In [1]:
import numpy as np
import pandas as pd

import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import MulticlassAccuracy
from sklearn.model_selection import train_test_split

from transformers import BertModel, RobertaTokenizer, BertTokenizer, RobertaTokenizerFast
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Utils

In [3]:
def my_topk_accuracy(preds,targets,topk,ignore_idx,num_classes):

    topk_indices = torch.topk(preds, topk, dim=1)[1]
    topk_indices = topk_indices.cpu()
    targets = targets.cpu()
    total_correct = 0
    total_count = 0
    for i in range(targets.shape[0]):
        if targets[i] == ignore_idx:
            continue
        if targets[i] in topk_indices[i]:
            total_correct += 1
        total_count += 1
    if total_count == 0:
        print("debug target")
        print(targets)
    accuracy = total_correct / total_count
    return accuracy

def torch_metrics_accuracy(preds,targets,ignore_idx,num_classes):
    preds = preds.cpu()
    targets = targets.cpu()
    accuracy_obj = MulticlassAccuracy(num_classes,
                                      average='micro',
                                      ignore_index=ignore_idx)
    accuracy = accuracy_obj(preds, targets)
    return accuracy

In [4]:
dropout = 0.25
batch_size = 4
epochs = 20
lr = 0.0001
distance_mask_token = 2
no_mask_token = 1
max_len = 371
bert_max_len = 373
layer_no_to_view = 29
head_no_to_view = 0

#data_path = '/home/andrewkim/Desktop/GPCRBert/data/final_edry_class.npy'
data_path = '/home/andrewkim/Desktop/GPCRBert/data/final_cwxp_class.npy'
#data_path = '/home/andrewkim/Desktop/GPCRBert/data/final_npxxy_class.npy'
#parameter_path = '/home/andrewkim/Desktop/GPCRBert/parameter/proteins_EDRY.pt'
parameter_path = '/home/andrewkim/Desktop/GPCRBert/parameter/proteins_CWXP.pt'
#parameter_path = '/home/andrewkim/Desktop/GPCRBert/parameter/proteins_NPXXY.pt'

### Tokenization


In [5]:
### EXPERIMENT

data = np.load(data_path, allow_pickle=True)
pdb = data[0][0]
seq = list(data[0][1]) # inputfull
seq_join = ''.join(seq) # input_full_str
pad_start = len(seq_join) + 1 # +1 for the [CLS] token

# findind start and end of req_pre_string
motif = data[0][2] # req_pre
motif_join = ''.join(motif) # requ_pre_str
start_idx = seq_join.find(motif_join)
end_idx = start_idx + 1

seq_list = list(seq_join)
seq_list[start_idx + distance_mask_token : start_idx + distance_mask_token + no_mask_token] = 'J' * no_mask_token
label_list = list(seq_join)
label_list[:start_idx+distance_mask_token] = 'J'*len(label_list[:start_idx+distance_mask_token])
label_list[start_idx+distance_mask_token+no_mask_token:] = 'J'*len(label_list[start_idx+distance_mask_token+no_mask_token:])

seq_list_spaced = ' '.join(seq_list)
label_list_spaced = ' '.join(label_list)
print(seq_list_spaced)
print(label_list_spaced)
seq_list_spaced = seq_list_spaced.replace('J', '[MASK]')
label_list_spaced = label_list_spaced.replace('J', '[MASK]')

tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert')
seq_tokenized = tokenizer(seq_list_spaced, return_tensors='pt', padding='max_length', max_length=bert_max_len)
label_tokenized = tokenizer(label_list_spaced, return_tensors='pt', padding='max_length', max_length=bert_max_len)

label_tokenized['input_ids'][label_tokenized['input_ids'] == 4] = 0
label_tokenized['input_ids'][label_tokenized['input_ids'] == 3] = 0
label_tokenized['input_ids'][label_tokenized['input_ids'] == 2] = 0
label_tokenized['input_ids'][label_tokenized['input_ids'] == 1] = 0

