In [1]:
import sys
import argparse
import numpy as np
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.utils.data as du
import torch.optim as optim
import einops
from tqdm import tqdm
import time

In [2]:
class Sequences:
    '''
    This class reads in the sequences, their family id, and 
    the protvec embeddings (one per sequence)
    '''
    def __init__(self, seq_fname, label_fname, protvec_fname,
                 block_size, alphabet_idx, seed=42):
        '''seq_fname: input sequence file -- family_classification_sequences.csv
           label_fname: family id info -- family_classification_metadata.tab
           protvec_fname: protvec file -- family_classification_protVec.csv
           block_size: the sequence length used for bert training (1000)
           alphabet_idx: the mapping from alphabet to token ids
           seed: random seed for determinsitic data generation'''
        
        super(Sequences, self).__init__()
             
        np.random.seed(seed) # set the random seed
        
        self.alphabet_idx = alphabet_idx
        self.block_size = block_size
        
        self.sequences = []  # set of sequences
        self.labels = [] # family ids
        self.protVec = [] # protvecs 
        
        '''read protvecs, one per sequence'''
        with open(protvec_fname, "r") as f:
            for i, line in enumerate(f.readlines()):
                a = line.strip().split(',')
                self.protVec.append([float(v) for v in a])
        self.protVec = np.array(self.protVec, dtype=np.float32)
        
        '''read protein sequences from file''' 
        with open(seq_fname, "r") as f:
            for i, line in enumerate(f.readlines()):
                a = line.strip()
                if i == 0:
                    continue
                else:
                    #truncate sequences to block_size-1
                    #since first token is CLS
                    self.sequences.append(a[:block_size-1])
                    
        '''read family ids'''
        with open(label_fname, "r") as f:
            for i, line in enumerate(f.readlines()):
                a = line.strip()
                if i == 0:
                    continue
                else:
                    a = a.split("\t")
                    self.labels.append(a[3].strip('"'))
        self.labels = np.array(self.labels)
        
        # sort families by frequency
        fam_cnt = defaultdict(int)
        for v in self.labels:
            fam_cnt[v] += 1
        self.fam_lst = sorted(fam_cnt.items(), 
                              reverse=True, key=lambda item: item[1])
        
    def get_fam(self, fam_id):
        '''return all indices belonging to fam_id'''
        idxs = np.where(self.labels == fam_id)[0]
        return idxs
    
    def get_neg_fam(self, fam_id, seed):
        '''sample negative instances for fam_id
           return indices of negative class sequences'''
        np.random.seed(seed) # set the random seed
        pos_idxs = np.where(self.labels == fam_id)[0]
        neg_idxs = np.random.choice(len(self.sequences), len(pos_idxs))
        # make sure there is no overlap between pos and neg classes
        neg_idxs = [idx for idx in neg_idxs if idx not in pos_idxs]
        return neg_idxs
    
    def tokenize_and_pad(self, idx):
        '''return tokenized sequence at idx
        pad it if necessary to be block_size length
        assumes 'PAD', 'CLS' are the tokens from BERT training'''
        tokenized_seq = [self.alphabet_idx['CLS']]
        #actual AA sequence
        S = self.sequences[idx]
        tokenized_seq.extend([self.alphabet_idx[S[i]]  
                              if S[i] in self.alphabet_idx 
                              else self.alphabet_idx['PAD']
                              for i in range(len(S))])
        #PAD as remaining elements
        pad_len = self.block_size - len(tokenized_seq)
        tokenized_seq.extend([self.alphabet_idx['PAD'] 
                              for i in range(pad_len)])
        return tokenized_seq
    
    def get_protvec(self, idx):
        '''return protvec at idx'''
        return self.protVec[idx]

