# classify new data using pre-run model

In [1]:
# dependencies 
import networkx as nx
import numpy as np
import pandas as pd

import copy 
import sys
import gzip
import pickle
import time

from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report

# import hierarchical classification package
sys.path.append( '../../HC_package/' )
from HierarchicalClassification import *

In [2]:
# set variables 
data_file = "patterns.tab"
class_file = "sample.location" 
output_file = "output_"
include_file = "sample_list.txt"

n_jobs = -1
seed = 34

In [3]:
# classificaton mode and thresholds
mode = 'max' # maximum class proedicted probability
threshold = 'adaptive' # minimum threshold value is 1/number of classes in model - this can also be a set value e.g. 0.5

## load dataset and expected classes

In [4]:
# features
raw_features  = pd.read_csv(data_file, sep = "\t", index_col = "pattern_id")

In [5]:
# labels 
raw_labels = pd.read_csv(class_file, sep = "\t", header = None, index_col = 0)

# samples to include - just an index
include = pd.read_csv(include_file, sep = "\t", header = None, index_col = 0)

In [6]:
# find intersection between sample list and those included in features
inc = include.index.values
in_feat = raw_features.columns

intersect = []
for i in include.index.values:
    if i in in_feat:
        intersect.append(i)
        
print("original samples to include:", len(inc))
print("samples present in data:", len(intersect))

original samples to include: 35
samples present in data: 35


In [7]:
# filter feature/labels on include
new_features = raw_features.loc[:, intersect]

In [8]:
# order labels on colnames of features
lab_ind = pd.DataFrame(index=new_features.columns)
new_labels = lab_ind.join(raw_labels)[1]

print(new_features.shape)
print(new_labels.shape)

(94860, 35)
(35,)


In [9]:
# load trained model
with gzip.open("../../model_outputs/models.pkl.gz", 'rb') as f:
    models = pickle.load(f)

In [10]:
# load graph
with open("../../data/graph.pkl", 'rb') as f:
    graph = pickle.load(f)

In [11]:
# load model feature names
train_feat = np.loadtxt("../../data/feature_list.txt", dtype = int)

In [12]:
# filter input features on model features (train_features) and transpose
input_converted = copy.copy(new_features.loc[train_feat].T)
print(input_converted.shape)
input_converted

(35, 89481)


pattern_id,1,2,3,4,5,6,7,12,13,14,...,110468,110469,110470,110471,110472,110474,110476,110479,110480,110483
ERR2278721,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278723,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278728,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278729,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278733,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278739,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278742,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278743,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278749,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0
ERR2278713,1,1,1,1,1,1,1,0,1,1,...,0,0,1,0,0,1,0,0,0,0


## classify samples

In [13]:
# classify samples
timeb4 = time.time()

# classify test sample
(classification_table_test, classifications_test) = classify_samples_in_hierarchy(graph, input_converted, 
                                                                           models,
                                                                           mode = mode, 
                                                                           threshold = threshold)


timeafter = time.time()
total_seconds = timeafter-timeb4
print("classification took ", total_seconds, " seconds")

classification took  2.588233709335327  seconds


In [14]:
# generate per class statistics
(summary_test, summary_table_test) = summary_statistics_per_class(graph, new_labels, classifications_test,
                                                                         penalty=False, macro_inc_all = False)
#print( summary_test, summary_table_test) 
print(summary_test['macro'])
print(summary_test['micro'])

{'hR': 0.8444444444444446, 'hP': 0.8467032967032968, 'F1': 0.8455723620085919}
{'hR': 0.7596899224806202, 'hP': 0.620253164556962, 'F1': 0.6829268292682927}


In [15]:
# convert class dist into table 

# create df
col_names = ["hP", "hR", "hF1", "n", "nP", "nT", "root_dist"]
class_table =  pd.DataFrame(
    columns = col_names,
    index = summary_test['per_class'])

# fill table
for i in summary_test['per_class']:
    for j in col_names:
        class_table.loc[i, j] = summary_test['per_class'][i][j]
    
# sort on class names and root_dist 
class_table.sort_index(inplace=True)
class_table.sort_values(by=['root_dist'], inplace=True)

