In [1]:
import pandas as pd
import numpy as np
from math import sqrt
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from statistics import mean, stdev

In [2]:
class Tree():
    def __init__(self,name,rank):
        self.subtree = []
        self.mean = 0
        self.name = name
        self.rank = rank
        self.pan_stats = []
    def enter(self,lineage):
        if len(lineage) == 1:
            self.mean = int(lineage[0])
            return
        
        result = self.find(lineage[0])
        if result == None:
            new_node = Tree(lineage[0],self.rank+1)
            lineage = lineage[1:]
            new_node.enter(lineage)
            self.subtree.append(new_node)
        else:
            lineage = lineage[1:]
            self.subtree[result].enter(lineage)
        self.averaging()
    
    def find(self,target_name):
        for i in range(0,len(self.subtree),1):
            if self.subtree[i].name == target_name:
                return i
        return None
    
    def averaging(self):
        add = 0
        for i in range(0,len(self.subtree),1):
            add += self.subtree[i].mean
        self.mean = add/len(self.subtree)

    def output(self):
        pan_stats = []
        pan_stats = self.pull(pan_stats)
        return pan_stats
        
    def pull(self,pan_stats):
        if str(self.name) != "nan":
            pan_stats.append([self.rank,str(self.name),self.mean])
        if self.subtree == []:
            return pan_stats
        for i in range(0,len(self.subtree),1):
            pan_stats = self.subtree[i].pull(pan_stats)
        return pan_stats

In [3]:
class GCNlibrary():
    def __init__(self):
        self.lib = {}

    def record(self, rank, name, mean):
        if rank in list(self.lib.keys()):
            self.lib[rank][name] = mean
        else:
            rank_dict = {}
            rank_dict[name] = mean
            self.lib[rank] = rank_dict
    
    def search(self, index, name):
        return self.lib.get(index).get(name,"N/A")

    def fit(self,pan_stats):
        for i in range(0,len(pan_stats),1):
            rank = pan_stats[i][0]
            name = pan_stats[i][1]
            mean = float(pan_stats[i][2])
            self.record(rank,name,mean)
            
    def predict(self,taxa):
        ultimean = self.lib.get(0).get('prokaryotes')
        output = []
        for lineage in taxa:
            for index in range(len(lineage),0,-1):
                result = self.search(index,lineage[index-1])
                if result == 'N/A':
                    if index == 1:
                        output.append(ultimean)
                        break
                    else:
                        continue
                else:
                    output.append(result)
                    # lineage.append(search)
                    break
        return output

In [4]:
class TaxAvg(Tree,GCNlibrary):
    def __init__(self):
        self.prokaryotes = Tree("prokaryotes",0)
        self.model = GCNlibrary()
        self.pan_stats = []
        
    def fit(self,da):
        for lineage in da:
            self.prokaryotes.enter(lineage)
        self.pan_stats = self.prokaryotes.output()
        self.model.fit(self.pan_stats)
        
    def predict(self,X):
        pred = self.model.predict(X)
        return pred

In [5]:
def test_rmse(model,X_test,Y_test):
    test_preds = model.predict(X_test)
    mse = mean_squared_error(Y_test, test_preds)
    rmse = sqrt(mse)
    return rmse

In [6]:
performance = {}
for database in ["rdp","silva","pr2"]:
    if database in ["rdp","silva"]:
        da = pd.read_csv("taxa/"+database+"_full_length_taxa.csv")
        X = da.iloc[:,[5,6,7,8,9,10]]
        X = X.values.tolist()
        Y = da.iloc[:,[4]]
        Y = Y.values.tolist()
        da = da.iloc[:,[5,6,7,8,9,10,4]]
        da = da.values.tolist()
    else: # pr2 contains more columns
        da = pd.read_csv("taxa/"+database+"_full_length_taxa.csv")
        X = da.iloc[:,[5,6,7,8,9,10,11,12]]
        X = X.values.tolist()
        Y = da.iloc[:,[4]]
        Y = Y.values.tolist()
        da = da.iloc[:,[5,6,7,8,9,10,11,12,4]]
        da = da.values.tolist()
    performance[database+"_TA"] = []
    multiplicand = int(len(da)*0.2) #5-fold cross-validation
    for i in range(0,5,1):
        X_test = X[i*multiplicand:(i+1)*multiplicand]
        Y_test = Y[i*multiplicand:(i+1)*multiplicand]
        train = [da[0:i*multiplicand],da[(i+1)*multiplicand:len(da)]]
        train = sum(train,[])
        model = TaxAvg()
        model.fit(train)
        performance[database+"_TA"].append(test_rmse(model,X_test,Y_test))

In [7]:
pd.DataFrame(performance)

Unnamed: 0,rdp_TA,silva_TA,pr2_TA
0,1.199023,1.258218,1.327511
1,1.191911,1.223447,1.290085
2,1.130917,1.168445,1.216464
3,1.143604,1.19429,1.274543
4,1.210921,1.172372,1.177912


In [8]:
pd.DataFrame(performance).to_csv("performance/TA_full_length.csv",index=False)