# Evaluation of fine-tuned model of ESM-2(esm2-t33-650M-UR50D) on the IMMUNECODE dataset

## Global configuration

In [1]:
import pandas as pd
import numpy as np
import warnings
import logging

# Display
pd.set_option('display.max.rows', 2000)
pd.set_option('display.max.columns', 2000)

# Logger
warnings.filterwarnings('ignore')
logging.config.fileConfig('../config/logging.conf')
logger = logging.getLogger('gentcr')
logger.setLevel(logging.INFO)


## Load the base pretrained model and tokenizer

In [2]:
import tempfile

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from peft import *

# model_name = 'facebook/esm2_t33_650M_UR50D'
# base_model = AutoModelForMaskedLM.from_pretrained(model_name, device_map='auto').eval()
model_name = '../output/peft_esm2_t33_650M_UR50D'
base_model = AutoPeftModel.from_pretrained(model_name, device_map='auto').eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

(urllib3.connectionpool) 2023-12-18 17:33:55 [DEBUG]: Starting new HTTPS connection (1): huggingface.co:443
(urllib3.connectionpool) 2023-12-18 17:33:56 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-12-18 17:33:56 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0


## Training and evaluation datasets

### Training datasets
- SARS-CoV-2 T-cell epitope-specific TCR CDR3$\\beta$ sequence data from ImmuneCODE project launched by Adaptive Biotech. and Microsoft{Nolan, 2020} with Release 002.1
- Over 160,000 high-confidence SARS-CoV-2-specific TCRs covering 511 epitopes from 1,414 subjects exposed to or infected with the SARSCoV-2 virus
- 554,707 positive datapoints covering 545 epitopes

In [3]:
from gentcr.data import EpitopeTargetDataset, CN

train_ds = EpitopeTargetDataset.from_key('immunecode')
print(f'eval_ds.df.shape: {train_ds.df.shape}')
display(train_ds.df.head())

(gentcr) 2023-12-18 17:34:12 [INFO]: Loaded immunecode data from ../output/immunecode.data.csv, df.shape: (554707, 16)
eval_ds.df.shape: (554707, 16)


Unnamed: 0,epitope_species,epitope_gene,epitope_seq,orig_epitope_seq,epitope_start,epitope_end,epitope_len,mhc_allele,mhc_seq,orig_mhc_seq,cdr3b_seq,orig_cdr3b_seq,cdr3b_len,ref_id,source,bind_level
ADAGFIKQY_CASSAQGTGDRGYTF,SARS-CoV-2,"ORF1ab,surface glycoprotein",ADAGFIKQY,ADAGFIKQY,,,9,"HLA-A*11:01,HLA-A*68:01,HLA-B*35:01,HLA-B*35:0...",,,CASSAQGTGDRGYTF,CASSAQGTGDRGYTF,15,PMC:7418738,ImmuneCODE_002.1,2
AELEGIQY_CASSAQGTGDRGYTF,SARS-CoV-2,"ORF1ab,surface glycoprotein",AELEGIQY,AELEGIQY,,,8,"HLA-A*11:01,HLA-A*68:01,HLA-B*35:01,HLA-B*35:0...",,,CASSAQGTGDRGYTF,CASSAQGTGDRGYTF,15,PMC:7418738,ImmuneCODE_002.1,2
LADAGFIKQY_CASSAQGTGDRGYTF,SARS-CoV-2,"ORF1ab,surface glycoprotein",LADAGFIKQY,LADAGFIKQY,,,10,"HLA-A*11:01,HLA-A*68:01,HLA-B*35:01,HLA-B*35:0...",,,CASSAQGTGDRGYTF,CASSAQGTGDRGYTF,15,PMC:7418738,ImmuneCODE_002.1,2
TLADAGFIK_CASSAQGTGDRGYTF,SARS-CoV-2,"ORF1ab,surface glycoprotein",TLADAGFIK,TLADAGFIK,,,9,"HLA-A*11:01,HLA-A*68:01,HLA-B*35:01,HLA-B*35:0...",,,CASSAQGTGDRGYTF,CASSAQGTGDRGYTF,15,PMC:7418738,ImmuneCODE_002.1,2
ADAGFIKQY_CASSLVATGNTGELFF,SARS-CoV-2,"ORF1ab,surface glycoprotein",ADAGFIKQY,ADAGFIKQY,,,9,"HLA-A*02:01,HLA-A*33:03,HLA-B*53:01,HLA-B*58:0...",,,CASSLVATGNTGELFF,CASSLVATGNTGELFF,16,PMC:7418738,ImmuneCODE_002.1,2