In [10]:
class Fam_Dataset(Dataset):
    '''
        This class creates a dataset for the given fam_id.
        getitem returns the CLS vector, protVec vector and label per idx
    '''
    def __init__(self, S, fam_id, transformer, device, seed=42):
        '''S: an instance of the Sequences class
           fam_id: create a dataset for fam_id (incude pos & neg instances)
           transformer: transformer model (to extract CLS embeddings)
        '''
        super(Fam_Dataset, self).__init__()

        self.seed = seed
        self.X = [] # tokenized sequences for transformer
        self.V = [] # protvec embeddings
        self.y = [] # 1/0 for pos/neg class label per sequence
        
        # positive instances
        fam_idx = S.get_fam(fam_id)       
        for idx in fam_idx:
            x = torch.tensor(S.tokenize_and_pad(idx))
            x = x.unsqueeze(0)
            x = x.to(device)
            x = transformer(x)
            # extract the embedding for CLS token at pos 0
            x = torch.squeeze(x[:,0,:])
            
            self.X.append(x.cpu().numpy())
            self.V.append(S.get_protvec(idx))
            self.y.append(1.)
        self.pos_sz = len(fam_idx)

        # negative instances
        neg_fam_idx = S.get_neg_fam(fam_id, seed=self.seed)
        #print("NEG", neg_fam_idx[:5])
        for idx in neg_fam_idx:
            x = torch.tensor(S.tokenize_and_pad(idx))
            x = x.unsqueeze(0)
            x = x.to(device)
            x = transformer(x)
            # extract the embedding for CLS token at pos 0
            x = torch.squeeze(x[:,0,:])
            self.X.append(x.cpu().numpy())
            self.V.append(S.get_protvec(idx))
            self.y.append(0.)
        self.neg_sz = len(neg_fam_idx)
        
        self.X = np.array(self.X)
        self.V = np.array(self.V)
        self.y = np.array(self.y)
            
    def __len__(self):
        '''return len of dataset'''
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        '''return X, V, y at idx'''
        return self.X[idx], self.V[idx], self.y[idx]

In [3]:
class SelfAttention(nn.Module):
    '''Self Attention'''

    def __init__(self, d, dk):
        '''define WQ, WK, WV projection matrices:
        d: d_model is the original model dimension
        dk: projection dimension for query, keys and values
        '''
        super(SelfAttention, self).__init__()
        self.d = d  # d_model
        self.dk = dk  # d_k: projection dimension
        self.WQ = nn.Linear(self.d, self.dk, bias=False)
        self.WK = nn.Linear(self.d, self.dk, bias=False)
        self.WV = nn.Linear(self.d, self.dk, bias=False)

    def forward(self, x):
        '''project the context onto key, query and value spaces and
        return the final value vectors
        '''
        # input shape: (batch_size, block_size, d)
        # let batch_size=b, block_size=l, num_heads=h
        Q = self.WQ(x)  # shape: b, l, dk
        K = self.WK(x)  # shape: b, l, dk
        V = self.WV(x)  # shape: b, l, dk

        K = torch.transpose(K, 1, 2)  # K.T transpose
        QKT = torch.bmm(Q, K)  # shape: b, l, l

        # attention matrix
        # row specifies weights for the value vectors, row add up to one
        A = F.softmax(QKT / np.sqrt(self.dk), dim=2)  # shape: b, l, l

        V = torch.bmm(A, V)  # shape: b, l, dk
        return V


class SepHeads_SelfAttention(nn.Module):
    '''Separate Headed Self Attention: List of Attention Heads
    This is a straightforward implementation of the multiple heads.
    We have separate WQ, WK and WV matrices, one per head.'''

    def __init__(self, d, dk, num_heads):
        '''create separate heads:
        d: d_model dimension
        dk: projection dimension for query, keys and values
        num_heads: number of attention heads
        '''
        super(SepHeads_SelfAttention, self).__init__()
        self.d = d  # d_model
        self.dk = dk  # d_k: projection dimension
        self.num_heads = num_heads  # number of attention heads

        self.sa_layers = nn.ModuleList()
        for i in range(self.num_heads):
            self.sa_layers.append(SelfAttention(self.d, self.dk))

        self.WO = nn.Linear(self.dk * self.num_heads, self.d, bias=False)

    def forward(self, x):
        '''use separate attention heads, and concat values'''
        # input shape: (batch_size, block_size, d)
        # let batch_size=b, block_size=l, num_heads=h
        V = []
        for i in range(self.num_heads):
            V.append(self.sa_layers[i](x))

        # concat all the value vectors from the heads
        V = torch.cat(V, dim=2)  # shape: b, l, h x dk
        # project back to d_model
        x = self.WO(V)  # shape: b, l, d
        return x


