#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import numpy as np
import time
from tqdm import tqdm
from torch.utils import data
import pandas as pd
import random
from sklearn.metrics import classification_report

#### The BERT Model For Loading

In [2]:
class SingleAttention(nn.Module):
    def __init__(self, d_model):
        
        super(SingleAttention, self).__init__()
        
        self.d_k = int(d_model / 8)
        
        self.W_Q = nn.Linear(d_model, self.d_k)
        self.W_K = nn.Linear(d_model, self.d_k)
        self.W_V = nn.Linear(d_model, self.d_k)
        
    def forward(self, x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        
        A = torch.matmul(Q, torch.transpose(K, 0, 1)) / torch.sqrt(torch.tensor(self.d_k))
        
        A = F.softmax(A, dim=1)
        
        V_prime = torch.matmul(A, V)
        
        return V_prime

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head, device):
        
        super(MultiHeadAttention, self).__init__()
        
        self.d_k = int(d_model / 8)
        self.n_head = n_head
        
        self.attentions = []
        for i in range(self.n_head):
            self.attentions.append(SingleAttention(d_model).to(device))
        
        self.W_O = nn.Linear(n_head * self.d_k, d_model)
    
    def forward(self, x):
        Vs = []
        for i in range(self.n_head):
            Vs.append(self.attentions[i](x))
        
        V = torch.cat(tuple(Vs), dim=1)
        
        x = self.W_O(V)
        
        return x

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, device):
        
        super(TransformerBlock, self).__init__()
        
        self.mha = MultiHeadAttention(d_model, n_head, device).to(device)
        self.ln1 = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, d_model)
        self.ln2 = nn.LayerNorm(d_model)
        
    def forward(self, x):
        x1 = self.mha(x)
        x2 = self.ln1(x + x1)
        x3 = self.fc(x2)
        x4 = self.ln2(x3 + x2)
        
        return x4

class ProtBERT(nn.Module):
    def __init__(self, d_model, n_head, vocab_size, device):
        
        super(ProtBERT, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size+2, embedding_dim=d_model)
        self.trans = TransformerBlock(d_model, n_head, device).to(device)
        self.fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        #print(x)
        #x_embedding = torch.clone(x)
        x = self.trans(x)
        x = self.fc(x)
        
        return x

#### Model for Classification

In [3]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        '''in_dim: input layer dim
           hidden_dim: hidden layer dim
           out_dim: output layer dim'''
        
        super(MLP, self).__init__()
        
        self.flatten = nn.Flatten()
        
        #two fully connected layers
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        # flatten it first
        x = self.flatten(x)
        
        # compute output of fc1, and apply relu activation
        x = F.relu(self.fc1(x))
        
        # compute output layer
        # no activation: cross entropy will compute softmax
        x = self.fc2(x)
        return x

#### Define Dataset For classification

In [4]:
class ClsData(torch.utils.data.Dataset):
    def __init__(self, families, model, device):
        super(ClsData, self).__init__()
        
        self.data = []
        self.label = []
        
        for family in tqdm(families):
            for CLS in family:
                self.data.append(CLS)
                self.label.append(1)
                # randomly choose one negative sample
                neg_family = random.choice(families)
                neg_CLS = random.choice(neg_family)
                self.data.append(neg_CLS)
                self.label.append(0)
        
        self.data = torch.tensor(self.data, dtype=torch.float)
        self.label = torch.tensor(self.label, dtype=torch.float)
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

#### Set Up Loading

In [5]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
#filename = 'small_uniprot.txt'
filename = 'family_classification_sequences.csv'
file = open(filename, "r")
while(True):
    line = file.readline()
    if(line != 'sequence' and line[0] != '#'):
        break
sequences = file.read().rstrip()
vocabs = sorted(set(sequences.replace('\n','')))
vocab_to_idx = {vocab: index for index, (vocab) in enumerate(vocabs)}
vocab_to_idx['CLS'] = len(vocabs)
vocab_to_idx['MASK'] = len(vocabs) + 1

d_model = 256
n_head = 8
vocab_size = len(vocabs)
model = ProtBERT(d_model, n_head, vocab_size, device).to(device)

checkpoint = torch.load('checkpoint2.pth')
model.load_state_dict(checkpoint['state_dict'])

model.eval()

using device: cuda:0


