# This notebook demonstrates a training example

### 1. To run it please create a conda environment called training_env using main/requirments_training_env.txt and use it for trainig in this notebook of from command line

In [1]:
#! conda activate training env

### 2. Import libraries

In [2]:
import pandas as pd
import evaluate
import numpy as np
import torch

from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

import comet_ml
import os

  from .autonotebook import tqdm as notebook_tqdm


### 3. Check GPU availability

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(0)

'Tesla V100S-PCIE-32GB'

In [20]:
torch.cuda.is_available()

True

### 4. Initialize comet-ml

comet_ml.init(project_name='esm2_t6_8M_transfer_learning', api_key='INSERT YOUR COMET ML API KEY HERE', experiment_name='experiment_1')

In [22]:
os.environ["COMET_LOG_ASSETS"] = "True"

### 5. Load and prepare the dataset

The training and validation data can be found in the main/data folder of the current repo

In [24]:
train = pd.read_csv('~/TCR-specificity-prediction-with-ESMv2/data/train.csv')
hard_test = pd.read_csv('~/TCR-specificity-prediction-with-ESMv2/data/hard_test.csv')
easy_test = pd.read_csv('~/TCR-specificity-prediction-with-ESMv2/data/easy_test.csv')

In [25]:
train_sequences = list(train['cdr3.alpha']+train['cdr3.beta']+train['antigen.epitope'])
hard_test_sequences = list(hard_test['cdr3.alpha']+hard_test['cdr3.beta']+hard_test['antigen.epitope'])
easy_test_sequences = list(easy_test['cdr3.alpha']+easy_test['cdr3.beta']+easy_test['antigen.epitope'])
train_labels = list(train['target'])
hard_test_labels = list(hard_test['target'])
easy_test_labels = list(easy_test['target'])

In [26]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [27]:
train_tokenized = tokenizer(train_sequences)
hard_test_tokenized = tokenizer(hard_test_sequences)
easy_test_tokenized = tokenizer(easy_test_sequences)

In [28]:
train_dataset = Dataset.from_dict(train_tokenized)
hard_test_dataset = Dataset.from_dict(hard_test_tokenized)
easy_test_dataset = Dataset.from_dict(easy_test_tokenized)

In [29]:
train_dataset = train_dataset.add_column("labels", train_labels)
hard_test_dataset = hard_test_dataset.add_column("labels", hard_test_labels)
easy_test_dataset = easy_test_dataset.add_column("labels", easy_test_labels)
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 49014
})

### 6. Load the esm2_t6_8M model with a classification layer

In [None]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"

In [5]:
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
print(model)

EsmForSequenceClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05,

Let's look if the model's original layers are frozen or not:

In [8]:
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

esm.embeddings.word_embeddings.weight: True
esm.embeddings.position_embeddings.weight: True
esm.encoder.layer.0.attention.self.query.weight: True
esm.encoder.layer.0.attention.self.query.bias: True
esm.encoder.layer.0.attention.self.key.weight: True
esm.encoder.layer.0.attention.self.key.bias: True
esm.encoder.layer.0.attention.self.value.weight: True
esm.encoder.layer.0.attention.self.value.bias: True
esm.encoder.layer.0.attention.output.dense.weight: True
esm.encoder.layer.0.attention.output.dense.bias: True
esm.encoder.layer.0.attention.LayerNorm.weight: True
esm.encoder.layer.0.attention.LayerNorm.bias: True
esm.encoder.layer.0.intermediate.dense.weight: True
esm.encoder.layer.0.intermediate.dense.bias: True
esm.encoder.layer.0.output.dense.weight: True
esm.encoder.layer.0.output.dense.bias: True
esm.encoder.layer.0.LayerNorm.weight: True
esm.encoder.layer.0.LayerNorm.bias: True
esm.encoder.layer.1.attention.self.query.weight: True
esm.encoder.layer.1.attention.self.query.bias: Tru

### 7. Specify the training arguments and metrics

In [38]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 2048

args = TrainingArguments(
    f"{model_name}-batch-size-test-finetuned-localization",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=300,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    label_names=["labels"],
    report_to=["comet_ml"]
)



In [32]:
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return clf_metrics.compute(predictions=predictions, references=labels)

### 8. Train the model

In [39]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=easy_test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [40]:
trainer.train()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : batch_size_test
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/crystalcolecrystal/esm2-t6-8m-transfer-learning/6a41be53a4584d1c9f47bd79730f219e
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : batch_size_test
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     args/_n_gpu                                  : 1
[1;38;5;39mCOMET INFO:[0m     args/_no_sync_in_gradient_accumulation       : True
[1;38;5;39mCOMET INFO:[0m     args/_setup_devices                          : cuda:0


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.655652,0.643727,0.59375,0.69063,0.520706
2,No log,0.604424,0.705238,0.673194,0.755303,0.607186
3,No log,0.583107,0.710719,0.676651,0.766975,0.605359
4,No log,0.564244,0.722899,0.706072,0.751719,0.665652
5,No log,0.548066,0.729903,0.714699,0.757328,0.676614
6,No log,0.541569,0.732643,0.700546,0.796124,0.625457
7,No log,0.520412,0.745737,0.73849,0.760155,0.718027
8,No log,0.503139,0.752132,0.755849,0.744681,0.767357
9,No log,0.504033,0.757308,0.748183,0.777413,0.721072
10,No log,0.50144,0.760962,0.755528,0.773104,0.738733


[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : batch_size_test
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/crystalcolecrystal/esm2-t6-8m-transfer-learning/214ed1343f2147e7ab4888844729581c
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     epoch [315]                   : (1.0, 300.0)
[1;38;5;39mCOMET INFO:[0m     eval_accuracy [300]           : (0.643727161997564, 0.7737515225334958)
[1;38;5;39mCOMET INFO:[0m     eval_f1 [300]                 : (0.5937499999999999, 0.7810858143607706)
[1;38;5;39mCOMET INFO:[0m     eval_loss [

TrainOutput(global_step=7200, training_loss=0.06385186142391629, metrics={'train_runtime': 6551.6953, 'train_samples_per_second': 2244.335, 'train_steps_per_second': 1.099, 'total_flos': 3.3920592132368136e+16, 'train_loss': 0.06385186142391629, 'epoch': 300.0})

In [50]:
trainer.save_model("esm2_t6_8M_UR50D_300epoch")

# This is an example of the ESM model trainig. For metrics calculation see tests.ipynb notebook in /main/notebooks of the current repo