In [72]:
import copy
import ete3
import torch
from ete3 import Tree
import numpy as np
import random
import json

In [73]:
# function to convert string to numbers
def convert_string_to_numbers(str, dict):
    ''' str: string to convert
        dict dictionary with the relative ordering of each char'''
            # create a map iterator using a lambda function
    # lambda x -> return dict[x]
    # This return the value for each key in dict based on str
    numbers = map(lambda x: dict[x], str)
    # return an array of int64 numbers
    return np.fromiter(numbers, dtype=np.int64)

## IndSeq structure
### `name`
The name of this species, the original species have name `1` to `n`, the internal species have name `n+t`
### `seq`
The sequence alignment of the species. The ancestral species have sequences generated from the children

In [74]:
class IndSeq():
    def __init__(self, name, seq):
        self.name = name
        self.seq = seq

## NNTree Structure

### `Sequences`
a list of `n-2` lists of `IndSeq` objects. Each list of `IndSeq` contains all IndSeq for each level, so it starts with `n`, and have `n-1` for each level beyond. For validation set, only the first level is included.
### `n_node`
number of species contained in the tree, this also includes number of levels for the rooted tree.
### `level`
a list of `n_node-2` numbered lists, each numbered list include two 1s and 0s for the rest. 1 indicate that the 2 nodes are being connected to the common ancester in the next level.

In [75]:
class NNTree:
    def __init__(self, n_node, IndSeq, level):
        self.sequences = IndSeq
        self.n_node = n_node
        self.level = level

In [76]:
def construct_train_set(tree, seq_string, num_taxon):
    # First we construct the sequence list of the first layer
    seq_list = []                      # The sequences in one level
    Sequences = []                      # The lists of sequence list for all levels
    all_level = []                      # list
    for i in range(len(seq_string)):
        new_seq = IndSeq(str(i+1), convert_string_to_numbers(seq_string[i][:-1], dict_amino))
        seq_list.append(new_seq)
    # Now there are five IndSeq objects in the seq_list, we append it to the grand list
    # Now we traverse the tree and construct internal nodes while getting new levels
    # We need to give name to internal node by incrementing number
    node_so_far = num_taxon
    # The current level, start with 0
    level = 0
    for node in tree.traverse("postorder"):
        # This is when we found an internal node
        # every time a internal node is found, we construct a new level
        if node.name == "":
            list_to_append = copy.deepcopy(seq_list)
            Sequences.append(list_to_append)
            # the individual level array, decrease by 1 for each level
            ind_level = np.zeros(num_taxon - level)
            # give name to the new node
            node.name = str(node_so_far+1)
            # increment the naming value
            node_so_far += 1
            # get all child of this internal node for getting the sequence
            children = node.get_children()
            left = children[0].name
            right = children[1].name
            
            # Now we get the sequence of the internal node
            # We also remove the child node from list of IndSeq
            # We also need to get the index that changes to one
            int_seq, ind_level, seq_list = internal_proc(left, right, seq_list, ind_level)
            # we append the ind_level
            all_level.append(ind_level)
            # we add a new IndSeq object for the internal node
            seq_list.append(IndSeq(node.name, int_seq))
            level += 1
    return NNTree(num_taxon, Sequences, all_level)

In [77]:
def internal_proc(left, right, species_seq, ind_level):
    #first we need to find the index of left and right child in the species_seq
    # because we search by name
    left_ind, right_ind = -1, -1
    for i in range(len(species_seq)):
        if species_seq[i].name == left:
            left_ind = i
        if species_seq[i].name == right:
            right_ind = i
        if left_ind != -1 and right_ind != -1:
            break
    # Now we obtain both left and right index
    # First we generate the internal sequence
    int_seq = []
    for ii in range(len(species_seq[left_ind].seq)):
        if species_seq[left_ind].seq[ii] == species_seq[right_ind].seq[ii]:
            int_seq.append(species_seq[left_ind].seq[ii])
        else:
            int_seq.append(random.choice([species_seq[left_ind].seq[ii], species_seq[right_ind].seq[ii]]))
    
    # now we change the ind_level
    ind_level[left_ind] = 1
    ind_level[right_ind] = 1
    
    # now we remove the child IndSeq from the list
    del species_seq[left_ind]
    # since left is removed, we shift index to left by 1
    del species_seq[right_ind - 1]

    return int_seq, ind_level, species_seq

## Loading Data

In [91]:
# get name of the script
nameScript = "get_tree.py"
# get json file name of the script
nameJson = "model_param.json"

print("------------------------------------------------------------------------")
print("File proprocessing for 5-taxon trees")
print("------------------------------------------------------------------------")
print("Executing " + nameScript + " following " + nameJson, flush = True)

# opening Json file 
jsonFile = open(nameJson) 
dataJson = json.load(jsonFile)

data_root = dataJson["dataRoot"]         # data folder
model_root = dataJson["modelRoot"]       # folder to save the data

label_files = dataJson["labelFile"]      # file with labels
sequence_files = dataJson["matFile"]     # file with sequences
tree_files = dataJson["treeFile"]        # file with tree structure
num_taxon = dataJson["numTaxon"]         # NUmber of taxon

if "summaryFile" in dataJson:
    summary_file = dataJson["summaryFile"]
else :
    summary_file = "summary_file.txt"


