# Train HC full model
- the HC will run on the dataset described in the main publication
- this will take a moderate amount of time to run on low-end hardware

In [1]:
# dependencies 
import networkx as nx
import numpy as np
import pandas as pd
import copy 
import gzip
import pickle
import time

from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report

from matplotlib import pyplot as plt

# import hierarchical classification package
import sys
sys.path.append( './HC_package')
from HierarchicalClassification import *

In [2]:
# set variables 
seed = 34

example_outputs = "./model_outputs"

In [3]:
# read features 
meta = pd.read_csv('./data/metadata.tsv.zip', sep='\t', compression='zip').set_index('SRA.Accession', drop = False)
meta

Unnamed: 0_level_0,SRA.Accession,Region,Subregion,Country,1,2,3,4,5,6,...,110469,110470,110471,110472,110474,110476,110479,110480,110483,corrected_labels
SRA.Accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR8667277,SRR8667277,Americas,Latin America and the Caribbean,Barbados,1,1,1,1,1,1,...,0,1,0,0,1,0,0,0,0,Barbados
SRR8691693,SRR8691693,Americas,Latin America and the Caribbean,Barbados,1,1,1,1,1,1,...,0,1,0,0,1,0,0,0,0,Barbados
SRR1967763,SRR1967763,Americas,Latin America and the Caribbean,Barbados,1,0,0,1,1,0,...,0,1,0,0,1,0,0,0,0,Barbados
SRR8369264,SRR8369264,Americas,Latin America and the Caribbean,Barbados,1,1,1,1,1,1,...,0,1,0,0,1,0,0,0,0,Barbados
SRR6922673,SRR6922673,Americas,Latin America and the Caribbean,Barbados,1,0,0,1,0,1,...,0,1,0,0,1,0,0,0,0,Barbados
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR1960134,SRR1960134,Asia,South-eastern Asia,Vietnam,1,0,0,1,1,1,...,0,1,0,0,1,0,0,0,0,Vietnam
SRR5220879,SRR5220879,Asia,South-eastern Asia,Vietnam,1,0,0,1,1,0,...,0,1,0,0,1,0,0,0,0,Vietnam
SRR7275070,SRR7275070,Asia,South-eastern Asia,Vietnam,0,0,0,0,1,0,...,0,1,0,0,1,0,0,0,0,Vietnam
SRR6898922,SRR6898922,Asia,South-eastern Asia,Vietnam,1,1,1,1,1,1,...,0,1,0,0,1,0,0,0,0,Vietnam


In [4]:
# read labels 
f = open('./data/labels.txt', 'r+')
labels = f.read().split('\n')
f.close()

# remove final empty value
if labels[len(labels)-1] == "":
    labels = labels[0:(len(labels)-1)]

In [5]:
# read graph
graph = nx.readwrite.gpickle.read_gpickle('./data/graph.pkl')
graph.nodes

NodeView(('root', 'Americas', 'Latin America and the Caribbean', 'Barbados', 'Europe', 'Eastern Europe', 'Bulgaria', 'Africa', 'Sub-Saharan Africa', 'Cape verde', 'Asia', 'Eastern Asia', 'China', 'Cuba', 'Western Asia', 'Cyprus', 'Czech republic', 'Dominica', 'Dominican republic', 'Northern Africa', 'Egypt', 'Western Europe', 'France', 'Southern Europe', 'Greece', 'Hungary', 'Southern Asia', 'India', 'South-eastern Asia', 'Indonesia', 'Italy', 'Jamaica', 'Kenya', 'Malaysia', 'Malta', 'Mexico', 'Morocco', 'Pakistan', 'Poland', 'Portugal', 'Russian federation', 'Saudi arabia', 'Singapore', 'South africa', 'Spain', 'Sri lanka', 'Tanzania', 'Thailand', 'Tunisia', 'Turkey', 'United Arab Emirates', 'Northern America', 'United states', 'Vietnam'))

In [6]:
features = meta.drop(['SRA.Accession', 'Region', 'Subregion', 'Country'], axis = 1)
if 'corrected_labels' in features.columns:
    features = features.drop(['corrected_labels'], axis = 1)

In [7]:
# split the data into training and testing sets
train_features, test_features, train_labels, test_labels = train_test_split(features, 
                                                                            labels, 
                                                                            test_size = 0.25, 
                                                                            stratify=labels, # stratify on country
                                                                            random_state = seed)

In [8]:
# save feature labels 
save_feat = train_features.columns.values
np.savetxt('./data/feature_list.txt', save_feat, fmt="%s")

In [9]:
# set classifier model for HC 
classifier = RandomForestClassifier(n_estimators = 1000, n_jobs=-1, random_state = seed)

In [10]:
# set resampler for HC
resampler = RandomBalancingSampler(sampling_strategy = 'mean', random_state=seed)

In [11]:
# fit hierachical classifier
start_time = time.time()
models = fit_hierarchical_classifier(graph, train_labels, train_features, classifier,
                                    subsampler = resampler, verbose = True)