# convert attention of mask to 0 in input_tokenized which is 4
seq_tokenized['attention_mask'][seq_tokenized['input_ids'] == 4] = 0

seq_vocab = tokenizer.convert_ids_to_tokens(seq_tokenized['input_ids'][0])
label_vocab = tokenizer.convert_ids_to_tokens(label_tokenized['input_ids'][0])

#print(seq_tokenized, label_tokenized, start_idx+1, end_idx+1, pdb)

# bert_vocab = dataset.tokenizer.vocab
# inverse_vocab = {v: k for k, v in bert_vocab.items()}


# for b_no, (input_token, label_token) in enumerate(zip(seq_tokenized['input_ids'], 
#                                                               label_tokenized['input_ids'])):
#     #print(b_no, input_token, label_token)

#     input_str_list = [inverse_vocab[int(x)] for x in input_token.tolist()]
#     #print(input_str_list)
#     input_str = "".join(input_str_list)
#     #print(input_str)
#     label_str_list = [inverse_vocab[int(x)] for x in label_token.tolist()]
#     label_str = "".join(label_str_list)
#     #print(input_str, label_str)


A A D E V W V V G M G I V M S L I V L A I V F G N V L V I T A I A K F E R L Q T V T N Y F I T S L A C A D L V M G L A V V P F G A A C I L T K T W T F G N F W C E F W T S I D V L C V T A S I E T L C V I A V D R Y F A I T S P F K Y Q S L L T K N K A R V I I L M V W I V S G L T S F L P I Q M H W Y R A T H Q E A I N C Y A E E T C C D F F T N Q A Y A I A S S I V S F Y V P L V I M V F V Y S R V F Q E A K R Q L Q X X X X X X X X X X X X X X X X X X X X X X X X X X X X X X X K F A L K E H K A L K T L G I I M G T F T L C W J P F F I V N I V H V I Q D N L I R K E V Y I L L N W I G Y V N S G F N P L I Y C R S P D F R I A F Q E L L C L
J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J J 

In [6]:
class PositionPredictionFullDataset(torch.utils.data.Dataset):

    def __init__(self,df) -> None:
        super().__init__()
        self.df = df
        self.tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert')
        print(f"vocabulary {self.tokenizer.vocab}")

        # # finding an alphabet that is not in the vocabulary
        # capital_letters = [chr(i) for i in range(ord('A'), ord('Z')+1)]
        # for letter in capital_letters:
        #     if letter not in self.tokenizer.vocab:
        #         print(f"letter {letter} not in vocabulary")

        self.max_len = max_len
        self.bert_max_len = bert_max_len
        self.my_mask_token = 'J'
        self.bert_mask_token = '[MASK]'
        self.distance_mask_token = distance_mask_token
        self.no_mask_token = no_mask_token

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        seq = list(self.df[idx][1]) # inputfull
        seq_join = ''.join(seq) # input_full_str
        pad_start = len(seq_join) + 1 # +1 for the [CLS] token

        # findind start and end of req_pre_string
        motif = self.df[idx][2] # req_pre
        motif_join = ''.join(motif) # requ_pre_str
        start_idx = seq_join.find(motif_join)
        end_idx = start_idx + self.distance_mask_token + self.no_mask_token

        # replacing the two positions with mask token
        seq_list = list(seq_join)
        seq_list[start_idx + self.distance_mask_token : start_idx + self.distance_mask_token + self.no_mask_token] = self.my_mask_token * self.no_mask_token
        label_list = list(seq_join)
        label_list[:start_idx+self.distance_mask_token] = self.my_mask_token*len(label_list[:start_idx+self.distance_mask_token])
        label_list[start_idx+self.distance_mask_token+self.no_mask_token:] = self.my_mask_token*len(label_list[start_idx+self.distance_mask_token+self.no_mask_token:])

        seq_list_spaced = ' '.join(seq_list)
        label_list_spaced = ' '.join(label_list)

        seq_list_spaced = seq_list_spaced.replace(self.my_mask_token, self.bert_mask_token)
        seq_tokenized = self.tokenizer(seq_list_spaced, return_tensors='pt', padding='max_length', max_length=self.bert_max_len)

        label_list_spaced = label_list_spaced.replace(self.my_mask_token, self.bert_mask_token)
        label_tokenized = self.tokenizer(label_list_spaced, return_tensors='pt', padding='max_length', max_length=self.bert_max_len)

        # convert label_full_tokenized to 4 to 0 (??)
        label_tokenized['input_ids'][label_tokenized['input_ids'] == 4] = 0
        label_tokenized['input_ids'][label_tokenized['input_ids'] == 3] = 0
        label_tokenized['input_ids'][label_tokenized['input_ids'] == 2] = 0
        label_tokenized['input_ids'][label_tokenized['input_ids'] == 1] = 0

        # convert attention of mask to 0 in input_tokenized which is 4
        seq_tokenized['attention_mask'][seq_tokenized['input_ids'] == 4] = 0
        # print(label_tokenized['input_ids'])
        seq_vocab = self.tokenizer.convert_ids_to_tokens(seq_tokenized['input_ids'][0])
        label_vocab = self.tokenizer.convert_ids_to_tokens(label_tokenized['input_ids'][0])

        pdb = self.df[idx][0]
        seq_len = len(self.df[idx][1])

        return seq_tokenized, label_tokenized, start_idx+self.distance_mask_token+1, end_idx+1, pdb