### Evaluation datasets
- From three studies: {Shomuradova, 2020}, {Minervina, 2022}, and {Gfeller, 2023}
- Excluded duplicated entries in the training dataset
- 975 positive datapoints covering 16 epitopes

In [5]:
from gentcr.data import EpitopeTargetDataset, CN

eval_ds = EpitopeTargetDataset.from_key('smg_mutated')
eval_ds.exclude_by(train_ds, target_cols=['index', 'orig_cdr3b_seq'], inplace=True)
print(f'eval_ds.df.shape: {eval_ds.df.shape}')
display(eval_ds.df.head())

(gentcr) 2023-12-18 17:35:20 [INFO]: Loaded smg_mutated data from ../output/smg_mutated.data.csv, df.shape: (1174, 16)
(gentcr) 2023-12-18 17:35:20 [INFO]: Excluding immunecode data by column: index from smg_mutated
(gentcr) 2023-12-18 17:35:24 [INFO]: Current smg_mutated data.shape: (1174, 16)
(gentcr) 2023-12-18 17:35:24 [INFO]: Excluding immunecode data by column: orig_cdr3b_seq from smg_mutated
(gentcr) 2023-12-18 17:35:29 [INFO]: Current smg_mutated data.shape: (977, 16)
eval_ds.df.shape: (977, 16)


Unnamed: 0,epitope_species,epitope_gene,epitope_seq,orig_epitope_seq,epitope_start,epitope_end,epitope_len,mhc_allele,mhc_seq,orig_mhc_seq,cdr3b_seq,orig_cdr3b_seq,cdr3b_len,ref_id,source,bind_level
YLQ-RTFLL_CA-S--N-NUQ-F,SARS-CoV-2,Spike,YLQ-RTFLL,YLQPRTFLL,,,9,HLA-A*02,,,CA-S--N-NUQ-F,CASSSVNNNEQFF,13,PMID:33326767,Shomuradova,2
YLQPRTF-L_CFV---N-GELF-,SARS-CoV-2,Spike,YLQPRTF-L,YLQPRTFLL,,,9,HLA-A*02,,,CFV---N-GELF-,CAVGEANTGELFF,13,PMID:33326767,Shomuradova,2
YLQPRPFLL_C-Y-EV-TG---F,SARS-CoV-2,Spike,YLQPRPFLL,YLQPRTFLL,,,9,HLA-A*02,,,C-Y-EV-TG---F,CAYQEVNTGELFF,13,PMID:33326767,Shomuradova,2
YLQP-TFLL_-SARDD-A----EL--,SARS-CoV-2,Spike,YLQP-TFLL,YLQPRTFLL,,,9,HLA-A*02,,,-SARDD-A----EL--,CSARDDQAVNTGELFF,16,PMID:33326767,Shomuradova,2
YLQP-TFLL_-S-GLRN--ELFH,SARS-CoV-2,Spike,YLQP-TFLL,YLQPRTFLL,,,9,HLA-A*02,,,-S-GLRN--ELFH,CSAGQRNTGELFF,13,PMID:33326767,Shomuradova,2


## Experiment 1 with high mutation ratio in TCR CD3R$\\beta$ sequences
- Epitope: mut_ratio=0.15, mut_probs=[0.7, 03]
- TCR CDR3$\\beta$: mut_ratio=0.4, mut_probs=[0.8, 0.2]

### Evaluate the base model

In [9]:
from torch.utils.data import DataLoader
from gentcr.bioseq import UniformAASeqMutator, CalisImmunogenicAASeqMutator, needle_aaseq_pair
from gentcr.common import StrUtils
from gentcr.data import EpitopeTargetDataset, EpitopeTargetMaskedLMCollator, CN
import numpy as np

