In [1]:
!pip install gdown srsly transformers



In [2]:
import os
import pathlib
import random
import shutil
import tempfile
from datetime import datetime

import gdown
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import srsly
import torch
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedTokenizer
)
from tqdm import tqdm

In [3]:
# set seed
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f68dce61250>

In [4]:
# set up the file-system
dirname = tempfile.mkdtemp()
dirname = pathlib.Path(dirname)

data_directory = dirname / 'data'
model_directory = dirname / 'model'
metric_directory = dirname / 'metric'

data_directory.mkdir()
model_directory.mkdir()
metric_directory.mkdir()

In [5]:
# gdrive url
TRAIN_DATA_URL = 'https://drive.google.com/uc?id=1deNCsmS9IlOFquGZ4NMQ_y8It-WHuJlT'
DEV_DATA_URL = 'https://drive.google.com/uc?id=1bZQ26XNFswld8pkAO2JHLWybfOPhj-EN'
TEST_DATA_URL = 'https://drive.google.com/uc?id=1tD37fUooeygRlP7TYLEjzHz9wTXhSUyO'
LABEL_MAP_URL = 'https://drive.google.com/uc?id=1S9yaAVqzfaEV5mV5VZvbYBDp4AxPiSK8'
LABEL_WEIGHTS_URL = 'https://drive.google.com/uc?id=1zxrS94Kh3mCQupTCtlF6rGpvnVZo1KvY'

# target paths
train_data_path = data_directory / 'train.jsonl'
dev_data_path = data_directory / 'dev.jsonl'
test_data_path = data_directory / 'test.jsonl'
label_map_path = data_directory / 'label-map.json'
label_weights_path = data_directory / 'label-weights.json'

In [6]:
# download the data from Google drive
splits = [(TRAIN_DATA_URL, train_data_path),
          (DEV_DATA_URL, dev_data_path),
          (TEST_DATA_URL, test_data_path),
          (LABEL_MAP_URL, label_map_path),
          (LABEL_WEIGHTS_URL, label_weights_path)]