### Model

In [7]:
class ClassificationTwoPositionsBert(nn.Module):
    def __init__(self) -> None:
        super(ClassificationTwoPositionsBert, self).__init__()
        self.encoder = BertModel.from_pretrained('Rostlab/prot_bert', output_attentions=True)

        for key, value in self.encoder.encoder.named_parameters():
            layer_num = int(key.split('.')[1])
            if layer_num < 23:
                value.requires_grad = False

        self.fc = nn.Sequential(nn.Linear(1024,256),
                                nn.ReLU(),
                                nn.Dropout(dropout),
                                nn.Linear(256,30))

    def forward(self, input_tokens):
        # input_tokens: {input_ids, attention_mask, token_type_ids}
        encoded_features = self.encoder(**input_tokens)['last_hidden_state']
        # N*embedding_length*30
        logits = self.fc(encoded_features)
        logits = torch.permute(logits, (0, 2, 1))
        return logits # N*30*embedding_length

    def forward_test(self, input_tokens):
        # input_tokens: {input_ids, attention_mask, token_type_ids}
        encoded = self.encoder(**input_tokens)
        encoded_features = encoded['last_hidden_state'] # N*max_len*hidden_dims
        attentions = encoded['attentions']
        logits = self.fc(encoded_features)
        logits = torch.permute(logits, (0, 2, 1))
        return logits, attentions # N*classes*max_len, N*num_heads*max_len*max_len

### Inference Test

### 1