class MultiHead_SelfAttention(nn.Module):
    '''Multi Headed Self Attention:
    Instead of using a list of attention heads with separate WQ, WK, WV matrices,
    we combine all heads into one, and use a single WQ, WK and WV matrix.
    Each matrix maps the d-dim input block into h*dk dim space, where h is num_heads.
    We have to carefully keep the heads separate for softmax to achieve the same
    effect at from the list of heads. We do that via einops and the very useful
    torch.einsum function.
    
    This function is much more efficient than using separate heads.
    '''

    def __init__(self, d, dk, num_heads):
        '''create multi-heads -- joint heads:
        d: d_model dimension
        dk: projection dimension for query, keys and values
        num_heads: number of attention heads
        '''
        super(MultiHead_SelfAttention, self).__init__()
        self.d = d  # d_model
        self.dk = dk  # d_k: projection dimension
        self.num_heads = num_heads  # number of attention heads

        self.WQ = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WK = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WV = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WO = nn.Linear(self.dk * self.num_heads, self.d, bias=False)

    def forward(self, x):
        # input shape: (batch_size, block_size, d)
        # let batch_size=b, block_size=l, num_heads=h, d_model=d
        Q = self.WQ(x)  # size: (b, l, h*dk)
        K = self.WK(x)  # size: (b, l, h*dk)
        V = self.WV(x)  # size: (b, l, h*dk)

        # split Q, K, V into heads and dk, move heads up front; KT is transpose of K
        Q = einops.rearrange(
            Q, 'b l (h dk)-> b h l dk', h=self.num_heads
        )  # size: (b, h, l, dk)
        KT = einops.rearrange(
            K, 'b l (h dk)-> b h dk l', h=self.num_heads
        )  # size: (b, h, dk, l)
        V = einops.rearrange(
            V, 'b l (h dk)-> b h l dk', h=self.num_heads
        )  # size: (b, h, l, dk)

        # compute Q x K.T, output is (b, h, l, l)
        QKT = torch.einsum('bhik,bhkj->bhij', Q, KT)
        A = F.softmax(QKT / np.sqrt(self.dk), dim=3)  # softmax along last dim

        # new value representation
        V = torch.einsum('bhik,bhkj->bhij', A, V)  # size: (b, h, l, dk)
        V = einops.rearrange(V, 'b h l dk -> b l (h dk)')  # size: (b, l, h*dk)

        # shape: b, l, h x dk
        x = self.WO(V)  # shape: b, l, d
        return x


class TransformerBlock(nn.Module):
    '''Transformer Block: multi-head or separate heads of attention,
    followed by layernorm, ffn, and another layernorm
    '''

    def __init__(self, d, dk, num_heads, block_size, use_sepheads):
        '''
        d: d_model dimension
        dk: projection dimension
        num_heads: number of attention heads
        use_sepheads: use separate heads or multiheads,
                      multiheads is much more efficient
        '''
        super(TransformerBlock, self).__init__()
        self.use_sepheads = use_sepheads
        self.drop_prob = 0.1

        if self.use_sepheads:
            # uses for loop for separate heads
            self.mhsa = SepHeads_SelfAttention(d, dk, num_heads)
        else:
            # this is more efficient
            self.mhsa = MultiHead_SelfAttention(d, dk, num_heads)

        self.ln1 = nn.LayerNorm(d)  # layer norm
        self.ffn = nn.Sequential( #FFN module
            nn.Linear(d, d),  # linear layer
            nn.ReLU(),  # relu
            nn.Linear(d, d)  # linear layer
        )
        self.ln2 = nn.LayerNorm(d)  # layer norm

    def forward(self, x):
        # input shape: (batch_size, block_size, d)
        # let batch_size=b, block_size=l, num_heads=h, d_model=d
        x_sa = self.mhsa(x)  # multiple attention heads
        x_sa = F.dropout(x_sa, p=self.drop_prob)
        x_ln1 = self.ln1(x + x_sa)  # residual layer + layer norm
        # two linear layers with relu in between
        x_ffn = self.ffn(x_ln1)
        x_ffn = F.dropout(x_ffn, p=self.drop_prob)
        x_ln2 = self.ln2(x_ln1 + x_ffn)  # residual layer + layer norm
        return x_ln2


