In [3]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence

In [4]:
class PROTEIN2VEC(nn.Module):
    def __init__(self, input_dim, hidden_dim, fc1_dim, fc2_dim, fc3_dim, fc4_dim, dropout):
        super().__init__()
        
        self.rnnA = nn.LSTM(input_dim, hidden_dim)
        self.rnnB = nn.LSTM(input_dim, hidden_dim)
        
        self.fc1 = nn.Linear(hidden_dim * 2, fc1_dim)
        self.fc2 = nn.Linear(fc1_dim, fc2_dim)
        self.fc3 = nn.Linear(fc2_dim, fc3_dim)
        self.fc4 = nn.Linear(fc3_dim, fc4_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, emb_proteinA, emb_proteinB, seq_lengths_A, seq_lengths_B):
        
        emb_proteinA_pad = pack_padded_sequence(emb_proteinA, seq_lengths_A, batch_first=True, enforce_sorted = False)
        emb_proteinB_pad = pack_padded_sequence(emb_proteinB, seq_lengths_A, batch_first=True, enforce_sorted = False)
        
        packed_outputA, (hiddenA, cellA) = self.rnnA(emb_proteinA_pad)
        packed_outputA, (hiddenB, cellB) = self.rnnB(emb_proteinB_pad)
        
        relu = nn.ReLU()
        
        fc1_output = self.fc1( self.dropout(torch.cat((hiddenA.squeeze(0), hiddenB.squeeze(0)), 1)) )
        fc2_output = self.fc2( self.dropout( relu(fc1_output) ))
        fc3_output = self.fc3( self.dropout( relu(fc2_output) ))
        fc4_output = self.fc4( self.dropout( relu(fc3_output) ))
        
        return fc4_output

In [6]:
class PROTEIN2VEC_SHARED(nn.Module):
    def __init__(self, input_dim, hidden_dim, fc1_dim, fc2_dim, fc3_dim, fc4_dim, dropout):
        super().__init__()
        
        self.rnnA = nn.LSTM(input_dim, hidden_dim)
        
        self.fc1 = nn.Linear(hidden_dim * 2, fc1_dim)
        self.fc2 = nn.Linear(fc1_dim, fc2_dim)
        self.fc3 = nn.Linear(fc2_dim, fc3_dim)
        self.fc4 = nn.Linear(fc3_dim, fc4_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, emb_proteinA, emb_proteinB, seq_lengths_A, seq_lengths_B):
        
        
        emb_proteinA_pad = pack_padded_sequence(emb_proteinA, seq_lengths_A, batch_first=True, enforce_sorted = False)
        emb_proteinB_pad = pack_padded_sequence(emb_proteinB, seq_lengths_A, batch_first=True, enforce_sorted = False)
        
        packed_outputA, (hiddenA, cellA) = self.rnnA(emb_proteinA_pad)
        packed_outputA, (hiddenB, cellB) = self.rnnA(emb_proteinB_pad)
        
        relu = nn.ReLU()
        
        fc1_output = self.fc1( self.dropout(torch.cat((hiddenA.squeeze(0), hiddenB.squeeze(0)), 1)) )
        fc2_output = self.fc2( self.dropout( relu(fc1_output) ))
        fc3_output = self.fc3( self.dropout( relu(fc2_output) ))
        fc4_output = self.fc4( self.dropout( relu(fc3_output) ))
        
        return fc4_output