In [8]:
# This notebook is to tokenize fullseq loocv data for binary classification
# data: cdrs,mhc,label 


In [1]:
import os
import pandas as pd 
from datasets import load_dataset,Dataset,DatasetDict
from transformers import DebertaTokenizerFast

In [2]:
max_length=288
model_name="tcrhlamotifs-crossencoder" 
tokenizer_name=model_name.split("-")[0]
tokenizer = DebertaTokenizerFast.from_pretrained(f'/data/finetuning/tokenizers/{tokenizer_name}', max_len=max_length)
tokenizer_name

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


'tcrhlamotifs'

In [3]:
new_tokens=["[cdra25]","[cdrb25]"]
tokenizer.add_tokens(list(new_tokens))

2

In [4]:
def unk(v):
    return tokenizer.unk_token if v is None else v

In [5]:
def tokenize_batch_binary(batch):
    batch['sequence1'] = ['[cdra1]'+unk(a1)+'[cdra2]'+unk(a2)+'[cdra25]'+unk(a25)+'[cdra3]'+unk(a3)+'[cdrb1]'+unk(b1)+'[cdrb2]'+unk(b2) + '[cdrb25]'+unk(b25) +'[cdrb3]'+unk(b3)
                    for a1,a2,a25,a3,b1,b2,b25,b3 in 
                    zip(batch['cdr1a'], batch['cdr2a'],batch['cdr25a'], batch['cdr3a'],batch['cdr1b'], batch['cdr2b'],batch['cdr25b'],batch['cdr3b'])
                    ]
    batch['sequence2'] = ['[mhc]'+ mhc for mhc in 
                    batch['motif']
                    ]
    return tokenizer(batch['sequence1'], batch['sequence2'], padding='max_length', max_length=max_length, truncation=True)
def tokenize_loocv_binary(loocv_path,save_path):
    mhcs = os.listdir(loocv_path)
    print(len(mhcs))
    for mhc in mhcs:
        data_files = {
            'train': f'{loocv_path}/{mhc}/{mhc}_train.csv',
            'eval': f'{loocv_path}/{mhc}/{mhc}_test.csv',
            'test': f'{loocv_path}/{mhc}/{mhc}_test.csv'
        }
        #print(data_files)
        dataset = load_dataset('csv', data_files=data_files)

        if len(dataset['test'])==0:
            print(f'{mhc} has no test data')
            continue

        remove_columns = list(dataset['test'].features.keys())
        remove_columns.remove('label')

        tokenized_datasets = dataset.map(tokenize_batch_binary, 
                                batched=True, 
                                num_proc=2, 
                                remove_columns=remove_columns)

        tokenized_datasets.save_to_disk(f'{save_path}/{mhc}')

In [6]:
%%capture
ds_name='fs22_loocv_mismhc'
ds_path=f'/data/finetuning/01-BinaryClassification/fullseq/loocv/data/{ds_name}'
print('tokenizing All CDRs')
ds_save_path=(f'/data/finetuning/01-BinaryClassification/fullseq/loocv/tokenized_datasets/{ds_name}_{tokenizer_name}')
tokenize_loocv_binary(ds_path,ds_save_path) 