In [8]:
def test_forward_two_position_prediction(model,data_loader,loss_fn,device):

    model = model.eval()
    total_loss = 0
    total_count = 0
    total_gts = torch.zeros((0), dtype=torch.long).to(device)
    total_preds = torch.zeros((0), dtype=torch.long).to(device)
    total_logits = torch.zeros((0, 30),dtype=torch.float).to(device)
    loop = tqdm(data_loader,leave=True,total=len(data_loader),colour='green')

    for idx, batch in enumerate(loop):
        input_tokens,label_tokens, start_idx, end_idx, pdb = batch

        input_tokens['input_ids'] = input_tokens['input_ids'].to(device).squeeze(1)
        input_tokens['token_type_ids'] = input_tokens['token_type_ids'].to(device).squeeze(1)
        input_tokens['attention_mask'] = input_tokens['attention_mask'].to(device).squeeze(1)

        label_tokens['input_ids'] = label_tokens['input_ids'].to(device).squeeze(1)
        label_tokens['token_type_ids'] = label_tokens['token_type_ids'].to(device).squeeze(1)
        label_tokens['attention_mask'] = label_tokens['attention_mask'].to(device).squeeze(1)

        with torch.no_grad():
            out = model(input_tokens)
            iteration_loss = loss_fn(out, label_tokens['input_ids'])

        total_loss += iteration_loss.item() * input_tokens['input_ids'].shape[0]
        total_count += input_tokens['input_ids'].shape[0]
        out = out.permute(0, 2, 1)

        for b_idx in range(out.shape[0]):

            #print(start_idx[b_idx], end_idx[b_idx])
            for seq_idx in range(start_idx[b_idx], end_idx[b_idx]):
                
                total_logits = torch.cat((total_logits, out[b_idx, seq_idx, :].reshape(1, -1)), dim=0)
                total_gts = torch.cat((total_gts, label_tokens['input_ids'][b_idx, seq_idx].reshape(1)), dim=0)
                #print(total_logits.shape)
                #print(total_gts.shape)
                #print(seq_idx, label_tokens['input_ids'][b_idx, seq_idx].reshape(1))

        loop.set_description(f"Loss: {iteration_loss.item()}")
    total_preds = torch.argmax(total_logits, dim=1)
    test_accuracy = torch_metrics_accuracy(total_preds.reshape(-1),total_gts.reshape(-1), num_classes=30,ignore_idx=0)

    for k in range(1, 5):
        test_topk_accuracy = my_topk_accuracy(total_logits,total_gts.reshape(-1),topk=k,ignore_idx=0,num_classes=30)
        print(f"Test Top-{k} Accuracy: {test_topk_accuracy}")

    #print(total_preds.shape)
    #print(total_gts.shape)
    final_pred = total_preds.reshape(-1)
    final_gts = total_gts.reshape(-1)
    #print(final_pred.shape)
    #print(final_gts.shape)

    return test_accuracy, total_loss / total_count, 0.0, final_pred, final_gts

In [9]:
def inference_test(model,test_loader,device,vocab,data_name):
    model.eval()
    total_loss = 0
    total_count = 0
    total_gts = torch.zeros((0), dtype=torch.long).to(device)
    total_preds = torch.zeros((0), dtype=torch.long).to(device)
    total_logits = torch.zeros((0,30),dtype=torch.float).to(device)
    loop = tqdm(test_loader,total=len(test_loader),leave=True) 
    

    weight_and_virus = []
    for idx, batch in enumerate(loop):
        input_tokens, label_tokens, start_id, end_id, name = batch

        input_tokens['input_ids'] = input_tokens['input_ids'].to(device).squeeze(1)
        input_tokens['attention_mask'] = input_tokens['attention_mask'].to(device).squeeze(1)
        input_tokens['token_type_ids'] = input_tokens['token_type_ids'].to(device).squeeze(1)
        
        label_tokens['input_ids'] = label_tokens['input_ids'].to(device).squeeze(1)
        label_tokens['attention_mask'] = label_tokens['attention_mask'].to(device).squeeze(1)
        label_tokens['token_type_ids'] = label_tokens['token_type_ids'].to(device).squeeze(1)

        with torch.no_grad():
            out, attention = model.forward_test(input_tokens)

        
        for b_no, (input_token, label_token) in enumerate(zip(input_tokens['input_ids'],label_tokens['input_ids'])):
            
            ### Integers to Residues
            input_str_list = [vocab[int(x)] for x in input_token.tolist()]
            #print(input_str_list)
            input_str = "".join(input_str_list)
            #print(input_str)
            label_str_list = [vocab[int(x)] for x in label_token.tolist()]
            #print(label_str_list)
            #label_str_list = [vocab.get(int(x),"") for x in label_token.tolist()]
            label_str = "".join(label_str_list)

            input_str = input_str.replace('[CLS]', 'J')
            input_str = input_str.replace('[SEP]', 'J')
            input_str = input_str.replace('[PAD]', 'J')
            input_str = input_str.replace('[UNK]', 'J')
            input_str = input_str.replace('[MASK]', 'J')

            label_str = label_str.replace('[CLS]', 'J')
            label_str = label_str.replace('[SEP]', 'J')
            label_str = label_str.replace('[PAD]', 'J')
            label_str = label_str.replace('[UNK]', 'J')
            label_str = label_str.replace('[MASK]', 'J')


            pdb_name = name[b_no]
            #print(f'pdb : {pdb_name}')
            #print(f"Input: {input_str[1:]}")
            req_attention = attention[layer_no_to_view]
            req_attention = req_attention[b_no, head_no_to_view,:,:].squeeze(0).cpu().numpy()
            req_attention = req_attention[1:, 1:]
            #print(req_attention.shape)

            weight_and_virus.append({"pdb": pdb_name ,"seq": input_str[1:], "attention": req_attention})

    
    weight_and_virus_np = np.array(weight_and_virus)
    np.save(f"{data_name}_(head{head_no_to_view}).npy", weight_and_virus_np)
    print(weight_and_virus_np.shape)

