# Evaluation of fine-tuned model of ESM-2(esm2-t33-650M-UR50D) on the IMMUNECODE dataset

## Global configuration

In [1]:
import pandas as pd
import numpy as np
import warnings
import logging

# Display
pd.set_option('display.max.rows', 2000)
pd.set_option('display.max.columns', 2000)

# Logger
warnings.filterwarnings('ignore')
logging.config.fileConfig('../config/logging.conf')
logger = logging.getLogger('gentcr')
logger.setLevel(logging.INFO)


## Load the base pretrained model and tokenizer

In [2]:
import tempfile

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from peft import *

# model_name = 'facebook/esm2_t33_650M_UR50D'
# base_model = AutoModelForMaskedLM.from_pretrained(model_name, device_map='auto').eval()
model_name = '../output/peft_esm2_t33_650M_UR50D'
base_model = AutoPeftModel.from_pretrained(model_name, device_map='auto').eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

(urllib3.connectionpool) 2023-11-25 17:19:42 [DEBUG]: Starting new HTTPS connection (1): huggingface.co:443
(urllib3.connectionpool) 2023-11-25 17:19:42 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-11-25 17:19:43 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0


## Training and evaluation datasets

### Training datasets
- SARS-CoV-2 T-cell epitope-specific TCR CDR3$\\beta$ sequence data from ImmuneCODE project launched by Adaptive Biotech. and Microsoft{Nolan, 2020} with Release 002.1
- Over 160,000 high-confidence SARS-CoV-2-specific TCRs covering 511 epitopes from 1,414 subjects exposed to or infected with the SARSCoV-2 virus
- 554,707 positive datapoints covering 545 epitopes

In [3]:
from gentcr.data import EpitopeTargetDataset, CN

train_ds = EpitopeTargetDataset.from_key('immunecode')
print(f'eval_ds.df.shape: {train_ds.df.shape}')
display(train_ds.summary())

(gentcr) 2023-11-25 17:19:47 [INFO]: Loaded immunecode data from ../output/immunecode.data.csv, df.shape: (554707, 12)
eval_ds.df.shape: (554707, 12)


Unnamed: 0,source,epitope_species,epitope_gene,epitope_seq,cdr3_beta,positive,negative
0,ImmuneCODE_002.1,SARS-CoV-2,"ORF1ab,surface glycoprotein",ADAGFIKQY,109,109,0
1,ImmuneCODE_002.1,SARS-CoV-2,ORF1ab,AEAELAKNVSL,1861,1861,0
2,ImmuneCODE_002.1,SARS-CoV-2,nucleocapsid phosphoprotein,AEGSRGGSQA,5,5,0
3,ImmuneCODE_002.1,SARS-CoV-2,"ORF1ab,ORF3a",AEIPKEEVKPF,163,163,0
4,ImmuneCODE_002.1,SARS-CoV-2,surface glycoprotein,AEIRASANL,211,211,0
5,ImmuneCODE_002.1,SARS-CoV-2,surface glycoprotein,AEIRASANLA,211,211,0
6,ImmuneCODE_002.1,SARS-CoV-2,ORF1ab,AELAKNVSLDNVL,1861,1861,0
7,ImmuneCODE_002.1,SARS-CoV-2,"ORF1ab,surface glycoprotein",AELEGIQY,109,109,0
8,ImmuneCODE_002.1,SARS-CoV-2,surface glycoprotein,AENSVAYSN,41,41,0
9,ImmuneCODE_002.1,SARS-CoV-2,surface glycoprotein,AENSVAYSNN,41,41,0


### Evaluation datasets
- From three studies: {Shomuradova, 2020}, {Minervina, 2022}, and {Gfeller, 2023}
- Excluded duplicated entries in the training dataset
- 975 positive datapoints covering 16 epitopes

In [4]:
from gentcr.data import EpitopeTargetDataset, CN

eval_ds = EpitopeTargetDataset.from_key('shomuradova_minervina_gfeller')
eval_ds.exclude_by(train_ds, target_cols=['index', 'cdr3b_seq'], inplace=True)
print(f'eval_ds.df.shape: {eval_ds.df.shape}')
display(eval_ds.summary())