train_time = time.time() - start_time
print(" - Training Time(s): ", train_time)

 - starting root
 - processing: root
	 - classes in model:['Africa' 'Americas' 'Asia' 'Europe']
 - starting Americas
 - processing: Americas
	 - classes in model:['Latin America and the Caribbean' 'Northern America']
 - starting Europe
 - processing: Europe
	 - classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']
 - starting Africa
 - processing: Africa
	 - classes in model:['Northern Africa' 'Sub-Saharan Africa']
 - starting Asia
 - processing: Asia
	 - classes in model:['Eastern Asia' 'South-eastern Asia' 'Southern Asia' 'Western Asia']
 - starting Latin America and the Caribbean
 - processing: Latin America and the Caribbean
	 - classes in model:['Barbados' 'Cuba' 'Dominica' 'Dominican republic' 'Jamaica' 'Mexico']
 - starting Northern America
 - tip or non-branching node
 - starting Eastern Europe
 - processing: Eastern Europe
	 - classes in model:['Bulgaria' 'Czech republic' 'Hungary' 'Poland' 'Russian federation']
 - starting Western Europe
 - tip or non-branchi

In [12]:
# save trained models
with gzip.open("%s/models.pkl.gz" % example_outputs, 'wb') as file:
    pickle.dump(models, file)

In [13]:
# classify and summarise training data
(classification_table_train, classifications_train) = classify_samples_in_hierarchy(graph, train_features, 
                                                                                   models,
                                                                                   mode = 'max', 
                                                                                   threshold = 0.51, 
                                                                                   verbose = True)

 - traversing the DAG using max probability per node, threshold > 0.51
 - processing node: root
	 - descendants:['Americas', 'Europe', 'Africa', 'Asia']
	 - classifying node: root
	 - classes in model:['Africa' 'Americas' 'Asia' 'Europe']
		 -  284 samples assigned to Africa
		 -  294 samples assigned to Americas
		 -  709 samples assigned to Asia
		 -  397 samples assigned to Europe
		 -  50 samples fixed at root
 - processing node: Americas
	 - descendants:['Latin America and the Caribbean', 'Northern America']
	 - classifying node: Americas
	 - classes in model:['Latin America and the Caribbean' 'Northern America']
		 -  277 samples assigned to Latin America and the Caribbean
		 -  17 samples assigned to Northern America
		 -  0 samples fixed at Americas
 - processing node: Europe
	 - descendants:['Eastern Europe', 'Western Europe', 'Southern Europe']
	 - classifying node: Europe
	 - classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']
		 -  132 samples assigned to

In [14]:
(summary_train, summary_table_train) = summary_statistics_per_class(graph, train_labels, 
                                                                    classifications_train, 
                                                                    penalty=False)

In [15]:
# save training summaries
summary_table_train.to_csv("%s/training_summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

In [16]:
# classify and summarise test data
(classification_table_test, classifications_test) = classify_samples_in_hierarchy(graph, test_features, 
                                                                                   models,
                                                                                   mode = 'max', 
                                                                                   threshold = 0.51)

(summary_test, summary_table_test) = summary_statistics_per_class(graph, test_labels, 
                                                                  classifications_test, 
                                                                  penalty=False)

In [17]:
# save test summary
summary_table_test.to_csv("%s/test_summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

In [18]:
# access non-hierarchical statistics per node
(per_node, per_class, clf_reports) = per_node_summary_stats(graph, test_labels, test_features, models, verbose = True)


classes in model:['Africa' 'Americas' 'Asia' 'Europe']
classes in model:['Northern Africa' 'Sub-Saharan Africa']
classes in model:['Latin America and the Caribbean' 'Northern America']
classes in model:['Eastern Asia' 'South-eastern Asia' 'Southern Asia' 'Western Asia']
classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']
classes in model:['Bulgaria' 'Czech republic' 'Hungary' 'Poland' 'Russian federation']
classes in model:['Barbados' 'Cuba' 'Dominica' 'Dominican republic' 'Jamaica' 'Mexico']
classes in model:['Egypt' 'Morocco' 'Tunisia']
classes in model:['Indonesia' 'Malaysia' 'Singapore' 'Thailand' 'Vietnam']
classes in model:['India' 'Pakistan' 'Sri lanka']
classes in model:['Greece' 'Italy' 'Malta' 'Portugal' 'Spain']
classes in model:['Cape verde' 'Kenya' 'South africa' 'Tanzania']
classes in model:['Cyprus' 'Saudi arabia' 'Turkey' 'United Arab Emirates']


In [19]:
# save per node summary
per_node.to_csv("%s/per_node_summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

# save per class summary
per_class.to_csv("%s/per_class_summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

In [20]:
# generate overall hierachical summary stats (for entire dataset)
(h_summary) = overall_summary_stats(test_labels, classifications_test, graph, penalty=False)
print(h_summary)

{'hR': 0.7449625791594704, 'hP': 0.8702084734364492, 'hF1': 0.8027295285359801, 'hAcc': 0.7941853770869316}
