In [2]:
import import_ipynb
from harvard_transformer import *

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

importing Jupyter notebook from harvard_transformer.ipynb


In [3]:
class TransformerGO(nn.Module):
    def __init__(self ,d_model, nhead, num_layers, dim_feedforward, dropout = 0.1):
        super().__init__()
        
        #encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout = dropout, dim_feedforward = dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        #decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dropout = dropout, dim_feedforward = dim_feedforward)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        #last linear layer
        self.linear = nn.Linear(d_model, 1)
    
    #batch  * max_seq_len * node2vec_dim
    def forward(self, emb_proteinA, emb_proteinB, protA_mask, protB_mask):
        
        memory = self.transformer_encoder(emb_proteinA, src_key_padding_mask = protA_mask)
        output = self.transformer_decoder(emb_proteinB, memory, memory_key_padding_mask = protA_mask, tgt_key_padding_mask = protB_mask)
        #output: seqLen * batch * embDim
        
        #transform B * seqLen * node2vec_dim --> B * node2vec_dim (TransformerCPI paper)
        output = output.permute(1,0,2) 
        output_c = torch.linalg.norm(output, dim = 2)
        output_c = F.softmax(output_c, dim = 1).unsqueeze(1)
        output = torch.bmm(output_c, output)
        
        return self.linear(output).squeeze(1)

In [21]:
class TransformerGO_Scratch(nn.Module):
    def __init__(self ,d_model, nhead, num_layers, dim_feedforward, dropout = 0.1):
        super().__init__()
        
        c = copy.deepcopy
        attn = MultiHeadedAttention(nhead, d_model, dropout)
        ff = PositionwiseFeedForward(d_model, dim_feedforward, dropout)
        
        self.encoder = Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), num_layers)
        self.decoder = Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), num_layers)

        self.linear = nn.Linear(d_model, 1)
    
    #batch  * max_seq_len * node2vec_dim
    def forward(self, emb_proteinA, emb_proteinB, protA_mask, protB_mask):
        
        memory = self.encoder(emb_proteinA, protA_mask)
        output = self.decoder(emb_proteinB, memory, protA_mask, protB_mask)
        #output: batch * seqLen * embDim
        
        #transform B * seqLen * node2vec_dim --> B * node2vec_dim (TransformerCPI paper)
        output_c = torch.linalg.norm(output, dim = 2)
        output_c = F.softmax(output_c, dim = 1).unsqueeze(1)
        output = torch.bmm(output_c, output)
        
        return self.linear(output).squeeze(1)

In [6]:
class GO_Sum_NN(nn.Module):
    def __init__(self, input_dim, fc1_dim, fc2_dim, fc3_dim, fc4_dim, dropout):
        super().__init__()
        
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.fc1 = nn.Linear(input_dim * 2, fc1_dim)
        self.fc2 = nn.Linear(fc1_dim, fc2_dim)
        self.fc3 = nn.Linear(fc2_dim, fc3_dim)
        self.fc4 = nn.Linear(fc3_dim, fc4_dim)
        self.dropout = nn.Dropout(dropout)
    
    # emb_proteinA = (batch, max_seq_len, node2vec_dim)
    #masks not used (adding 'empty' GO terms), addded only for downstream compatibility
    def forward(self, emb_proteinA, emb_proteinB, maskA, maskB):
        
        emb_proteinA = torch.sum(emb_proteinA, dim = 1)
        emb_proteinB = torch.sum(emb_proteinB, dim = 1)
        
        relu = nn.ReLU()        
        fc1_output = self.fc1( torch.cat((emb_proteinA, emb_proteinB), dim=1) )
        fc2_output = self.fc2( self.dropout( relu(fc1_output) ))
        fc3_output = self.fc3( self.dropout( relu(fc2_output) ))
        fc4_output = self.fc4( self.dropout( relu(fc3_output) ))
        
        return fc4_output