(gentcr) 2023-11-25 17:19:50 [INFO]: Loaded shomuradova_minervina_gfeller data from ../output/shomuradova_minervina_gfeller.data.csv, df.shape: (1160, 12)
(gentcr) 2023-11-25 17:19:50 [INFO]: Excluding immunecode data by column: index from shomuradova_minervina_gfeller
(gentcr) 2023-11-25 17:19:54 [INFO]: Current shomuradova_minervina_gfeller data.shape: (1060, 12)
(gentcr) 2023-11-25 17:19:54 [INFO]: Excluding immunecode data by column: cdr3b_seq from shomuradova_minervina_gfeller
(gentcr) 2023-11-25 17:19:58 [INFO]: Current shomuradova_minervina_gfeller data.shape: (975, 12)
eval_ds.df.shape: (975, 12)


Unnamed: 0,source,epitope_species,epitope_gene,epitope_seq,cdr3_beta,positive,negative
0,IEDB,SARS-CoV2,ORF3a protein,ALSKGVHFV,13,13,0
1,Gfeller,,,EYADVFHLYL,8,8,0
2,IEDB,SARS-CoV2,ORF3a protein,FTSDYYQLY,35,35,0
3,IEDB,SARS-CoV2,Spike glycoprotein,LTDEMIAQY,126,126,0
4,Gfeller,,,LYLYALVYF,9,9,0
5,IEDB,SARS-CoV2,Spike glycoprotein,NQKLIANQF,73,73,0
6,IEDB,SARS-CoV2,Spike glycoprotein,NYNYLYRLF,29,29,0
7,IEDB,SARS-CoV2,orf1ab polyprotein [Severe acute respiratory s...,PTDNYITTY,21,21,0
8,IEDB,SARS-CoV2,Spike glycoprotein,QYIKWPWYI,20,20,0
9,Gfeller,,,QYIKWPWYIW,15,15,0


## Experiment 1 with high mutation ratio in TCR CD3R$\\beta$ sequences
- Epitope: mut_ratio=0.15, mut_probs=[0.7, 03]
- TCR CDR3$\\beta$: mut_ratio=0.4, mut_probs=[0.8, 0.2]

### Evaluate the base model

In [5]:
from torch.utils.data import DataLoader
from gentcr.bioseq import UniformAASeqMutator, CalisImmunogenicAASeqMutator, needle_aaseq_pair
from gentcr.common import StrUtils
from gentcr.data import EpitopeTargetDataset, EpitopeTargetMaskedLMCollator, CN
import numpy as np

def eval_model(target_model=None, 
               ds=None,
               n_samples=32,
               tokenizer=None,
               epitope_seq_mutator=None,
               target_seq_mutator=None,
               seq_format='{epitope_seq}{target_seq}'):
    
    epitope_seqs = ds.df[CN.epitope_seq].values[:n_samples]
    target_seqs = ds.df[CN.cdr3b_seq].values[:n_samples]
    input_seqs = [
        seq_format.format(epitope_seq=e_seq, target_seq=t_seq)  for e_seq, t_seq in zip(epitope_seqs, target_seqs)
    ]

    collator = EpitopeTargetMaskedLMCollator(tokenizer=tokenizer, 
                                             epitope_seq_mutator=epitope_seq_mutator, 
                                             target_seq_mutator=target_seq_mutator,
                                             max_epitope_len=ds.max_epitope_len,
                                             max_target_len=ds.max_target_len,
                                             seq_format=seq_format)
    data_loader = DataLoader(ds, batch_size=n_samples, shuffle=False, collate_fn=collator)
    
    batch = next(iter(data_loader))
    output = target_model(**batch)
    token_ids = torch.argmax(output.logits, dim=-1)
    output_seqs = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
    output_seqs = list(map(lambda seq: StrUtils.rm_nonwords(seq), output_seqs))
    
    scores = []
    for input_seq, output_seq in zip(input_seqs, output_seqs):
        output_seq = output_seq[:len(input_seq)]
        # print(f'input : {input_seq}\noutput: {output_seq}')
        score = needle_aaseq_pair(input_seq, output_seq, output_identity=True, output_similarity=True)['similarity']
        scores.append(score)
        # print('>>> similarity score: ', score)
    return np.mean(scores)
   
epitope_seq_mutator = CalisImmunogenicAASeqMutator(mut_ratio=0.15, mut_probs=[0.7, 0.3])
target_seq_mutator = UniformAASeqMutator(mut_ratio=0.5, mut_probs=[0.8, 0.2])

score = eval_model(target_model=base_model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer, 
                   epitope_seq_mutator=None, 
                   target_seq_mutator=target_seq_mutator)
print(f'>>> Mean similarity score of base model: {score}')

>>> Mean similarity score of base model: 0.7044615384615385


In [9]:
def load_peft_model(adapter_path='../output/exp1/mlm_finetune'):
    model = AutoPeftModel.from_pretrained(adapter_path, device_map='auto').eval()
    # model = model.merge_and_unload()
    return model

model = load_peft_model('../output/exp1/mlm_finetune')

score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer, 
                   epitope_seq_mutator=None, 
                   target_seq_mutator=target_seq_mutator)