class Transformer(nn.Module):
    '''Transformer model:
    input is a block of tokens: first token is always CLS
    MASK token for positions for training the masked language model
    PAD tokens at end for sequences shorter than block size'''

    def __init__(
        self, d, dk, block_size, num_layers, num_heads, alphabet_idx, use_sepheads
    ):
        '''
        d: d_model dimension
        dk: projection dimension
        block_size: the max sequence length
        num_layers: how many transformer blocks/layers?
        num_heads: number of attention heads
        alphabet_idx: dict of tokens to token ids (ints)
        use_sepheads: use separate heads or joint heads (multiheads),
                      multiheads is much more efficient

        '''
        super(Transformer, self).__init__()
        self.num_layers = num_layers
        self.drop_prob = 0.1  # for dropout layer

        # embedding layer to map tokens to d dim vectors
        self.embed = nn.Embedding(len(alphabet_idx), d, padding_idx=alphabet_idx['PAD'])

        # learnable position embeddings, one per sequence element in block
        # can also use sine/cosine embeddings: not done here!
        self.pos_embed = nn.Embedding(block_size, d)

        # list of transformer blocks/layers
        tb_list = [
            TransformerBlock(d, dk, num_heads, block_size, use_sepheads)
            for i in range(self.num_layers)
        ]
        # combine all layers into one "sequential" layer
        self.layers = nn.Sequential(*tb_list)

    def forward(self, x):
        # input shape: batch_size (b), block_size (l)
        # d is d_model
        p = self.pos_embed.weight  # shape: l, d
        x = self.embed(x) + p  # add pos embeddings, shape: b, l, d
        x = F.dropout(x, p=self.drop_prob)  # dropout
        x = self.layers(x)  # shape: (b, l, d)
        return x

In [4]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout):
        '''in_dim: input layer dim
           hidden_dim: hidden layer dim
           out_dim: output layer dim
           dropout: dropout probability
           '''
        
        super(MLP, self).__init__()
        self.dropout = dropout
        
        #two fully connected layers
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        # compute output of fc1, and apply relu activation, followed by dropout
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout)           
       # compute output layer
        x = self.fc2(x)
        
        return x

In [5]:
def train_model(indim, batch_size, learning_rate, use_protvec):
    model = MLP(indim, 64, 1, dropout)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate,
                          weight_decay=weight_decay)

    # load training and validation data in batches
    data_loader = du.DataLoader(fam_data,
                                 batch_size=batch_size,
                                 shuffle=True)

    # send model over to device
    model = model.to(device)
    
    for epoch in range(1, epochs + 1):
        sum_loss = 0.
        correct = 0.
        for batch_idx, (data, protvec, target) in enumerate(data_loader):
            # send batch over to device
            data = data.to(device)
            protvec = protvec.to(device)
            target = target.to(device)
            # zero out prev gradients
            optimizer.zero_grad()

            # run the forward pass
            if use_protvec:
                output = model(protvec).flatten()
            else:
                output = model(data).flatten()
            # compute loss/error with weight per sample
            loss = F.binary_cross_entropy_with_logits(
                    output, target)
            sum_loss += loss.item()

            #compute training accuracy
            correct += compute_correct(output, target)

            # compute gradients and take a step
            loss.backward()
            optimizer.step()

        # average loss per example
        sum_loss /= (batch_idx+1)
        train_acc = correct / len(fam_data)
        #print(f'Epoch: {epoch}, Loss: {sum_loss:.6e}, Acc: {train_acc:.4f}') 
    return train_acc

In [6]:
def compute_correct(output, target):
    '''first apply sigmoid and predict class 1 if >= 0.5, 0 otherwise
    '''
    #use logsigmoid for log space computations
    output = F.logsigmoid(output.detach())
    pred = torch.where(output > F.logsigmoid(torch.tensor(0.5)), 
                       1, 0)

    # add up weights of correct predictions
    correct = torch.sum(pred == target)
    
    return correct.item()