ProtBERT(
  (embedding): Embedding(27, 256)
  (trans): TransformerBlock(
    (mha): MultiHeadAttention(
      (W_O): Linear(in_features=256, out_features=256, bias=True)
    )
    (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (fc): Linear(in_features=256, out_features=256, bias=True)
    (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (fc): Linear(in_features=256, out_features=25, bias=True)
)

In [6]:
def cls_train(cls_model, cls_optimizer, train_loader, epochs):
    cls_model = cls_model.to(device)
    cls_model.train()
    for epoch in range(1, epochs + 1):    
        sum_loss = 0.
        correct = 0
        total = 0
        for batch_idx, (data, label) in enumerate(train_loader):
            
            data, label = data.to(device), label.to(device)
            
            # zero out prev gradients
            cls_optimizer.zero_grad()
            
            # run the forward pass
            output = cls_model(data)
            #print(output)
            
            label = label.reshape((-1,1))
            # compute loss/error
            loss = F.binary_cross_entropy_with_logits(output, label)
            
            # compute training accuracy
            pred = torch.clone(output)
            pred = torch.sigmoid(pred)
            pred[pred < float(0.5)] = 0
            pred[pred != 0] = 1
            
            correct += torch.sum((pred == label))
            total += len(pred)
            #print(classification_report(label.to('cpu').tolist(), pred.to('cpu').tolist()))
            
            # sum up batch losses
            sum_loss += loss.item()

            # compute gradients and take a step
            loss.backward()
            cls_optimizer.step()
        # average loss per example
        sum_loss /= len(train_loader)
        print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}')
        
        if epoch == epochs:
            acc = correct / total
            print(f'Accuracy: {acc:.3f}')

#### Get the Data for classification

In [7]:
dataframe = {'Sequences':[]}
count = 0
length = set()
for sequence in tqdm(sequences.split('\n')):
    seq_idx = []
    seq_idx.append(vocab_to_idx['CLS'])
    for letter in sequence:
        seq_idx.append(vocab_to_idx[letter])
    
    CLS = None
    if count == 138252:
        length.add(len(seq_idx))
        #print(len(seq_idx))
        # torch.cuda.empty_cache()
        # print(torch.cuda.memory_summary(device=None, abbreviated=False))
        seq_idx = seq_idx[0:int(len(seq_idx)/2)]
        seq_tensor = torch.tensor(seq_idx, dtype=torch.int64).to(device)
        embedding = model.embedding(seq_tensor)
        CLS = model.trans(embedding)[0].tolist()
    else:
        seq_tensor = torch.tensor(seq_idx, dtype=torch.int64).to(device)
        embedding = model.embedding(seq_tensor)
        CLS = model.trans(embedding)[0].tolist()
        length.add(len(seq_idx))

    dataframe['Sequences'].append(CLS)
    count += 1

print(max(list(length)))

df1 = pd.DataFrame(dataframe)

cls_file = 'family_classification_metadata.csv'
df2 = pd.read_csv(cls_file)[['Family ID']]

df = pd.concat([df1, df2], join = 'outer', axis = 1)
grouped_df = df.groupby('Family ID')

100%|██████████| 324018/324018 [15:25<00:00, 350.23it/s] 


22153


#### Training Classification (BERT)

In [8]:
family = 0
families = []
cls_model = MLP(d_model, 256, 1)
cls_optimizer = optim.SGD(cls_model.parameters(), lr=0.05, momentum=0.9)
for key, item in grouped_df:
    
    if family % 1000 == 0 and family != 0:
        cls_dataset = ClsData(families, model, device)
        train_loader = data.DataLoader(dataset=cls_dataset, 
                        batch_size=500,
                        shuffle=True)
        epochs = 10
        print('First {} families:'.format(family))
        cls_train(cls_model, cls_optimizer, train_loader, epochs)
    
    families.append(list(grouped_df.get_group(key)['Sequences']))
    family += 1

cls_dataset = ClsData(families, model, device)
train_loader = data.DataLoader(dataset=cls_dataset, 
                        batch_size=200,
                        shuffle=True)
epochs = 10
print('The entire dataset:')
cls_train(cls_model, cls_optimizer, train_loader, epochs)

100%|██████████| 1000/1000 [00:00<00:00, 16493.07it/s]


First 1000 families:
Epoch: 1, Loss: 0.693108
Epoch: 2, Loss: 0.692583
Epoch: 3, Loss: 0.692621
Epoch: 4, Loss: 0.692853
Epoch: 5, Loss: 0.692764
Epoch: 6, Loss: 0.692666
Epoch: 7, Loss: 0.692468
Epoch: 8, Loss: 0.691601
Epoch: 9, Loss: 0.691193
Epoch: 10, Loss: 0.691485
Accuracy: 0.518


100%|██████████| 2000/2000 [00:00<00:00, 24337.45it/s]


First 2000 families:
Epoch: 1, Loss: 0.692331
Epoch: 2, Loss: 0.692464
Epoch: 3, Loss: 0.692096
Epoch: 4, Loss: 0.692104
Epoch: 5, Loss: 0.691789
Epoch: 6, Loss: 0.691936
Epoch: 7, Loss: 0.692004
Epoch: 8, Loss: 0.691661
Epoch: 9, Loss: 0.691482
Epoch: 10, Loss: 0.691855
Accuracy: 0.509


100%|██████████| 3000/3000 [00:00<00:00, 25540.50it/s]


First 3000 families:
Epoch: 1, Loss: 0.692652
Epoch: 2, Loss: 0.692582
Epoch: 3, Loss: 0.692510
Epoch: 4, Loss: 0.692473
Epoch: 5, Loss: 0.692388
Epoch: 6, Loss: 0.692373
Epoch: 7, Loss: 0.692510
Epoch: 8, Loss: 0.692487
Epoch: 9, Loss: 0.692382
Epoch: 10, Loss: 0.692488
Accuracy: 0.509


100%|██████████| 4000/4000 [00:00<00:00, 20433.33it/s]


First 4000 families:
Epoch: 1, Loss: 0.691879
Epoch: 2, Loss: 0.691753
Epoch: 3, Loss: 0.691778
Epoch: 4, Loss: 0.691835
Epoch: 5, Loss: 0.691792
Epoch: 6, Loss: 0.691826
Epoch: 7, Loss: 0.691425
Epoch: 8, Loss: 0.691288
Epoch: 9, Loss: 0.691608
Epoch: 10, Loss: 0.691792
Accuracy: 0.514


100%|██████████| 5000/5000 [00:00<00:00, 18286.58it/s]


First 5000 families:
Epoch: 1, Loss: 0.691512
Epoch: 2, Loss: 0.691715
Epoch: 3, Loss: 0.691225
Epoch: 4, Loss: 0.691139
Epoch: 5, Loss: 0.691317
Epoch: 6, Loss: 0.692046
Epoch: 7, Loss: 0.691579
Epoch: 8, Loss: 0.691342
Epoch: 9, Loss: 0.691630
Epoch: 10, Loss: 0.691851
Accuracy: 0.511


100%|██████████| 6000/6000 [00:00<00:00, 16513.73it/s]


First 6000 families:
Epoch: 1, Loss: 0.691607
Epoch: 2, Loss: 0.691604
Epoch: 3, Loss: 0.691370
Epoch: 4, Loss: 0.691804
Epoch: 5, Loss: 0.691630
Epoch: 6, Loss: 0.691492
Epoch: 7, Loss: 0.691683
Epoch: 8, Loss: 0.690732
Epoch: 9, Loss: 0.691491
Epoch: 10, Loss: 0.690384
Accuracy: 0.527


100%|██████████| 7000/7000 [00:00<00:00, 16970.15it/s]


First 7000 families:
Epoch: 1, Loss: 0.690639
Epoch: 2, Loss: 0.691137
Epoch: 3, Loss: 0.691282
Epoch: 4, Loss: 0.690734
Epoch: 5, Loss: 0.690928
Epoch: 6, Loss: 0.690856
Epoch: 7, Loss: 0.691585
Epoch: 8, Loss: 0.690978
Epoch: 9, Loss: 0.691265
Epoch: 10, Loss: 0.690791
Accuracy: 0.518


100%|██████████| 7027/7027 [00:00<00:00, 16796.00it/s]


The entire dataset:
Epoch: 1, Loss: 0.692001
Epoch: 2, Loss: 0.691962
Epoch: 3, Loss: 0.691636
Epoch: 4, Loss: 0.691757
Epoch: 5, Loss: 0.691654
Epoch: 6, Loss: 0.691534
Epoch: 7, Loss: 0.691693
Epoch: 8, Loss: 0.691790
Epoch: 9, Loss: 0.691816
Epoch: 10, Loss: 0.691718
Accuracy: 0.509


#### WordtoVec classification

In [9]:
class WordToVecData(torch.utils.data.Dataset):
    def __init__(self, families):
        super(WordToVecData, self).__init__()
        
        self.data = []
        self.label = []
        
        for family in tqdm(families):
            for vec in family:
                self.data.append(vec)
                self.label.append(1)
                # randomly choose one negative sample
                neg_family = random.choice(families)
                neg_vec = random.choice(neg_family)
                self.data.append(neg_vec)
                self.label.append(0)
        
        self.data = torch.tensor(self.data, dtype=torch.float)
        self.label = torch.tensor(self.label, dtype=torch.float)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

#### Training Classification (Word2Vec)

In [10]:
filename = 'family_classification_protVec.csv'
dataframe = {'Sequences':[]}
with open(filename, "r") as f:
    for line in f.readlines():
        line = line.strip()
        line = line.split(',')
        #print(line)
        line = list(map(float, line))
        dataframe['Sequences'].append(line)

df1 = pd.DataFrame(dataframe)
df = pd.concat([df1, df2], join = 'outer', axis = 1)
grouped_df = df.groupby('Family ID')

family = 0
families = []
cls_model = MLP(100, 256, 1)
cls_optimizer = optim.SGD(cls_model.parameters(), lr=0.05, momentum=0.9)

for key, item in grouped_df:
    
    if family % 1000 == 0 and family != 0:
        cls_dataset = WordToVecData(families)
        train_loader = data.DataLoader(dataset=cls_dataset, 
                        batch_size=500,
                        shuffle=True)
        epochs = 10
        cls_train(cls_model, cls_optimizer, train_loader, epochs)
    
    families.append(list(grouped_df.get_group(key)['Sequences']))
    family += 1

100%|██████████| 1000/1000 [00:00<00:00, 16598.55it/s]


Epoch: 1, Loss: 0.918050
Epoch: 2, Loss: 0.635993
Epoch: 3, Loss: 0.624709
Epoch: 4, Loss: 0.626691
Epoch: 5, Loss: 0.623777
Epoch: 6, Loss: 0.629583
Epoch: 7, Loss: 0.621729
Epoch: 8, Loss: 0.628340
Epoch: 9, Loss: 0.656839
Epoch: 10, Loss: 0.683181
Accuracy: 0.548


100%|██████████| 2000/2000 [00:00<00:00, 23280.35it/s]


Epoch: 1, Loss: 0.691830
Epoch: 2, Loss: 0.648510
Epoch: 3, Loss: 0.641624
Epoch: 4, Loss: 0.644351
Epoch: 5, Loss: 0.632373
Epoch: 6, Loss: 0.638298
Epoch: 7, Loss: 0.661344
Epoch: 8, Loss: 0.643872
Epoch: 9, Loss: 0.639670
Epoch: 10, Loss: 0.636381
Accuracy: 0.644


100%|██████████| 3000/3000 [00:00<00:00, 24932.46it/s]


Epoch: 1, Loss: 0.650589
Epoch: 2, Loss: 0.641074
Epoch: 3, Loss: 0.642069
Epoch: 4, Loss: 0.641184
Epoch: 5, Loss: 0.642199
Epoch: 6, Loss: 0.639744
Epoch: 7, Loss: 0.639338
Epoch: 8, Loss: 0.639199
Epoch: 9, Loss: 0.647120
Epoch: 10, Loss: 0.638023
Accuracy: 0.635


100%|██████████| 4000/4000 [00:00<00:00, 20221.29it/s]


Epoch: 1, Loss: 0.645117
Epoch: 2, Loss: 0.641550
Epoch: 3, Loss: 0.639623
Epoch: 4, Loss: 0.636080
Epoch: 5, Loss: 0.638415
Epoch: 6, Loss: 0.640109
Epoch: 7, Loss: 0.665185
Epoch: 8, Loss: 0.638915
Epoch: 9, Loss: 0.637107
Epoch: 10, Loss: 0.665449
Accuracy: 0.579


100%|██████████| 5000/5000 [00:00<00:00, 18368.77it/s]


Epoch: 1, Loss: 0.680393
Epoch: 2, Loss: 0.667234
Epoch: 3, Loss: 0.686472
Epoch: 4, Loss: 0.691886
Epoch: 5, Loss: 0.690515
Epoch: 6, Loss: 0.689566
Epoch: 7, Loss: 0.688043
Epoch: 8, Loss: 0.688127
Epoch: 9, Loss: 0.688282
Epoch: 10, Loss: 0.685242
Accuracy: 0.541


100%|██████████| 6000/6000 [00:00<00:00, 16406.48it/s]


Epoch: 1, Loss: 0.671846
Epoch: 2, Loss: 0.648210
Epoch: 3, Loss: 0.652663
Epoch: 4, Loss: 0.641747
Epoch: 5, Loss: 0.642654
Epoch: 6, Loss: 0.648403
Epoch: 7, Loss: 0.646867
Epoch: 8, Loss: 0.647661
Epoch: 9, Loss: 0.645399
Epoch: 10, Loss: 0.640117
Accuracy: 0.642


100%|██████████| 7000/7000 [00:00<00:00, 16722.20it/s]


Epoch: 1, Loss: 0.637734
Epoch: 2, Loss: 0.633219
Epoch: 3, Loss: 0.643028
Epoch: 4, Loss: 0.639594
Epoch: 5, Loss: 0.642396
Epoch: 6, Loss: 0.642078
Epoch: 7, Loss: 0.641197
Epoch: 8, Loss: 0.649333
Epoch: 9, Loss: 0.641734
Epoch: 10, Loss: 0.646760
Accuracy: 0.626
