In [1]:
from Bio import SeqIO
import torch     
import numpy as np
from torch.utils import data
import torch.nn as nn

In [2]:
### preprocess the fasta data, keep the correspondence in the dictionary fasta_data
### file_in should be the input file of the real data, for our case it is the "zika-fasta1.fa"
file_in = "zika-fasta1.fa"
fasta_sequences = SeqIO.parse(open(file_in),'fasta')
fasta_data = {}
seq_string = []
for fasta in fasta_sequences:
    name, sequence = fasta.id, str(fasta.seq)
    fasta_data[name] = sequence
    seq_string.append(sequence)

In [4]:
###convert the sequence data to testing input
# extracting the number of samples
n_samples = 1
# extracting the sequence lenghth
seq_length = len(seq_string[0])-1

# function to convert string to numbers
def convert_string_to_numbers(str, dict):
    ''' str: is the string to convert,
        dict: dictionary with the relative ordering of each char'''

    # create a map iterator using a lambda function
    numbers = map(lambda x: dict[x], str)

    return np.fromiter(numbers, dtype=np.int64)

# We need to extract the dictionary with the relative positions
# fo each aminoacid

# first we need to extract all the different chars
strL = ""
for c in seq_string[0][:-1]:
    if not c in strL:
        strL += c

# we sort them
strL = sorted(strL)

# we give them a relative order
dict_amino = {}
for ii, c in enumerate(strL):
    dict_amino[c] = ii
    
mats = np.zeros((len(seq_string), seq_length), dtype = np.int64)

# this is pretty slow (optimize in numba)
for ii, seq in enumerate(seq_string):
    # note each line has a \n character at the end so we remove it
    mats[ii,:] = convert_string_to_numbers(seq[:-1], dict_amino).reshape((1,seq_length))


mats = mats.reshape((n_samples, -1, seq_length))  
trunc_length = 1550
inputTest  = torch.from_numpy(mats[-4:, :,:trunc_length])
datasetTest = data.TensorDataset(inputTest) 
dataloaderTest = torch.utils.data.DataLoader(datasetTest,
                                             batch_size=16,
                                             shuffle=True)

In [7]:
##load the model

class _ResidueModule(torch.nn.Module):

    def __init__(self, channel_count):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv1d(channel_count, channel_count, 1),
            torch.nn.BatchNorm1d(channel_count),
            torch.nn.ReLU(),
            torch.nn.Conv1d(channel_count, channel_count, 1),
            torch.nn.BatchNorm1d(channel_count),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        return x + self.layers(x)


class _DoubleEmbedding(torch.nn.Module):
    # we use first an embedding for the 

    def __init__(self, length_dict, embeding_dim, trunc_length = 1550):
        super().__init__()
        
        self.embedding_layer = nn.Embedding(length_dict, embeding_dim)
        self._res_module_1 = _ResidueModule(embeding_dim)
        self._res_module_2 = _ResidueModule(embeding_dim)

    def forward(self, x):
        # (none, 4, 1550)

        # input will be not hot-encoded 
        x = self.embedding_layer(x).permute([0, 1, 3, 2])
        # (none, 4, 1550, chn_dim) without permute
        # (none, 4, chn_dim, 1550) with permutation

        d0 =  self._res_module_1(x[:,0,:,:])
        d1 =  self._res_module_1(x[:,1,:,:])     
        d2 =  self._res_module_1(x[:,2,:,:])     
        d3 =  self._res_module_1(x[:,3,:,:])   

        # Quartet 1 (12|34)
        # d01 = d0 + d1
        d01 = self._res_module_2(d0 + d1)

        # d23 = d2 + d3
        d23 = self._res_module_2(d2 + d3)

        G1 = d01 + d23

        #Quartet 2 (13|24)
        # d02 = d0 + d2
        d02 = self._res_module_2(d0 + d2)

        # d13 = d1 + d3
        d13 = self._res_module_2(d1 + d3)

        # F56 = F5 + F6
        G2 = d02 + d13

        # Quartet 3 (14|23)
        # d03 = d0 + d3
        d03 = self._res_module_2(d0 + d3)

        # d12 = d1 + d2
        d12 = self._res_module_2(d1 + d2)

        # F34 = F3 + F4
        G3 = d03 + d12

        x = torch.cat([torch.unsqueeze(G1,1), 
                       torch.unsqueeze(G2,1), 
                       torch.unsqueeze(G3,1)], dim = 1)

        # (none, 3, emb_dim, 1550)
        return x


class _Model(torch.nn.Module):
    """A neural network model to predict phylogenetic trees."""

    def __init__(self, embeding_dim = 80, hidden_dim = 20, 
                      num_layers = 3, output_size = 20, 
                      dropout = 0.0):
        """Create a neural network model."""
        super().__init__()

        self.embedding_layer = _DoubleEmbedding(20, embeding_dim)
        self.hidden_dim = hidden_dim
        self.output_size = output_size

        self.classifier = torch.nn.Linear(self.output_size, 1)
        self.rnn = nn.LSTM(embeding_dim, hidden_dim, 
                           num_layers, dropout=dropout,
                           batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, self.output_size)

        # flatenning the parameters
        self.rnn.flatten_parameters()


    def forward(self, x):
        """ Function that infers the phylogenetic trees for the input sequences.
        Input: x the raw sequences
        Output: the scores for each topology
        """
        device = x.device
        batch_size = x.size()[0]
        
        # this is the structure preserving embedding
        g =  self.embedding_layer(x)
        
        x1 = g[:,0,:,:]
        x2 = g[:,1,:,:]
        x3 = g[:,2,:,:]
        # (none,embeding_dim, 1550)

        # contanenation in the batch dimesion
        # (3*none, 80, 1550)
        X = torch.cat([x1, x2, x3], dim  = 0)

        # (3*none, 1550, hidden_dim)
        r_output, hidden = self.rnn(X.permute([0, 2, 1]))

        # (3*none, hidden_dim)
        r_output_last = r_output[:, -1, :] 

        # not sure if this helps
        out = r_output_last.contiguous().view(-1, self.hidden_dim)
        
        # (3*none, out_put_dimensions)
        output = self.fc(out)


        X_combined = self.classifier(output) 
        # (3*none, 1)

        X_combined = X_combined.view(3,batch_size)

        return torch.permute(X_combined, [1, 0])


In [11]:
### load the NN model and run the test
### the best_path should be the parameters that you saved by the training process, for our case, it is "resultsaved_TrainOptLSTM_trainoptlstm_lr_0.001_batch_16_lba_best.pth"
best_path = "resultsaved_TrainOptLSTM_trainoptlstm_lr_0.001_batch_16_lba_best.pth"
device = torch.device("cpu")


# define the model
model = _Model(dropout = 0.2).to(device)

# load the weights
model.load_state_dict(torch.load(best_path,map_location ='cpu'))
model.eval()
#load the data
inputTest = inputTest.to(device)
print(inputTest.shape)
quartetsNN = model(inputTest)
_, predicted = torch.max(quartetsNN, 1)

torch.Size([1, 4, 1550])


In [12]:
### To check the predicted label
predicted

tensor([1])