In [1]:
from tqdm.notebook import tqdm
import pandas as pd
from datasets import Dataset
from src.dataset import SPFastaDatasetBinary, SPFastaDatasetBinaryWithTokenizedCategory
import torch
from torch import nn
from transformers import AutoTokenizer, AdamW
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
ds = SPFastaDatasetBinaryWithTokenizedCategory("data/train.fasta")
dataset_train = Dataset.from_pandas(ds.data).with_format("torch", device=device)
ds = SPFastaDatasetBinaryWithTokenizedCategory("data/test.fasta")
dataset_test = Dataset.from_pandas(ds.data).with_format("torch", device=device)
del ds

  0%|          | 0/20290 [00:00<?, ?it/s]

100%|██████████| 20290/20290 [00:00<00:00, 317035.33it/s]
100%|██████████| 20290/20290 [00:00<00:00, 1068079.37it/s]
100%|██████████| 8811/8811 [00:00<00:00, 326328.17it/s]
100%|██████████| 8811/8811 [00:00<00:00, 978967.22it/s]


In [3]:
from datasets import ClassLabel
dataset_train = dataset_train.cast_column('labels', ClassLabel(num_classes = 2, names=["NO_SP", "SP"]))
dataset_test = dataset_test.cast_column('labels', ClassLabel(num_classes = 2, names=["NO_SP", "SP"]))
dataset_train[0]['labels']

Casting the dataset:   0%|          | 0/20290 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/8811 [00:00<?, ? examples/s]

tensor(0, device='cuda:0')

In [4]:
# Tokenize:
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", device="cpu")
dataset_train = dataset_train.map(lambda x: tokenizer(x['text'], return_tensors="pt", padding='max_length', max_length=81, truncation=True), batched=True)
dataset_test = dataset_test.map(lambda x: tokenizer(x['text'], return_tensors="pt", padding='max_length', max_length=81, truncation=True), batched=True)
dataset_train, dataset_test

Map:   0%|          | 0/20290 [00:00<?, ? examples/s]

Map:   0%|          | 0/8811 [00:00<?, ? examples/s]