def eval_model(target_model=None, 
               ds=None,
               n_samples=32,
               tokenizer=None):
    

    collator = EpitopeTargetMaskedLMCollator(tokenizer=tokenizer,
                                             max_epitope_len=ds.max_epitope_len,
                                             max_target_len=ds.max_target_len)
    data_loader = DataLoader(ds, batch_size=n_samples, shuffle=False, collate_fn=collator)

    epitope_seqs = ds.df[CN.orig_epitope_seq].values[:n_samples]
    target_seqs = ds.df[CN.orig_cdr3b_seq].values[:n_samples]
    input_seqs = [
        collator.format_seqs(e_seq, t_seq) for e_seq, t_seq in zip(epitope_seqs, target_seqs)
    ]
    
    batch = next(iter(data_loader))
    output = target_model(**batch)
    token_ids = torch.argmax(output.logits, dim=-1)
    output_seqs = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
    output_seqs = list(map(lambda seq: StrUtils.rm_nonwords(seq), output_seqs))
    
    scores = []
    for input_seq, output_seq in zip(input_seqs, output_seqs):
        output_seq = output_seq[:len(input_seq)]
        # print(f'input : {input_seq}\noutput: {output_seq}')
        score = needle_aaseq_pair(input_seq, output_seq, output_identity=True, output_similarity=True)['similarity']
        scores.append(score)
        print(f'>>Similarity score for {input_seq}/{output_seq}: {score:.3f}')
    return np.mean(scores)
   
score = eval_model(target_model=base_model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer)
print(f'>>> Mean similarity score of base model: {score:.3f}')

>>Similarity score for YLQPRTFLLCASSSVNNNEQFF/MLQLRTFLLCALSLLNLNLQLF: 0.682
>>Similarity score for YLQPRTFLLCAVGEANTGELFF/MLQPRTFLLCFVLLLNLGELFL: 0.682
>>Similarity score for YLQPRTFLLCAYQEVNTGELFF/MLQPRPFLLCLYLEVLTGLLLF: 0.682
>>Similarity score for YLQPRTFLLCSARDDQAVNTGELFF/MLQPLTFLLLSARDDLAALLLELLL: 0.600
>>Similarity score for YLQPRTFLLCSAGQRNTGELFF/MLQPLTFLLLSLGLRNLLELFH: 0.636
>>Similarity score for YLQPRTFLLCASSLEIEAFF/MLQPRTLLLQTLSLEIELLF: 0.650
>>Similarity score for YLQPRTFLLCAGDYLNTGELFF/MLQPRLFLLLAGIMLNTGLLFL: 0.682
>>Similarity score for YLQPRTFLLCASSPDIACTF/MLQPRTLLLCALLLDIXXXF: 0.600
>>Similarity score for YLQPRTFLLCASSVDNTGELFF/MLQPRSFLLCASLVDLSGELLL: 0.773
>>Similarity score for YLQPRTFLLCAGQDLNTGELFF/MLQPRTFVLLLGQLLNTGELLF: 0.773
>>Similarity score for YLQPRTFLLCASSPDIVAFF/MLQLRTFLLCLLSLDLVDLF: 0.650
>>Similarity score for YLQPRTFLLCAAQNLNTGELFF/MLQPRTLLLLALLNLLTGLLLF: 0.636
>>Similarity score for YLQPRTFLLCSAGDRNTGELFF/MLQPRTFFLVMAGLRLTLELLF: 0.636
>>Similarity score

In [10]:
def load_peft_model(adapter_path='../output/exp1/mlm_finetune'):
    model = AutoPeftModel.from_pretrained(adapter_path, device_map='auto').eval()
    # model = model.merge_and_unload()
    return model

model = load_peft_model('../output/exp1/mlm_finetune')

score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer)
print(f'>>> Mean similarity score of fine-tuned model: {score}')