print("------------------------------------------------------------------------") 
print("Loading Sequence Data in " + sequence_files, flush = True)
print("Loading Tree Data in " + tree_files, flush = True)

# we read the sequence as a list of strings
with open(data_root+sequence_files, 'r') as f:
    seq_string = f.readlines()

with open(data_root+tree_files, 'r') as f:
    tree_newick = f.readlines()
seq_length = len(seq_string[0])-1
num_sample = len(tree_newick)
print("------------------------------------------------------------------------") 
print("Number of samples: " + str(num_sample), flush = True)

------------------------------------------------------------------------
File proprocessing for 5-taxon trees
------------------------------------------------------------------------
Executing get_tree.py following model_param.json
------------------------------------------------------------------------
Loading Sequence Data in sequences12062021.in
Loading Tree Data in trees12062021.in
------------------------------------------------------------------------
Number of samples: 10000


In [79]:
strL = ""
for c in seq_string[0][:-1]:
    if not c in strL:
        strL += c

# we sort them
strL = sorted(strL)

# we give them a relative order
dict_amino = {}
for ii, c in enumerate(strL):
    dict_amino[c] = ii

## Construct the sets

In [80]:
train_set, test_set = [], []
for i in range(int(num_sample*0.9)):
    new_tree = construct_train_set(Tree(tree_newick[i]), seq_string[i*5:(i+1)*5], num_taxon)
    train_set.append(new_tree)
for ii in range(int(num_sample*0.9), int(num_sample*0.95)):
    new_tree = construct_train_set(Tree(tree_newick[ii]), seq_string[ii*5:(ii+1)*5], num_taxon)
    test_set.append(new_tree)

## Model

In [81]:
class ResNetModule(torch.nn.Module):
    '''Dense Residual network acting on each site, thus
    implemtented via a Conv1 with window size equals to one
    '''

    def __init__(self, channel_count):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv1d(channel_count, channel_count, 1),
            torch.nn.BatchNorm1d(channel_count),
            torch.nn.ReLU(),
            torch.nn.Conv1d(channel_count, channel_count, 1),
            torch.nn.BatchNorm1d(channel_count),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        return x + self.layers(x)

In [82]:
class SequenceModule(torch.nn.Module):

    def __init__(self, length_dict, embed_dim, level):
        super().__init__()
        self.embedding_layer = nn.Embedding(length_dict, embedding_dim)
        self._res_module_1 = ResNetModule(embedding_dim)
        self._res_module_2 = ResNetModule(embedding_dim)
        self.embedding_dim = embedding_dim
        self.level = level

    def forward(self, x):
        x = self.embedding_layer(x).permute([0, 1, 3, 2])
        d = []
        # FIXME
        # Change later, only first level for now!
        d0 = self._res_module_1(x[:, 0, :, :])
        d1 = self._res_module_1(x[:, 1, :, :])
        d2 = self._res_module_1(x[:, 2, :, :])
        d3 = self._res_module_1(x[:, 3, :, :])
        d4 = self._res_module_1(x[:, 4, :, :])

        aggregated_sum = self._res_module_2(d0+d1+d2+d3+d4)

        D_0 = d0 + aggregated_sum
        D_1 = d1 + aggregated_sum
        D_2 = d2 + aggregated_sum
        D_3 = d3 + aggregated_sum
        D_4 = d4 + aggregated_sum

        x = torch.cat([torch.unsequeeze(D_0, 1), 
                       torch.unsequeeze(D_1, 1),
                       torch.unsequeeze(D_2, 1),
                       torch.unsequeeze(D_3, 1),
                       torch.unsequeeze(D_4, 1)], dim = 1)
        
        return x

In [83]:
class targetModule(torch.nn.Module):
    def __init__(self, embed_dim, level, num_layers = 3, output_size = 20, dropout=0.0):
        super().__init__()

        self.sequence_model = SequenceModule(20, embed_dim, 0)
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        self.embed_dim = embed_dim
        self.fc = torch.nn.Linear(hidden_dim, self.output_size)

        self.classifier = torch.nn.Linear(self.output_size, 1)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim,
                           num_layers, dropout=dropout,
                           batch_first=True)
        self.rnn.flatten_parameters()

    def forward(self, x):
        device = x.device
        batch_size = x.size()[0]

        g = sequence_model(x)
        X =  g.view(5*batch_size, self.embed_dim, -1)
        r_output, hidden = self.rnn(X.permute([0, 2, 1]))
        r_output_last = r_output[:, -1, :]
        
        out = r_output_last.contiguous().view(-1, self.hidden_dim)

        # (none*3, out_put_dimensions)
        output = self.fc(out)

        X_combined = self.classifier(output)
        # (3*none, 1)

        X_combined = X_combined.view(batch_size, 5)

        return X_combined



In [110]:
train_set[0].level[0]

array([1., 0., 0., 0., 1.])

# Only the first layer for now!!!

In [116]:
train_seq = np.zeros((len(train_set)*num_taxon, seq_length), dtype=np.int64)
train_label = np.zeros((len(train_set), num_taxon), dtype=np.int64)
for i in range(len(train_set)):
    train_label[i,:] = train_set[i].level[0]
    for ii in range(len(train_set[i].sequences[0])):
        train_seq[ii,:] = train_set[i].sequences[0][ii].seq.reshape((1, seq_length))
train_seq = train_seq.reshape((len(train_set), -1, seq_length))  