In [18]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
model_fname = 'transformer_d256_dk32_l2_h8_lr0.0001_e10_j1700463.pth'
saveinfo = torch.load(model_fname)
d = saveinfo['d']
dk = saveinfo['dk']
num_layers = saveinfo['l']
num_heads = saveinfo['h']
block_size = saveinfo['block_size']
alphabet_idx = saveinfo['alphabet_idx']
use_sepheads=False

transformer = Transformer(d, dk, block_size, num_layers, 
                    num_heads, alphabet_idx, use_sepheads)
transformer.load_state_dict(saveinfo['model'])

#freeze the model
for param in transformer.parameters():
    param.requires_grad = False

transformer.to(device)
transformer.eval()

using device: cuda:0


Transformer(
  (embed): Embedding(28, 256, padding_idx=27)
  (pos_embed): Embedding(1000, 256)
  (layers): Sequential(
    (0): TransformerBlock(
      (mhsa): MultiHead_SelfAttention(
        (WQ): Linear(in_features=256, out_features=256, bias=False)
        (WK): Linear(in_features=256, out_features=256, bias=False)
        (WV): Linear(in_features=256, out_features=256, bias=False)
        (WO): Linear(in_features=256, out_features=256, bias=False)
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=256, bias=True)
      )
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (mhsa): MultiHead_SelfAttention(
        (WQ): Linear(in_features=256, out_features=256, bias=False)
        (WK): Linear(in_features=256, out_features=256, bias=False)
        

In [12]:
S = Sequences('family_classification_sequences.csv', 
              'family_classification_metadata.tab', 
              'family_classification_protVec.csv',
              block_size, alphabet_idx)

In [19]:
batch_size = 512
epochs = 50
weight_decay = 0.
dropout = 0.5
d_protvec = 100
    
st = time.time()
ACC = []
for fidx, (fam_id, fam_cnt) in enumerate(S.fam_lst):
    if fidx > 19:
        break
    st_time = time.time()
    fam_data = Fam_Dataset(S, fam_id, transformer, device)
    en_time = time.time()
    fam_time = en_time - st_time
    T_acc = train_model(d, batch_size, 0.001, use_protvec = False)
    P_acc = train_model(d_protvec, batch_size, 0.01, use_protvec=True)
    en_time = time.time()
    print("ACC", fam_id, T_acc, P_acc, fam_time, en_time-st_time)
    ACC.append((fam_id, T_acc, P_acc))

en = time.time()
print("total time", en-st)

ACC MMR_HSR1 0.8547677261613692 0.9426242868785656 10.541239261627197 17.729961395263672
ACC Helicase_C 0.8875598086124402 0.6915869218500797 13.924832105636597 19.231462240219116
ACC ATP-synt_ab 0.9354092152324848 0.9671786240269303 8.295345783233643 13.140318870544434
ACC 7tm_1 0.9387417218543046 0.8887969094922737 6.207003593444824 9.908357381820679
ACC AA_kinase 0.8973623853211009 0.8864678899082569 6.003232717514038 9.422768592834473
ACC AAA 0.8535871156661786 0.8459736456808199 6.043012619018555 9.510087490081787
ACC tRNA-synt_1 0.9671879791475008 0.9546151487273843 5.704678773880005 8.96895456314087
ACC tRNA-synt_2 0.8684303350970017 0.8719576719576719 4.953840017318726 7.712321519851685
ACC MFS_1 0.9546153846153846 0.936923076923077 4.522503614425659 7.1766767501831055
ACC HSP70 0.9376971608832808 0.9684542586750788 4.395656108856201 6.82806921005249
ACC Oxidored_q1 0.9759711653984782 0.9763716459751702 4.362226247787476 6.778350830078125
ACC His_biosynth 0.9405861099959856 0.9

In [20]:
Tacc = np.array([float(t) for (f,t,p) in ACC])
Pacc = np.array([float(p) for (f,t,p) in ACC])
#print(Tacc)
print("ACC", Tacc.mean(), Tacc.std(), Pacc.mean(), Pacc.std())

ACC 0.9152775178168702 0.04191850109283851 0.9102803079509083 0.07204960254440099
