# Find the model

## Global configurations

In [1]:
import logging
import logging.config
import os
import sys
import warnings
from enum import auto
import pandas as pd
import numpy as np
from IPython.core.display import display

rootdir = '/home/hym/trunk/TCRBert'
workdir = '%s/notebook' % rootdir
datadir = '%s/data' % rootdir
srcdir = '%s/tcrbert' % rootdir
outdir = '%s/output' % rootdir

os.chdir(workdir)

sys.path.append(rootdir)
sys.path.append(srcdir)

from tcrbert.exp import Experiment
from tcrbert.predlistener import PredResultRecoder


# Display
pd.set_option('display.max.rows', 2000)
pd.set_option('display.max.columns', 2000)

# Logger
warnings.filterwarnings('ignore')
logging.config.fileConfig('../config/logging.conf')
logger = logging.getLogger('tcrbert')
logger.setLevel(logging.INFO)

# Target experiment
exp_key = 'exp1'
experiment = Experiment.from_key(exp_key)

exp_conf = experiment.exp_conf

display(exp_conf)

2021-11-15 22:00:25 [INFO]: Loaded exp_conf: {'title': 'exp1', 'description': 'Fine-tuning of pre-trained TAPE model in a progressively specialized manner', 'paper': 'exp1', 'model_config': '../config/bert-base/', 'outdir': '../output/exp1', 'train': {'pretrained_model': {'type': 'tape', 'location': '../config/bert-base/'}, 'data_parallel': False, 'backup': 'train.bak.{date}.tar.gz', 'rounds': [{'data': 'dash_vdjdb_mcpas', 'test_size': 0.2, 'batch_size': 128, 'n_epochs': 150, 'n_workers': 12, 'metrics': ['accuracy'], 'optimizer': {'type': 'adam', 'lr': 0.0001}, 'train_bert_encoders': [-10, None], 'early_stopper': {'monitor': 'accuracy', 'patience': 15}, 'model_checkpoint': {'chk': 'train.{round}.model_{epoch}.chk', 'monitor': 'accuracy', 'save_best_only': True, 'period': 1}, 'result': 'train.{round}.result.json'}, {'data': 'iedb_sars2', 'test_size': 0.2, 'batch_size': 128, 'n_epochs': 100, 'n_workers': 12, 'metrics': ['accuracy'], 'optimizer': {'type': 'adam', 'lr': 0.0001}, 'train_ber

{'title': 'exp1',
 'description': 'Fine-tuning of pre-trained TAPE model in a progressively specialized manner',
 'paper': 'exp1',
 'model_config': '../config/bert-base/',
 'outdir': '../output/exp1',
 'train': {'pretrained_model': {'type': 'tape',
   'location': '../config/bert-base/'},
  'data_parallel': False,
  'backup': 'train.bak.{date}.tar.gz',
  'rounds': [{'data': 'dash_vdjdb_mcpas',
    'test_size': 0.2,
    'batch_size': 128,
    'n_epochs': 150,
    'n_workers': 12,
    'metrics': ['accuracy'],
    'optimizer': {'type': 'adam', 'lr': 0.0001},
    'train_bert_encoders': [-10, None],
    'early_stopper': {'monitor': 'accuracy', 'patience': 15},
    'model_checkpoint': {'chk': 'train.{round}.model_{epoch}.chk',
     'monitor': 'accuracy',
     'save_best_only': True,
     'period': 1},
    'result': 'train.{round}.result.json'},
   {'data': 'iedb_sars2',
    'test_size': 0.2,
    'batch_size': 128,
    'n_epochs': 100,
    'n_workers': 12,
    'metrics': ['accuracy'],
    'opt

## Find target aa position in the epitope

In [None]:
from tcrbert.dataset import TCREpitopeSentenceDataset, CN
from collections import OrderedDict, Counter
from torch.utils.data import DataLoader

epitope = 'YLQPRTFLL'
epitope_len = len(epitope)

sh_ds = TCREpitopeSentenceDataset.from_key('shomuradova')
sh_df = sh_ds.df_enc
im_ds = TCREpitopeSentenceDataset.from_key('immunecode')

# Remove duplicated CDR3beta seqs with Shomuradova
im_ds.df_enc = im_ds.df_enc[
        im_ds.df_enc[CN.cdr3b].map(lambda seq: seq not in sh_df[CN.cdr3b].values)
]
im_df = im_ds.df_enc

n_found = 0
n_train = 0
metrics = ['accuracy', 'f1', 'roc_auc']

target_attn_pos = 3

while(n_found < 5):
    found = True

    logger.info('>>>Begin train %s' % n_train)
    
    experiment.train()
    
    logger.info('>>>Done to train %s' % n_train)
    n_train = n_train + 1
    
    for i in range(experiment.n_train_rounds):
        train_result = experiment.get_train_result(i)
        logger.info('Round %s train results======================' % i)
        logger.info('n_epochs: %s' % train_result['n_epochs'])
        logger.info('stopped_epoch: %s' % train_result['stopped_epoch'])
        logger.info('best_epoch: %s' % train_result['best_epoch'])
        logger.info('best_score: %s' % train_result['best_score'])
        logger.info('best_chk: %s' % train_result['best_chk'])
    
    
    model = experiment.load_eval_model()
    eval_recoder = PredResultRecoder(output_attentions=True, output_hidden_states=True)
    model.add_pred_listener(eval_recoder)    
    
    for ds, max_cum_ratio in zip([sh_ds, im_ds], [0.9, 0.85]):
        df = ds.df_enc
        data_loader = DataLoader(ds, batch_size=len(ds), shuffle=False, num_workers=2)
        logger.info('Predicting for %s' % ds.name)
        model.predict(data_loader=data_loader, metrics=metrics)
        logger.info('Performace score_map for %s: %s' % (ds.name, eval_recoder.result_map['score_map']))
        
        output_labels = np.array(eval_recoder.result_map['output_labels'])
        
        # Select target CDR3b sequences with most common lengths
        pos_indices = np.where(output_labels == 1)[0]
        # print('pos_indices: %s(%s)' % (pos_indices, str(pos_indices.shape)))
        pos_cdr3b = df[CN.cdr3b].values[pos_indices]

        lens, cnts = zip(*sorted(Counter(map(lambda x: len(x), pos_cdr3b)).items()))
        lens = np.array(lens)
        cnts = np.array(cnts)

        # Select target indices by cdr3b sequence lenghts
        target_index_map = OrderedDict()
        order = np.argsort(cnts)[::-1]
        cum_cnt = 0
        for cur_len, cur_cnt in zip(lens[order], cnts[order]):
            cum_cnt += cur_cnt
            cum_ratio = cum_cnt/pos_indices.shape[0]
            if cum_ratio < max_cum_ratio:
                target_indices = np.where((output_labels == 1) & (df[CN.cdr3b].map(lambda x: len(x) == cur_len)))[0]
                logger.debug('target_indices for %s: %s(%s)' % (cur_len, target_indices, target_indices.shape[0]))
                target_index_map[cur_len] = target_indices
        
        # Investigate attention weights
        attentions = eval_recoder.result_map['attentions']
        # attentions.shape: (n_layers, n_data, n_heads, max_len, max_len)
        logger.info('attentions.shape: %s' % str(attentions.shape))

        for i, (cur_len, cur_indices) in enumerate(target_index_map.items()):
            attns = attentions[:, cur_indices]
            sent_len = epitope_len + cur_len

            # Marginalized position-wise attentions by mean
            attns = np.mean(attns, axis=(0, 1, 2, 3))[1:sent_len+1]
            logger.info('Marginalized attns for cdr3b %s: %s (%s)' % (cur_len, attns, str(attns.shape)))
            
            epi_attns = attns[:epitope_len]
            cur_max_attn_pos = np.argmax(epi_attns)
            logger.info('Current max epitope attention weight: %s at %s' % (epi_attns[cur_max_attn_pos], 
                                                                            cur_max_attn_pos))
            if target_attn_pos != cur_max_attn_pos:
                found = False
 
    if found:
        logger.info('>>>>>Found it!, backup train results, n_found: %s' % n_found)
        experiment.backup_train_results()
        n_found = n_found + 1
            

2021-11-15 22:03:14 [INFO]: >>>Begin train 0
2021-11-15 22:03:14 [INFO]: Begin train at 2021-11-15 22:03:14.458654
2021-11-15 22:03:14 [INFO]: Loading the TAPE pretrained model from ../config/bert-base/
2021-11-15 22:03:18 [INFO]: Start 2 train rounds of exp1 at 2021-11-15 22:03:14.458654
2021-11-15 22:03:18 [INFO]: train_conf: {'pretrained_model': {'type': 'tape', 'location': '../config/bert-base/'}, 'data_parallel': False, 'backup': 'train.bak.{date}.tar.gz', 'rounds': [{'data': 'dash_vdjdb_mcpas', 'test_size': 0.2, 'batch_size': 128, 'n_epochs': 150, 'n_workers': 12, 'metrics': ['accuracy'], 'optimizer': {'type': 'adam', 'lr': 0.0001}, 'train_bert_encoders': [-10, None], 'early_stopper': {'monitor': 'accuracy', 'patience': 15}, 'model_checkpoint': {'chk': 'train.{round}.model_{epoch}.chk', 'monitor': 'accuracy', 'save_best_only': True, 'period': 1}, 'result': 'train.{round}.result.json'}, {'data': 'iedb_sars2', 'test_size': 0.2, 'batch_size': 128, 'n_epochs': 100, 'n_workers': 12, '

Training in epoch 0/150: 100%|██████████| 158/158 [03:18<00:00,  1.26s/batch]
Validating in epoch 0/150: 100%|██████████| 40/40 [00:37<00:00,  1.08batch/s]

2021-11-15 22:07:15 [INFO]: [EvalScoreRecoder]: In epoch 0/150, loss train score: 0.6940721258332457, val score: 0.6900263592600823
2021-11-15 22:07:15 [INFO]: [EvalScoreRecoder]: In epoch 0/150, accuracy train score: 0.5162607368896927, val score: 0.5508680555555555
2021-11-15 22:07:15 [INFO]: [EarlyStopper]: In epoch 0/150, accuracy score: 0.5508680555555555, best accuracy score: -inf;update best score to 0.5508680555555555
2021-11-15 22:07:15 [INFO]: [ModelCheckpoint]: Checkpoint at epoch 0: accuracy improved from -inf to 0.5508680555555555, saving model to ../output/exp1/train.0.model_0.chk



Training in epoch 1/150: 100%|██████████| 158/158 [03:21<00:00,  1.27s/batch]
Validating in epoch 1/150: 100%|██████████| 40/40 [00:36<00:00,  1.10batch/s]

2021-11-15 22:11:14 [INFO]: [EvalScoreRecoder]: In epoch 1/150, loss train score: 0.6928321036356914, val score: 0.6936042934656144
2021-11-15 22:11:14 [INFO]: [EvalScoreRecoder]: In epoch 1/150, accuracy train score: 0.5163808205244123, val score: 0.5004991319444445
2021-11-15 22:11:14 [INFO]: [EarlyStopper]: In epoch 1/150, accuracy score: 0.5004991319444445, best accuracy score: 0.5508680555555555;accuracy score was not improved
2021-11-15 22:11:14 [INFO]: [EarlyStopper]: Current wait count: 1, patience: 15
2021-11-15 22:11:14 [INFO]: [ModelCheckpoint]: Checkpoint at epoch 1: accuracy did not improve



Training in epoch 2/150:  85%|████████▍ | 134/158 [02:49<00:30,  1.25s/batch]