### 2

In [10]:
data = np.load(data_path, allow_pickle=True)
req_save_name = parameter_path.split("/")[-1].split(".")[0]

# Create dataset
dataset = PositionPredictionFullDataset(data)

# bert_tokenizer
bert_vocab = dataset.tokenizer.vocab

# Create dataloader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Create model
model = ClassificationTwoPositionsBert().to(device)

# Create optimizer, loss function, scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss(ignore_index=0,label_smoothing=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=1,verbose=True)
inverse_vocab = {v: k for k, v in bert_vocab.items()}

save_dict = torch.load(parameter_path)
model.load_state_dict(save_dict['model'], strict=False)
optimizer.load_state_dict(save_dict['optimizer'])
scheduler.load_state_dict(save_dict['scheduler'])
epoch = save_dict['epoch']

# test first to see if model is properly trained
# test_accuracy, test_loss, _, final_pred,final_gts = test_forward_two_position_prediction(model, data_loader, loss_fn, device)

# For result of single head
#head_no_to_view = 0
#inference_test(model,data_loader, device,inverse_vocab, data_name=req_save_name)

# For results of all heads
for i in range(0, 16):
    head_no_to_view = i
    inference_test(model,data_loader, device,inverse_vocab, data_name=req_save_name)

vocabulary OrderedDict([('[PAD]', 0), ('[UNK]', 1), ('[CLS]', 2), ('[SEP]', 3), ('[MASK]', 4), ('L', 5), ('A', 6), ('G', 7), ('V', 8), ('E', 9), ('S', 10), ('I', 11), ('K', 12), ('R', 13), ('D', 14), ('T', 15), ('P', 16), ('N', 17), ('Q', 18), ('F', 19), ('Y', 20), ('M', 21), ('H', 22), ('C', 23), ('W', 24), ('X', 25), ('U', 26), ('B', 27), ('Z', 28), ('O', 29)])


100%|██████████| 42/42 [00:06<00:00,  6.15it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.42it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.40it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.41it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.39it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.40it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.41it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.38it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.37it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.36it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.34it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.32it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.29it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.26it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.32it/s]


(168,)


100%|██████████| 42/42 [00:06<00:00,  6.29it/s]


(168,)


### Confirm

In [11]:
weights = np.load('/home/andrewkim/Desktop/GPCRBert/proteins_NPXXY_(head2).npy', allow_pickle=True)
weights[0]
print(weights[0]['attention'].shape)
print(weights.shape)

FileNotFoundError: [Errno 2] No such file or directory: '/home/andrewkim/Desktop/GPCRBert/proteins_NPXXY_(head2).npy'