In [1]:
import torch
import torch.nn as nn

import random
import numpy as np
import csv
import matplotlib.pyplot as plt

In [2]:
# !pip install transformers evaluate datasets

In [3]:
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer,  TrainingArguments
from transformers import BertModel, BertConfig  

import datasets
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = datasets.load_dataset('csv', data_files='../data/data.csv', split="all")

In [5]:
dataset

Dataset({
    features: ['antigen', 'TCR', 'interaction'],
    num_rows: 130471
})

In [6]:
dataset = dataset.remove_columns('interaction')

In [7]:
dataset

Dataset({
    features: ['antigen', 'TCR'],
    num_rows: 130471
})

In [8]:
dataset = dataset.train_test_split(test_size=0.2)
dataset

DatasetDict({
    train: Dataset({
        features: ['antigen', 'TCR'],
        num_rows: 104376
    })
    test: Dataset({
        features: ['antigen', 'TCR'],
        num_rows: 26095
    })
})

In [9]:
BERT_CONFIG = BertConfig(
    vocab_size=25,  
    max_position_embeddings=64,
    type_vocab_size=2,
    num_attention_heads=8,
    num_hidden_layers=8,
    hidden_size=512,
    intermediate_size=2048,
    num_labels=2
)

In [10]:
from transformers import BertConfig

In [11]:
config  = BertConfig

In [12]:
model_name = 'Rostlab/prot_bert_bfd_localization' # "bert-base-uncased"

In [13]:
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
tokenizer.model_max_length = 64


tokenizer_config.json: 100%|██████████| 210/210 [00:00<00:00, 120kB/s]
config.json: 100%|██████████| 1.06k/1.06k [00:00<00:00, 1.20MB/s]
vocab.txt: 100%|██████████| 81.0/81.0 [00:00<00:00, 128kB/s]
special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 318kB/s]


In [14]:
config = AutoConfig.from_pretrained(model_name)

In [15]:
model = AutoModelForMaskedLM.from_config(BERT_CONFIG)

In [16]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25, 512, padding_idx=0)
      (position_embeddings): Embedding(64, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-7): 8 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=

In [17]:
column_names = list(dataset["train"].features)

In [18]:
column_names

['antigen', 'TCR']

[CLS]antigen[SEP]TCR[EOS]

In [19]:
def tokenize_function(examples):
    return tokenizer(examples[column_names[0]],examples[column_names[1]], return_special_tokens_mask=False, 
                     padding='longest', truncation='longest_first', return_tensors="pt")

In [20]:
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on every text in dataset",
    )

Running tokenizer on every text in dataset: 100%|██████████| 104376/104376 [00:02<00:00, 38590.99 examples/s]
Running tokenizer on every text in dataset: 100%|██████████| 26095/26095 [00:00<00:00, 38633.10 examples/s]


In [21]:
tokenized_datasets['train']

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 104376
})

In [22]:
max_seq_length = tokenizer.model_max_length # original BERT max length

In [23]:
max_seq_length

64

In [24]:
max_length = 64 # max_position_embeddings
tokenizer.model_max_length = max_length

In [25]:
from itertools import chain

In [26]:
pad_code = tokenizer.pad_token_id

In [27]:
pad_code

0

In [28]:
def group_texts(examples):

    # Concatenate all texts
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    # Calculate the number of chunks needed
    num_chunks = (total_length + max_seq_length - 1) // max_seq_length  # Round up division
    #num_chunks = (total_length + max_seq_length) // max_seq_length

    # Split by chunks of max_seq_length
    result = {
        k: [t[i * max_seq_length: (i + 1) * max_seq_length] for i in range(num_chunks)]
        for k, t in concatenated_examples.items()
    }

    # Pad the last chunk for each key if necessary
    k = 'input_ids'
    last_chunk_length = len(result[k][-1])
    if last_chunk_length < max_seq_length:
        result[k][-1] = result[k][-1] + [pad_code] * (max_seq_length - last_chunk_length)  

    k = 'attention_mask'
    last_chunk_length = len(result[k][-1])
    if last_chunk_length < max_seq_length:
        result[k][-1] = result[k][-1] + [0] * (max_seq_length - last_chunk_length)  

    k = 'token_type_ids'
    last_chunk_length = len(result[k][-1])
    if last_chunk_length < max_seq_length:
        result[k][-1] = result[k][-1] + [0] * (max_seq_length - last_chunk_length)

    # Create a new labels column
    result["labels"] = result["input_ids"].copy()

    return result


In [29]:
tokenized_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    desc=f"Grouping texts in chunks of {max_length}",
)

tokenized_datasets["train"][0]['input_ids']

In [30]:
tokenizer.decode(tokenized_datasets["train"][0]["input_ids"])

'[CLS] [UNK] [SEP] [UNK] [SEP]'

In [31]:
train_dataset = tokenized_datasets["train"]

In [32]:
eval_dataset = tokenized_datasets["test"]

In [33]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)       

In [34]:
metric = evaluate.load("accuracy")

In [35]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)

In [36]:
mlm_probability = 0.15 # Percentage of data to mask 

In [37]:
data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=mlm_probability)

In [38]:
training_args = TrainingArguments(output_dir='./results', evaluation_strategy="epoch", 
                                  learning_rate=2e-5, per_device_train_batch_size=16, 
                                  per_device_eval_batch_size=16, num_train_epochs=1, 
                                  weight_decay=0.01)

In [39]:
# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
    # preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

In [40]:
train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too for easy upload
metrics = train_result.metrics

  0%|          | 0/6524 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  8%|▊         | 500/6524 [01:22<15:25,  6.51it/s] 

{'loss': 0.0, 'learning_rate': 1.846719803801349e-05, 'epoch': 0.08}


 15%|█▌        | 1000/6524 [02:38<13:53,  6.63it/s]

{'loss': 0.0, 'learning_rate': 1.6934396076026978e-05, 'epoch': 0.15}


 23%|██▎       | 1500/6524 [03:54<12:27,  6.73it/s]

{'loss': 0.0, 'learning_rate': 1.540159411404047e-05, 'epoch': 0.23}


 31%|███       | 2000/6524 [05:09<11:09,  6.76it/s]

{'loss': 0.0, 'learning_rate': 1.3868792152053956e-05, 'epoch': 0.31}


 36%|███▌      | 2323/6524 [05:58<10:52,  6.44it/s]

KeyboardInterrupt: 

In [None]:
metrics["train_samples"] = len(train_dataset)

In [None]:
trainer.save_metrics("train", metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate()

In [None]:
metrics["eval_samples"] = len(eval_dataset)

In [None]:
import math

In [None]:
try:
    perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
    perplexity = float("inf")
metrics["perplexity"] = perplexity

trainer.save_metrics("eval", metrics)

In [None]:
metrics