In [16]:
# save per class hierachical summary values 
class_table.to_csv(( "%sper_class.tsv" % output_file ), sep = "\t", header = True)

In [17]:
# overall hier stats
(Hsummary) = overall_summary_stats(new_labels, classifications_test, graph, penalty=False)
print(Hsummary)

{'hR': 0.7047619047619048, 'hP': 0.7115384615384616, 'hF1': 0.7081339712918662, 'hAcc': 0.7142857142857139}


In [18]:
# classification per sample 
classification_table_renamed = pd.concat([raw_labels.rename(columns = {1:"labels"}), classification_table_test], axis=1).loc[new_labels.index]

classification_table_renamed.to_csv(( "%sper_sample.tsv" % output_file ), sep = "\t", header = True)

classification_table_renamed

Unnamed: 0,labels,classification,Americas,Europe,Africa,Asia,Latin America and the Caribbean,Northern America,Eastern Europe,Western Europe,...,Turkey,United Arab Emirates,India,Pakistan,Sri lanka,Indonesia,Malaysia,Singapore,Thailand,Vietnam
ERR2278721,Poland,Hungary,0.07,0.677,0.009,0.244,,,0.491,0.039,...,,,,,,,,,,
ERR2278723,Poland,Poland,0.037,0.707,0.009,0.247,,,0.535,0.029,...,,,,,,,,,,
ERR2278728,Poland,Poland,0.036,0.701,0.02,0.243,,,0.519,0.035,...,,,,,,,,,,
ERR2278729,Poland,Poland,0.034,0.688,0.009,0.269,,,0.563,0.019,...,,,,,,,,,,
ERR2278733,Poland,Italy,0.027,0.854,0.031,0.088,,,0.167,0.009,...,,,,,,,,,,
ERR2278739,Poland,Hungary,0.031,0.744,0.008,0.217,,,0.551,0.022,...,,,,,,,,,,
ERR2278742,Poland,Poland,0.039,0.727,0.008,0.226,,,0.568,0.038,...,,,,,,,,,,
ERR2278743,Poland,Poland,0.023,0.785,0.004,0.188,,,0.59,0.004,...,,,,,,,,,,
ERR2278749,Poland,Eastern Europe,0.017,0.489,0.095,0.399,,,0.528,0.021,...,,,,,,,,,,
ERR2278713,Poland,Greece,0.018,0.657,0.154,0.171,,,0.38,0.005,...,,,,,,,,,,


In [19]:
# save classified samples and summary data
hsum_file = "%shsummary.tab" % output_file

# Saving the reference of the standard output
original_stdout = sys.stdout    
 
with open(hsum_file, 'w') as f:
    
    # set stdout
    sys.stdout = f 
    
    # print data
    for i in Hsummary:
        print(i, "\t", Hsummary[i] )
    
    # Reset the standard output
    sys.stdout = original_stdout 

In [20]:
# use standard assessment metrics per node (non-hierarchical)
(test_summary_per_node, test_summary_per_class, test_clf_tables) = per_node_summary_stats(graph, new_labels, input_converted, models, verbose = True)
test_summary_per_class.to_csv(( "%snonhier_per_class.tsv" % output_file ), sep = "\t", header = True)

classes in model:['Africa' 'Americas' 'Asia' 'Europe']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Northern Africa' 'Sub-Saharan Africa']
classes in model:['Latin America and the Caribbean' 'Northern America']
classes in model:['Eastern Asia' 'South-eastern Asia' 'Southern Asia' 'Western Asia']
classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Bulgaria' 'Czech republic' 'Hungary' 'Poland' 'Russian federation']
classes in model:['Barbados' 'Cuba' 'Dominica' 'Dominican republic' 'Jamaica' 'Mexico']
classes in model:['Egypt' 'Morocco' 'Tunisia']
classes in model:['Indonesia' 'Malaysia' 'Singapore' 'Thailand' 'Vietnam']
classes in model:['India' 'Pakistan' 'Sri lanka']
classes in model:['Greece' 'Italy' 'Malta' 'Portugal' 'Spain']
classes in model:['Cape verde' 'Kenya' 'South africa' 'Tanzania']
classes in model:['Cyprus' 'Saudi arabia' 'Turkey' 'United Arab Emirates']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
