In [1]:
import os
import random
import numpy as np
import pandas as pd
import argparse

def return_position_single(mutation):
    """Note: Only works for single mutations"""
    position = mutation.split(":")[0][1:-1]
    return int(position)

def keep_singles(DMS, mutant_column='mutant'):
    DMS = DMS[~DMS[mutant_column].str.contains(":")]
    return DMS


def create_folds_random(DMS, n_folds=5, mutant_column='mutant'):
    column_name = 'fold_random_{}'.format(n_folds)
    try:
        mutated_region_list = DMS[mutant_column].apply(lambda x: return_position_single(x)).unique()
    except:
        print("Mutated region not found from 'mutant' variable -- assuming the full protein sequence is mutated")
        mutated_region_list = range(len(DMS['mutated_sequence'].values[0]))
    len_mutated_region = len(mutated_region_list)
    if len_mutated_region < n_folds:
        raise Exception("Error, there are fewer mutated regions than requested folds")
    DMS[column_name] = np.random.randint(0, n_folds, DMS.shape[0])
    print(DMS[column_name].value_counts())
    return DMS

def create_folds_by_position_modulo(DMS, n_folds=5, mutant_column='mutant'):
    column_name = 'fold_modulo_{}'.format(n_folds)
    mutated_region_list = sorted(DMS[mutant_column].apply(return_position_single).unique())
    len_mutated_region = len(mutated_region_list)
    if len_mutated_region < n_folds:
        raise Exception("Error, there are fewer mutated regions than requested folds")
    position_to_fold = {pos: i % n_folds for i, pos in enumerate(mutated_region_list)}
    DMS[column_name] = DMS[mutant_column].apply(lambda x: position_to_fold[return_position_single(x)])
    print(DMS[column_name].value_counts())
    return DMS

def create_folds_by_contiguous_position_discontiguous(DMS, n_folds=5, mutant_column='mutant'):
    column_name = 'fold_contiguous_{}'.format(n_folds)
    mutated_region_list = sorted(DMS[mutant_column].apply(lambda x: return_position_single(x)).unique())
    len_mutated_region = len(mutated_region_list)
    k, m = divmod(len_mutated_region, n_folds)
    folds = [[i] * k + [i] * (i < m) for i in range(n_folds)]
    folds = [item for sublist in folds for item in sublist]
    folds_indices = dict(zip(mutated_region_list, folds))
    if len_mutated_region < n_folds:
        raise Exception("Error, there are fewer mutated regions than requested folds")
    DMS[column_name] = DMS[mutant_column].apply(lambda x: folds_indices[return_position_single(x)])
    print(DMS[column_name].value_counts())
    return DMS



In [2]:
def spearman(y_pred, y_true):
    if np.var(y_pred) < 1e-6 or np.var(y_true) < 1e-6:
        return 0.0
    return spearmanr(y_pred, y_true)[0]

def compute_stat(sr):
    sr = np.asarray(sr)
    mean = np.mean(sr)
    std = np.std(sr)
    sr = (sr,)
    ci = list(bootstrap(sr, np.mean).confidence_interval)
    return mean, std, ci

In [3]:
def sample_data(dataset_name, seed, shot, frac=0.2):
    '''
    sample the train data and test data
    :param seed: sample seed
    :param frac: the fraction of testing data, default to 0.2
    :param shot: the size of training data
    '''

    data = pd.read_csv(f'data/{dataset_name}/data.csv', index_col=0)
    test_data = data.sample(frac=frac, random_state=seed)
    train_data = data.drop(test_data.index)
    kshot_data = train_data.sample(n=shot, random_state=seed)
    assert len(kshot_data) == shot, (
        f'expected {shot} train examples, received {len(train_data)}')

    kshot_data.to_csv(f'data/{dataset_name}/train.csv')
    test_data.to_csv(f'data/{dataset_name}/test.csv')

In [6]:
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from scipy.stats import spearmanr
from scipy import stats
from scipy.stats import bootstrap
import numpy as np
import os
import re

class Mutation_Set(Dataset):
    def __init__(self, data, tokenizer, sep_len=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.seq_len = sep_len
        self.seq, self.attention_mask = tokenizer(list(self.data['mutated_sequence']), padding='max_length',
                                                  truncation=True,
                                                  max_length=self.seq_len).values()

        target = list(data['target_seq'])
        self.target, self.tgt_mask = tokenizer(target, padding='max_length', truncation=True,
                                               max_length=self.seq_len).values()
        self.score = torch.tensor(np.array(self.data['DMS_score']))
        self.pid = np.asarray(data['PID'])

        if type(list(self.data['mut_pos'])[0]) != str:
            self.position = [[u] for u in self.data['mut_pos']]

        else:
            self.position = []
            for u in self.data['mut_pos']:
                p = re.findall(r'\d+', u)
                pos = [int(v) for v in p]
                self.position.append(pos)



    def __getitem__(self, idx):
        return [self.seq[idx], self.attention_mask[idx], self.target[idx],self.tgt_mask[idx] ,self.position[idx], self.score[idx], self.pid[idx]]

    def __len__(self):
        return len(self.score)

    def collate_fn(self, data):
        seq = torch.tensor(np.array([u[0] for u in data]))
        att_mask = torch.tensor(np.array([u[1] for u in data]))
        tgt = torch.tensor(np.array([u[2] for u in data]))
        tgt_mask = torch.tensor(np.array([u[3] for u in data]))
        pos = [torch.tensor(u[4]) for u in data]
        score = torch.tensor(np.array([u[5] for u in data]), dtype=torch.float32)
        pid = torch.tensor(np.array([u[6] for u in data]))
        return seq, att_mask, tgt, tgt_mask, pos, score, pid

In [8]:
from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig
tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_1')
train_csv = pd.read_csv(f'../data/cleaned_split_data_singles/fold_random_5/data_1.csv')

In [9]:
trainset = Mutation_Set(data=train_csv, tokenizer=tokenizer)
trainloader = DataLoader(trainset, batch_size=4, collate_fn=trainset.collate_fn, shuffle=True,num_workers=96)

In [10]:
for step, data in enumerate(trainloader):
    
    seq, mask = data[0], data[1]
    wt, wt_mask = data[2], data[3]
    pos = data[4]  
    mask_seq = seq.clone()
    m_id = tokenizer.mask_token_id

    batch_size = int(seq.shape[0])
    for i in range(batch_size):
        mut_pos = pos[i]
        mask_seq[i, mut_pos+1] = m_id

    print(mask_seq)
    break

 


tensor([[ 0, 20, 15,  ..., 15,  8,  2],
        [ 0, 20, 10,  ...,  1,  1,  1],
        [ 0, 20, 11,  ...,  1,  1,  1],
        [ 0, 20, 18,  ...,  8,  5,  2]])


In [None]:
m_id

32

In [None]:
tokenizer.add_tokens

<bound method SpecialTokensMixin.add_tokens of EsmTokenizer(name_or_path='facebook/esm1v_t33_650M_UR90S_1', vocab_size=33, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}>

In [None]:
tokenizer.all_tokens

['<cls>',
 '<pad>',
 '<eos>',
 '<unk>',
 'L',
 'A',
 'G',
 'V',
 'S',
 'E',
 'R',
 'T',
 'I',
 'D',
 'P',
 'K',
 'Q',
 'N',
 'F',
 'Y',
 'M',
 'H',
 'W',
 'C',
 'X',
 'B',
 'U',
 'Z',
 'O',
 '.',
 '-',
 '<null_1>',
 '<mask>']

TypeError: PreTrainedTokenizerBase.encode() missing 1 required positional argument: 'text'