print(f'>>> Mean similarity score of fine-tuned model: {score}')


(urllib3.connectionpool) 2023-11-25 17:22:58 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-11-25 17:22:59 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>> Mean similarity score of fine-tuned model: 0.7216933333333334


- Masking 50% AAs of TCR CDR3$\\beta$ sequences
- The mean similarity scores of the base model and the fine-tuned model is 0.705 and 0.729.
- The fine-tuned model is slightly better than the base model.
- Too high mutation in TCR CDR3$\\beta$ 
- We will investigate different mutation properties

## Experiment 2 with different mutation properties
- Epitope: mut_ratio=0.15, mut_probs=[0.7, 03]
- TCR CDR3$\\beta$: mut_ratio=0.2, mut_probs=[0.8, 0.2]

In [10]:
model = load_peft_model('../output/exp2/mlm_finetune')

target_seq_mutator.mut_ratio = 0.5
target_seq_mutator.mut_probs = [0.8, 0.2]
# score = eval_model(target_model=base_model, ds=eval_ds, n_samples=len(eval_ds),
#                    tokenizer=tokenizer, 
#                    epitope_seq_mutator=None, 
#                    target_seq_mutator=target_seq_mutator)
# print(f'>>>> Mean similarity score of base model: {score}')
score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer, 
                   epitope_seq_mutator=None, 
                   target_seq_mutator=target_seq_mutator)
print(f'>>>> Mean similarity score of fine-tuned model: {score}')

(urllib3.connectionpool) 2023-11-25 17:25:24 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-11-25 17:25:24 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>>> Mean similarity score of fine-tuned model: 0.8590153846153846


- Masking 50% AAs of TCR CDR3$\\beta$ sequences
- The mean similarity scores of the fine-tuned model is 0.859.
- The fine-tuned model score is significantly higher than the fine-tuned model in exp1.
- We will make the mutation properties of TCRCDR3$\\beta$ equal to that of epitope sequence.    

## Experiment 3 with same mutation properties on both epitope and TCR CDR3$\\beta$
- mut_ratio=0.15, mut_probs=[0.7, 03]

In [11]:
model = load_peft_model('../output/exp3/mlm_finetune')

target_seq_mutator.mut_ratio = 0.5
target_seq_mutator.mut_probs = [0.8, 0.2]
# score = eval_model(target_model=base_model, ds=eval_ds, n_samples=len(eval_ds),
#                    tokenizer=tokenizer, 
#                    epitope_seq_mutator=None, 
#                    target_seq_mutator=target_seq_mutator)
# print(f'>>>> Mean similarity score of base model: {score}')
score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer, 
                   epitope_seq_mutator=None, 
                   target_seq_mutator=target_seq_mutator)
print(f'>>>> Mean similarity score of fine-tuned model: {score}')

(urllib3.connectionpool) 2023-11-25 17:29:11 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-11-25 17:29:11 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>>> Mean similarity score of fine-tuned model: 0.8649661538461539


- Masking 50% AAs of TCR CDR3$\\beta$ sequences
- The mean similarity scores of the fine-tuned model is 0.865.
- The fine-tuned model score is higher than the fine-tuned model in exp2.
- The fine-tuned model shows reliable performance on unseen evaluation data