(urllib3.connectionpool) 2023-12-18 18:50:17 [DEBUG]: Resetting dropped connection: huggingface.co
(urllib3.connectionpool) 2023-12-18 18:50:17 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-12-18 18:50:17 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>Similarity score for YLQPRTFLLCASSSVNNNEQFF/MLHGRTFLLCASSPGAYNEQFF: 0.682
>>Similarity score for YLQPRTFLLCAVGEANTGELFF/MLHPRTFWLCAVGGANTGELFF: 0.818
>>Similarity score for YLQPRTFLLCAYQEVNTGELFF/MLSPRGVLSCASRESNTGELFF: 0.682
>>Similarity score for YLQPRTFLLCSARDDQAVNTGELFF/MLSPSGGWLCSARDRLAGNTGELFF: 0.640
>>Similarity score for YLQPRTFLLCSAGQRNTGELFF/MLCPGFFLCCSVGDANTGELFF: 0.636
>>Similarity score for YLQPRTFLLCASSLEIEAFF/MLQPRLSLLCASSLDGEQFF: 0.750
>>Similarity score for YLQPRTFLLCAGDYLNTGELFF/MLSPRGFWCCASGYANTGELFF: 0.636
>>Similarity score for YLQPRTFLLCASSPDIACTF/MLHPRTP

- Masking 50% AAs of TCR CDR3$\\beta$ sequences
- The mean similarity scores of the base model and the fine-tuned model is 0.705 and 0.729.
- The fine-tuned model is slightly better than the base model.
- Too high mutation in TCR CDR3$\\beta$ 
- We will investigate different mutation properties

## Experiment 2 with different mutation properties
- Epitope: mut_ratio=0.15, mut_probs=[0.7, 03]
- TCR CDR3$\\beta$: mut_ratio=0.2, mut_probs=[0.8, 0.2]

In [11]:
model = load_peft_model('../output/exp2/mlm_finetune')

# score = eval_model(target_model=base_model, ds=eval_ds, n_samples=len(eval_ds),
#                    tokenizer=tokenizer, 
#                    epitope_seq_mutator=None, 
#                    target_seq_mutator=target_seq_mutator)
# print(f'>>>> Mean similarity score of base model: {score}')
score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer)
print(f'>>>> Mean similarity score of fine-tuned model: {score}')

(urllib3.connectionpool) 2023-12-18 18:52:30 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-12-18 18:52:30 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>Similarity score for YLQPRTFLLCASSSVNNNEQFF/YLIPRTFLLCASSLGNYNEQFF: 0.818
>>Similarity score for YLQPRTFLLCAVGEANTGELFF/YLDPRTFLLCAVEDENTGELFF: 0.864
>>Similarity score for YLQPRTFLLCAYQEVNTGELFF/YLDPRTFLLCASSEVNTGELFF: 0.864
>>Similarity score for YLQPRTFLLCSARDDQAVNTGELFF/YLQPRTFLICSARDDRAGNTGELFF: 0.960
>>Similarity score for YLQPRTFLLCSAGQRNTGELFF/YLDPRTFLLCSAGDRNTGELFF: 0.909
>>Similarity score for YLQPRTFLLCASSLEIEAFF/YLDPRTFLLCASSLEIEQFF: 0.900
>>Similarity score for YLQPRTFLLCAGDYLNTGELFF/YLDPRTFLLCAGIDGNTGELFF: 0.818
>>Similarity score for YLQPRTFLLCASSPDIACTF/YLDPRTFLLCASSLDIEQFF: 0.750
>>Similarity score for YLQPRTFLLCASSVDNTGELFF/YLQPRTFLLCASSVDNTGELFF: 1.000
>>

- Masking 50% AAs of TCR CDR3$\\beta$ sequences
- The mean similarity scores of the fine-tuned model is 0.859.
- The fine-tuned model score is significantly higher than the fine-tuned model in exp1.
- We will make the mutation properties of TCRCDR3$\\beta$ equal to that of epitope sequence.    

## Experiment 3 with same mutation properties on both epitope and TCR CDR3$\\beta$
- mut_ratio=0.15, mut_probs=[0.7, 03]

