# process a validation unitig dataset 
- a validation dataset of unitig/pattern features *** genomes will be run through the HC for classification 

In [1]:
# dependencies 
import networkx as nx
import numpy as np
import pandas as pd
import copy 
import pickle
import gzip
import time

from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report

from matplotlib import pyplot as plt

# import hierarchical classification package
import sys
sys.path.append( './HC_package')
from HierarchicalClassification import *

In [2]:
# set variables 
seed = 34

example_outputs = "./model_outputs"

In [3]:
# features
validation_features  = pd.read_csv("./data/ncbi_patterns.tab.zip", sep = "\t", index_col = "pattern_id", compression='zip')

In [4]:
# labels 
validation_labels_raw = pd.read_csv("./data/ncbi_labels.tab", sep = "\t", header = None, index_col = 0)

In [5]:
# order labels on colnames of features
lab_ind = pd.DataFrame(index=validation_features.columns)
validation_labels = lab_ind.join(validation_labels_raw)[1]

In [6]:
print(validation_features.shape)
print(validation_labels.shape)

(94860, 430)
(430,)


In [7]:
# load trained model
with gzip.open("./model_outputs/models.pkl.gz", 'rb') as f:
    models = pickle.load(f)

In [8]:
# load graph
with open("./data/graph.pkl", 'rb') as f:
    graph = pickle.load(f)

In [9]:
# load model feature names
feature_list = np.loadtxt("./data/feature_list.txt", dtype = int)

In [10]:
# check for intersection between model features and input features
filtered_features = copy.copy(validation_features.loc[feature_list].T)
print(filtered_features.shape)

(430, 89481)


In [11]:
# classify new samples
timeb4 = time.time()

# classify test sample
(classification_table_validate, classifications_validate) = classify_samples_in_hierarchy(graph, filtered_features, 
                                                                           models,
                                                                           mode = 'max', 
                                                                           threshold = 0.500001)


timeafter = time.time()
total_seconds = timeafter-timeb4
print("classification took ", total_seconds, " seconds")

classification took  6.980911493301392  seconds


In [12]:
# save classification per sample
real_class_labs = pd.DataFrame(validation_labels).rename(columns = {1:'real'})
merged_classifications = pd.merge(real_class_labs, classifications_validate, left_index=True, right_index=True).rename(columns = {'classification':'prediction'})

In [13]:
merged_classifications.to_csv("%s/validation - classifications.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

In [14]:
(summary_train, summary_table_train) = summary_statistics_per_class(graph, validation_labels, 
                                                                    classifications_validate, 
                                                                    penalty=False)

In [15]:
# save training summaries
summary_table_train.to_csv("%s/validation - summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

In [16]:
# access non-hierarchical statistics per node
(per_node, per_class, clf_reports) = per_node_summary_stats(graph, validation_labels, filtered_features, models, verbose = True)


classes in model:['Africa' 'Americas' 'Asia' 'Europe']
classes in model:['Northern Africa' 'Sub-Saharan Africa']
classes in model:['Latin America and the Caribbean' 'Northern America']
classes in model:['Eastern Asia' 'South-eastern Asia' 'Southern Asia' 'Western Asia']
classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']
classes in model:['Bulgaria' 'Czech republic' 'Hungary' 'Poland' 'Russian federation']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Barbados' 'Cuba' 'Dominica' 'Dominican republic' 'Jamaica' 'Mexico']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Egypt' 'Morocco' 'Tunisia']
classes in model:['Indonesia' 'Malaysia' 'Singapore' 'Thailand' 'Vietnam']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['India' 'Pakistan' 'Sri lanka']
classes in model:['Greece' 'Italy' 'Malta' 'Portugal' 'Spain']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Cape verde' 'Kenya' 'South africa' 'Tanzania']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Cyprus' 'Saudi arabia' 'Turkey' 'United Arab Emirates']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
# save per node summary
per_node.to_csv("%s/validation - per_node_summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

# save per class summary
per_class.to_csv("%s/validation - per_class_summary.tsv" % example_outputs, 
                           sep = "\t", header = True, index = False)

In [18]:
# generate overall hierachical summary stats (for entire dataset)
(h_summary) = overall_summary_stats(validation_labels, classifications_validate, graph, penalty=False)
print(h_summary)

{'hR': 0.35348837209302325, 'hP': 0.5211428571428571, 'hF1': 0.4212471131639723, 'hAcc': 0.46666666666666684}