for url, file in splits:
    gdown.download(url=url, output=str(file), quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1deNCsmS9IlOFquGZ4NMQ_y8It-WHuJlT
To: /tmp/tmplxllqcdc/data/train.jsonl
100%|██████████| 678M/678M [00:02<00:00, 265MB/s]
Downloading...
From: https://drive.google.com/uc?id=1bZQ26XNFswld8pkAO2JHLWybfOPhj-EN
To: /tmp/tmplxllqcdc/data/dev.jsonl
100%|██████████| 78.8M/78.8M [00:00<00:00, 147MB/s]
Downloading...
From: https://drive.google.com/uc?id=1tD37fUooeygRlP7TYLEjzHz9wTXhSUyO
To: /tmp/tmplxllqcdc/data/test.jsonl
100%|██████████| 79.0M/79.0M [00:00<00:00, 162MB/s]
Downloading...
From: https://drive.google.com/uc?id=1S9yaAVqzfaEV5mV5VZvbYBDp4AxPiSK8
To: /tmp/tmplxllqcdc/data/label-map.json
100%|██████████| 374k/374k [00:00<00:00, 77.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1zxrS94Kh3mCQupTCtlF6rGpvnVZo1KvY
To: /tmp/tmplxllqcdc/data/label-weights.json
100%|██████████| 280k/280k [00:00<00:00, 46.7MB/s]


In [7]:
# model pahts
# using the prot-bert model from https://huggingface.co/Rostlab/prot_bert
LANGUAGE_MODEL = os.environ.get('LANGUAGE_MODEL', 'Rostlab/prot_bert')
TOKENIZER = os.environ.get('TOKENIZER', 'Rostlab/prot_bert')

# device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data variables
MAX_LENGTH = os.environ.get('MAX_LENGTH', 256)

# training params
EPOCHS = 10
BATCH_SIZE = 5
LEARNING_RATE = 1e-7

In [8]:
class ProteinFamilyDataset(Dataset):
    """
    Dataset class for protein sequences.
    """
    _source_column: str = 'sequence'
    _target_column: str = 'family_accession'

    def __init__(self, data_source: pd.DataFrame, tokenizer: PreTrainedTokenizer, label_map: dict, **tokenizer_args):
        self.data_source = data_source
        self.tokenizer = tokenizer
        self.label_map = label_map

        if not tokenizer_args:
            tokenizer_args = {'padding': 'max_length', 
                              'truncation': True, 
                              'max_length': int(MAX_LENGTH)}
        
        self.tokenizer_args = tokenizer_args
        
    def __len__(self):
        return len(self.data_source)

    def _space_separate_aa_tokens(self, sequence: str) -> str:
        # insert a space between each sequence character -- this is a requirement for the prot-bert tokenizer
        split = list(sequence)
        out = ' '.join(split)
        return out

    def _tokenize(self, sequence: str) -> torch.Tensor:
        toks = self.tokenizer(sequence, **self.tokenizer_args)
        out = torch.tensor(toks.input_ids, dtype=torch.long, device=DEVICE)
        return out

    def _cast_sequence_to_tensor(self, sequence: str) -> torch.Tensor:
        pre_processed = self._space_separate_aa_tokens(sequence)
        out = self._tokenize(pre_processed)
        return out

    def _cast_label_to_tensor(self, label: str) -> torch.Tensor:
        idx = self.label_map.get(label)
        if idx is None:
            raise RuntimeError(f'unknown label: {label}')
        out = torch.tensor(idx, dtype=torch.long, device=DEVICE)
        return out
    
    def __getitem__(self, item: int):
        record = self.data_source.iloc[item]
        src = record[self._source_column]
        tgt = record[self._target_column]

        source = self._cast_sequence_to_tensor(src)
        target = self._cast_label_to_tensor(tgt)
        
        return source, target

In [9]:
# load the label map and the label weights
label_map = srsly.read_json(label_map_path)
label_weights = srsly.read_json(label_weights_path)

In [10]:
# initialise tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

# load pre-trained model
model = AutoModelForSequenceClassification.from_pretrained(LANGUAGE_MODEL, num_labels=len(label_map))
model.to(DEVICE)

Downloading:   0%|          | 0.00/856M [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_albert were not used when initializing AlbertForSequenceClassification: ['predictions.decoder.bias', 'predictions.bias', 'predictions.decoder.weight', 'sop_classifier.classifier.bias', 'sop_classifier.classifier.weight', 'predictions.dense.bias', 'predictions.dense.weight', 'predictions.LayerNorm.weight', 'predictions.LayerNorm.bias']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at Rostlab/prot_al

AlbertForSequenceClassification(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(34, 128, padding_idx=0)
      (position_embeddings): Embedding(40000, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=4096, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((4096,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=4096, out_features=4096, bias=True)
                (key): Linear(in_features=4096, out_features=4096, bias=True)
                (value): Linear(in_features=4096, out_feature

In [11]:
# define loss and optmisation regime
# we weight the loss for different labels proportional to their abundance
lw = label_weights.get('weights', [])
lw_tensor = torch.tensor(lw, dtype=torch.float32, device=DEVICE)

loss_fn = nn.CrossEntropyLoss(weight=lw_tensor)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=LEARNING_RATE,
                             weight_decay=0.01)

In [12]:
# load training data
lines = srsly.read_jsonl(train_data_path)
data_frame = pd.DataFrame(lines)

train_data = ProteinFamilyDataset(data_frame, tokenizer, label_map)
loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

In [13]:
# train loop
size = len(loader)
training_error = []
for eps in range(EPOCHS):
    
    epoch_error = []
    print(f'epoch {eps}/{EPOCHS}')
    for (X, y) in tqdm(loader, desc='training'):
        y_hat = torch.softmax(model(X).logits, dim=-1)
        loss = loss_fn(y_hat, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_error.append(loss.item())
    training_error.append(epoch_error)

epoch 0/10


training:   0%|          | 1/108675 [00:09<277:11:00,  9.18s/it]


RuntimeError: ignored

In [None]:
error = np.array(training_error).mean(axis=1)
plt.plot(error)

In [None]:
lines = srsly.read_jsonl(test_data_path)
data_frame = pd.DataFrame(lines)

test_data = ProteinFamilyDataset(data_frame, tokenizer, label_map)
loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# test-loop
predictions, targets = [], []
size = len(loader)
with torch.no_grad():
    for batch, (X, y) in enumerate(loader):
        y_hat = torch.softmax(model(X).logits, dim=-1)
        
        predictions.append(y_hat)
        targets.append(y)
        
        print(f'batch: {batch + 1}/{size}')

predictions = torch.cat(predictions).numpy()
targets = torch.cat(targets).numpy()

In [None]:
pred_cls = predictions.argmax(axis=1)

metrics = classification_report(targets, pred_cls, output_dict=True)

In [None]:
srsly.write_json(metric_directory / f'{datetime.now().strftime("%Y-%m-%d-%X")}.json', metrics)

In [None]:
model.save_pretrained(model_directory / 'lm')
tokenizer.save_pretrained(model_directory / 'tokenizer')

In [None]:
# cleanup -- deletes temporary directory
shutil.rmtree(dirname)