(Dataset({
     features: ['text', 'labels', 'uniprot_ac', 'kingdom', 'type', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 20290
 }),
 Dataset({
     features: ['text', 'labels', 'uniprot_ac', 'kingdom', 'type', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 8811
 }))

In [5]:
import evaluate
acc = evaluate.load("accuracy")
pre = evaluate.load("precision")
rec = evaluate.load("recall")

metrics = [acc, pre, rec]

def compute_metrics(eval_pred):
    logits, y = eval_pred
    x = logits.argmax(-1)
    return {k: v for metric in [m.compute(predictions=x, references=y) for m in metrics] for k, v in metric.items()}

# Define model

In [6]:
from transformers import AutoModelForSequenceClassification

class ProtBertSequenceClassification(nn.Module):
    def __init__(self, device: str=device) -> None:
        super(ProtBertSequenceClassification, self).__init__()
        self.device = device

        self.model = AutoModelForSequenceClassification.from_pretrained(
            "Rostlab/prot_bert",
            num_labels=2,
            label2id = {
                'NO_SP': 0,
                'SP': 1
            },
            id2label= {
                0: 'NO_SP',
                1: 'SP'
            }).to(device)

        self.loss_fn = nn.functional.binary_cross_entropy

    def forward(self, input_ids, token_type_ids, attention_mask, labels):
        outputs = self.model.forward(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )
        return outputs
    
    def loss(self, x, weights=None):
        logits = self.forward(x).logits[:,1]
        pred = logits.softmax(-1)
        label = x['labels']
        return self.loss_fn(pred, label.float(), weight=weights)

model = ProtBertSequenceClassification()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Rostlab/prot_bert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# with torch.no_grad():
#     loss = model.loss(dataset_train[0:1])
# loss

In [8]:
n_pos = dataset_train[:]['labels'].count_nonzero()
n_neg = len(dataset_train) - n_pos
neg_weight = 1 - ((n_neg) / (n_neg + n_pos))
pos_weight = 1 - ((n_pos) / (n_neg + n_pos))
class_weights = torch.tensor([neg_weight, pos_weight]).to(device)
n_neg, n_pos, class_weights

(tensor(15625, device='cuda:0'),
 tensor(4665, device='cuda:0'),
 tensor([0.2299, 0.7701], device='cuda:0'))

In [9]:
sample = dataset_train[0:2]
with torch.no_grad():
    outputs = model.forward(sample['input_ids'], sample['token_type_ids'], sample['attention_mask'], sample['labels'])
    logits = outputs.get('logits')
    labels = sample.get('labels')

    loss_fn = nn.CrossEntropyLoss(weight=class_weights)
    loss = loss_fn(logits, labels)
logits, labels, loss

(tensor([[-0.0274,  0.0034],
         [-0.0272,  0.0034]], device='cuda:0'),
 tensor([0, 0], device='cuda:0'),
 tensor(0.7086, device='cuda:0'))

In [10]:
from transformers import Trainer

class WeightedLossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        
        outputs = model(**inputs)
        logits = outputs.get('logits')
        labels = inputs.get('labels')

        loss_fn = nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [11]:
from transformers import TrainingArguments

epochs = 3
batch_size = 8

args = TrainingArguments(
    output_dir='./model',
    logging_dir="./logs",
    num_train_epochs=epochs,
    learning_rate=2e-4,
    # per_device_train_batch_size=batch_size,
    # per_device_eval_batch_size=batch_size,
    logging_first_step=True,
    weight_decay=0.01,
    evaluation_strategy='steps',
    save_strategy='epoch',
    logging_steps=100,
    fp16=True,
    optim='adamw_torch',
    remove_unused_columns=True,
    auto_find_batch_size=True,
)

In [12]:
trainer = WeightedLossTrainer(
    model=model,
    args=args,
    compute_metrics=compute_metrics,
    train_dataset=dataset_train,
    eval_dataset=dataset_test,
    tokenizer=tokenizer
)
trainer.train()

  0%|          | 0/7611 [00:00<?, ?it/s]

{'loss': 0.6847, 'learning_rate': 0.00019997372224412038, 'epoch': 0.0}
{'loss': 0.7288, 'learning_rate': 0.00019737222441203523, 'epoch': 0.04}


  0%|          | 0/1102 [00:00<?, ?it/s]

{'eval_loss': 0.7293215394020081, 'eval_accuracy': 0.1397117239813869, 'eval_precision': 0.1397117239813869, 'eval_recall': 1.0, 'eval_runtime': 3007.0311, 'eval_samples_per_second': 2.93, 'eval_steps_per_second': 0.366, 'epoch': 0.04}
{'loss': 0.6967, 'learning_rate': 0.00019474444882407044, 'epoch': 0.08}


  0%|          | 0/1102 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 0.5964159965515137, 'eval_accuracy': 0.8602882760186131, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_runtime': 3520.9244, 'eval_samples_per_second': 2.502, 'eval_steps_per_second': 0.313, 'epoch': 0.08}
{'loss': 0.6975, 'learning_rate': 0.00019211667323610566, 'epoch': 0.12}


  0%|          | 0/1102 [00:00<?, ?it/s]

{'eval_loss': 0.7439395785331726, 'eval_accuracy': 0.1397117239813869, 'eval_precision': 0.1397117239813869, 'eval_recall': 1.0, 'eval_runtime': 3188.8005, 'eval_samples_per_second': 2.763, 'eval_steps_per_second': 0.346, 'epoch': 0.12}
{'loss': 0.6897, 'learning_rate': 0.00018948889764814085, 'epoch': 0.16}


  0%|          | 0/1102 [00:00<?, ?it/s]

KeyboardInterrupt: 

```python
# TODO: check if this is correct
loss_weights = torch.tensor([neg_weight if x[0] == 1 else pos_weight for x in batch['label']])
loss_weights = torch.tensor([neg_weight, pos_weight]).to(device) # or this
model_v = "microsoft/MiniLM-L12-H384-uncased"

# Classical training loop:
import pickle
loss_log = []
epoch0 = 0
if os.path.exists("model/temp.state"):
    with open("model/temp.state", "rb") as f:
        state = pickle.load(f)
        loss_log = state['loss_log']
        epoch0 = state['epoch']

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
epochs = 10
batch_size = 4

model.train()
batches = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
for epoch in range(epoch0, epochs):
    for i, batch in enumerate(tqdm(batches)):
        class_weights = torch.tensor([neg_weight if x[0] == 1 else pos_weight for x in batch['labels']]).to(device)

        optimizer.zero_grad()
        loss = model.loss(batch, weights=class_weights)

        loss.backward()
        optimizer.step()
        loss_log.append(loss.cpu().detach().item())
        if (i+1) % (len(batches) // 20) == 0:
            print(np.mean(loss_log[-5000:]))
```