In [12]:
model = load_peft_model('../output/exp3/mlm_finetune')
# score = eval_model(target_model=base_model, ds=eval_ds, n_samples=len(eval_ds),
#                    tokenizer=tokenizer, 
#                    epitope_seq_mutator=None, 
#                    target_seq_mutator=target_seq_mutator)
# print(f'>>>> Mean similarity score of base model: {score}')
score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer)
print(f'>>>> Mean similarity score of fine-tuned model: {score}')

(urllib3.connectionpool) 2023-12-18 18:54:59 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-12-18 18:54:59 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>Similarity score for YLQPRTFLLCASSSVNNNEQFF/YLQPRTFLLCASSLGGYNEQFF: 0.818
>>Similarity score for YLQPRTFLLCAVGEANTGELFF/YLQPRTFLLCAVEDENTGELFF: 0.909
>>Similarity score for YLQPRTFLLCAYQEVNTGELFF/YLQPRTFLLCASSEVNTGELFF: 0.909
>>Similarity score for YLQPRTFLLCSARDDQAVNTGELFF/YLQPRTFLLCSARDDRAGGTGELFF: 0.920
>>Similarity score for YLQPRTFLLCSAGQRNTGELFF/YLQPRTFLLCSVGLRNTGELFF: 0.909
>>Similarity score for YLQPRTFLLCASSLEIEAFF/YLQPRTFLLCASSGENEQFF: 0.850
>>Similarity score for YLQPRTFLLCAGDYLNTGELFF/YLQPRTFLLCAGIGENTGELFF: 0.864
>>Similarity score for YLQPRTFLLCASSPDIACTF/YLQPRTFLLCASSQDIGQFF: 0.800
>>Similarity score for YLQPRTFLLCASSVDNTGELFF/YLQPRTFLLCASSVDGTGELFF: 0.955
>>

- Masking 50% AAs of TCR CDR3$\\beta$ sequences
- The mean similarity scores of the fine-tuned model is 0.865.
- The fine-tuned model score is higher than the fine-tuned model in exp2.
- The fine-tuned model shows reliable performance on unseen evaluation data

## Experiment 4
- Mutation ratio: 0.15->0.2  for epitope and TCR CDR3beta sequences, respectively

In [14]:
model = load_peft_model('../output/exp4/mlm_finetune/checkpoint-1972')

score = eval_model(target_model=model, ds=eval_ds, n_samples=len(eval_ds),
                   tokenizer=tokenizer)
print(f'>>>> Mean similarity score of fine-tuned model: {score}')

(urllib3.connectionpool) 2023-12-18 19:01:12 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
(urllib3.connectionpool) 2023-12-18 19:01:12 [DEBUG]: https://huggingface.co:443 "HEAD /facebook/esm2_t33_650M_UR50D/resolve/main/config.json HTTP/1.1" 200 0
>>Similarity score for YLQPRTFLLCASSSVNNNEQFF/FLQPRTFLICSSGEDGYDELFF: 0.727
>>Similarity score for YLQPRTFLLCAVGEANTGELFF/LLQPRTFLLCSAGDGDTGELFF: 0.864
>>Similarity score for YLQPRTFLLCAYQEVNTGELFF/LLQPRTFYICAVEDANSDELFF: 0.773
>>Similarity score for YLQPRTFLLCSARDDQAVNTGELFF/FLQPRTFLACSAEEPRGGNTGEQFF: 0.760
>>Similarity score for YLQPRTFLLCSAGQRNTGELFF/YLQPRTFLCCAVDEADTGEQFF: 0.773
>>Similarity score for YLQPRTFLLCASSLEIEAFF/SLQPRTFLCCASGLDNEQFF: 0.750
>>Similarity score for YLQPRTFLLCAGDYLNTGELFF/LLQPRTFLLCSSDDADTGELFF: 0.818
>>Similarity score for YLQPRTFLLCASSPDIACTF/FLQPRTFLLCSSSLGYGYYF: 0.700
>>Similarity score for YLQPRTFLLCASSVDNTGELFF/YLQPRTFLLCATSLENTNELFF: 0.955
>>