# classify new data using pre-run model

In [1]:
# dependencies 
import networkx as nx
import numpy as np
import pandas as pd

import copy 
import sys
import gzip
import pickle
import time

from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report

from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel

# import hierarchical classification package
sys.path.append( '../../HC_package/' )
from HierarchicalClassification import *

In [2]:
# set variables 
data_file = "patterns.tab"
class_file = "sample.location" 
output_file = "output_"
include_file = "sample_list.txt"

n_jobs = -1
seed = 34

In [3]:
# classificaton mode and thresholds
mode = 'max' # maximum class proedicted probability
threshold = 'adaptive' # minimum threshold value is 1/number of classes in model - this can also be a set value e.g. 0.5

## load dataset and expected classes

In [4]:
# features
raw_features  = pd.read_csv(data_file, sep = "\t", index_col = "pattern_id")

In [5]:
# labels 
raw_labels = pd.read_csv(class_file, sep = "\t", header = None, index_col = 0)

# samples to include - just an index
include = pd.read_csv(include_file, sep = "\t", header = None, index_col = 0)

In [6]:
# find intersection between sample list and those included in features
inc = include.index.values
in_feat = raw_features.columns

intersect = []
for i in include.index.values:
    if i in in_feat:
        intersect.append(i)
        
print("original samples to include:", len(inc))
print("samples present in data:", len(intersect))

original samples to include: 62
samples present in data: 48


In [7]:
# filter feature/labels on include
new_features = raw_features.loc[:, intersect]

In [8]:
# order labels on colnames of features
lab_ind = pd.DataFrame(index=new_features.columns)
new_labels = lab_ind.join(raw_labels)[1]

print(new_features.shape)
print(new_labels.shape)

(94860, 48)
(48,)


## load required prefiltering methods 

In [9]:
# random forest to select features based on GINI
class RF_FS:
     
    def __init__(self, n=100, n_estimators=1000, random_state=0, n_jobs=-1):
        self.name = 'RF_FS'
        self.scores = None
        self.n = n
        self.n_estimators = n_estimators
        self.n_jobs = n_jobs
        self.random_state = random_state
        
    def classifier(self):
        return(self.name)
    
    def fit(self, feat, lab):
        
        # fit model
        fs = RandomForestClassifier(n_estimators=self.n_estimators, random_state=self.random_state, n_jobs=self.n_jobs)
        fs.fit(feat, lab)
      
        # make score / feature id dataframe 
        self.scores = pd.DataFrame(index = feat.columns.values)
        self.scores['scores'] = copy.copy(fs.feature_importances_)
        
        # sort scores
        self.scores.sort_values(by=['scores'], ascending=False, inplace=True)

        return(self)
    
    def transform(self, feat):
        
        # select top scores /transform input data
        ind = self.scores['scores'][0:self.n].index.values # get indices of top hits
        out_feat  = feat.loc[:,ind] # select input cols on indices
      
        # return no_selected top features
        return(copy.copy(out_feat))
    
    def fit_transform(self, feat, lab):
        self.fit(feat, lab)
        out_feat = self.transform(feat)
        return(out_feat)
    
    def plot(self):

        # plot sorted scores
        x=np.r_[0:len(self.scores['scores'])]
        plt.plot(x, self.scores['scores'])
        plt.vlines(x = self.n, ymin = 0, ymax = max(self.scores['scores']), 
               colors = 'red') 

In [10]:
# load model and associated data 
pkl = "/home/sbayliss/Desktop/projects/PHE_Salmonella/notebooks/github/HierarchicalML/optimised_model/optimised_model_data.pkl"
with gzip.open(pkl, 'rb') as f:
    
    models = pickle.load(f)
    train_features = pickle.load(f)
    graph = pickle.load(f)

In [11]:
input_converted = copy.copy(new_features.loc[train_features].T)
print(input_converted.shape)
input_converted

(48, 94860)


pattern_id,1,2,3,4,5,6,7,8,12,13,...,110471,110472,110474,110476,110477,110479,110480,110481,110482,110483
SRR7777155,0,0,0,0,0,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777160,0,0,0,1,1,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777161,1,0,0,1,1,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777165,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777162,0,0,1,0,1,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777163,0,0,0,0,1,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777172,1,0,0,0,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777167,0,0,0,0,1,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777170,0,1,1,0,1,0,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
SRR7777178,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0


## classify samples

In [12]:
# classify samples
timeb4 = time.time()

# classify test sample
(classification_table_test, classifications_test) = classify_samples_in_hierarchy(graph, input_converted, 
                                                                           models,
                                                                           mode = mode, 
                                                                           threshold = threshold)


timeafter = time.time()
total_seconds = timeafter-timeb4
print("classification took ", total_seconds, " seconds")

classification took  0.5498299598693848  seconds


In [13]:
# generate per class statistics
(summary_test, summary_table_test) = summary_statistics_per_class(graph, new_labels, classifications_test,
                                                                         penalty=False, macro_inc_all = False)
#print( summary_test, summary_table_test) 
print(summary_test['macro'])
print(summary_test['micro'])

{'hR': 0.9907407407407408, 'hP': 0.9907407407407408, 'F1': 0.9907407407407408}
{'hR': 0.9736842105263158, 'hP': 0.9487179487179487, 'F1': 0.9610389610389611}


In [14]:
# convert class dist into table 

# create df
col_names = ["hP", "hR", "hF1", "n", "nP", "nT", "root_dist"]
class_table =  pd.DataFrame(
    columns = col_names,
    index = summary_test['per_class'])

# fill table
for i in summary_test['per_class']:
    for j in col_names:
        class_table.loc[i, j] = summary_test['per_class'][i][j]
    
# sort on class names and root_dist 
class_table.sort_index(inplace=True)
class_table.sort_values(by=['root_dist'], inplace=True)

In [15]:
# save per class hierachical summary values 
class_table.to_csv(( "%sper_class.tsv" % output_file ), sep = "\t", header = True)

In [16]:
# overall hier stats
(Hsummary) = overall_summary_stats(new_labels, classifications_test, graph, penalty=False)
print(Hsummary)

{'hR': 0.9722222222222222, 'hP': 0.9722222222222222, 'hF1': 0.9722222222222222, 'hAcc': 0.9722222222222223}


In [17]:
# classification per sample 
classification_table_renamed = pd.concat([raw_labels.rename(columns = {1:"labels"}), classification_table_test], axis=1).loc[new_labels.index]

classification_table_renamed.to_csv(( "%sper_sample.tsv" % output_file ), sep = "\t", header = True)

classification_table_renamed

Unnamed: 0,labels,classification,Americas,Europe,Africa,Asia,Latin America and the Caribbean,Northern America,Eastern Europe,Western Europe,...,Turkey,United Arab Emirates,India,Pakistan,Sri lanka,Indonesia,Malaysia,Singapore,Thailand,Vietnam
SRR7777155,Singapore,Indonesia,0.043333,0.2325,0.278262,0.445905,,,,,...,,,,,,0.340833,0.142833,0.264333,0.202,0.05
SRR7777160,Singapore,Thailand,0.0,0.0175,0.02,0.9625,,,,,...,,,,,,0.1285,0.310167,0.08,0.461333,0.02
SRR7777161,Singapore,Singapore,0.01,0.108433,0.152472,0.729095,,,,,...,,,,,,0.1285,0.232833,0.4575,0.141167,0.04
SRR7777165,Singapore,Malaysia,0.0,0.04,0.0,0.96,,,,,...,,,,,,0.025,0.6675,0.231667,0.065833,0.01
SRR7777162,Singapore,Singapore,0.0,0.151381,0.077238,0.771381,,,,,...,,,,,,0.2435,0.238667,0.276833,0.201,0.04
SRR7777163,Singapore,Singapore,0.01,0.126381,0.112976,0.750643,,,,,...,,,,,,0.221833,0.206167,0.393333,0.178667,0.0
SRR7777172,Singapore,Singapore,0.0,0.036667,0.078333,0.885,,,,,...,,,,,,0.081,0.246167,0.536,0.136833,0.0
SRR7777167,Singapore,Singapore,0.01,0.209833,0.126643,0.653524,,,,,...,,,,,,0.220167,0.122833,0.429333,0.187667,0.04
SRR7777170,Singapore,Singapore,0.01,0.045,0.033429,0.911571,,,,,...,,,,,,0.109333,0.268667,0.533,0.059,0.03
SRR7777178,Singapore,Singapore,0.002222,0.002222,0.01,0.985556,,,,,...,,,,,,0.083,0.104,0.6355,0.1775,0.0


In [18]:
# save classified samples and summary data
hsum_file = "%shsummary.tab" % output_file

# Saving the reference of the standard output
original_stdout = sys.stdout    
 
with open(hsum_file, 'w') as f:
    
    # set stdout
    sys.stdout = f 
    
    # print data
    for i in Hsummary:
        print(i, "\t", Hsummary[i] )
    
    # Reset the standard output
    sys.stdout = original_stdout 

In [19]:
# use standard assessment metrics per node (non-hierarchical)
(test_summary_per_node, test_summary_per_class, test_clf_tables) = per_node_summary_stats(graph, new_labels, input_converted, models, verbose = True)
test_summary_per_class.to_csv(( "%snonhier_per_class.tsv" % output_file ), sep = "\t", header = True)

classes in model:['Africa' 'Americas' 'Asia' 'Europe']
classes in model:['Northern Africa' 'Sub-Saharan Africa']
classes in model:['Latin America and the Caribbean' 'Northern America']
classes in model:['Eastern Asia' 'South-eastern Asia' 'Southern Asia' 'Western Asia']


  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)


classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']
classes in model:['Bulgaria' 'Czech republic' 'Hungary' 'Poland' 'Russian federation']
classes in model:['Barbados' 'Cuba' 'Dominica' 'Dominican republic' 'Jamaica' 'Mexico']
classes in model:['Egypt' 'Morocco' 'Tunisia']
classes in model:['Indonesia' 'Malaysia' 'Singapore' 'Thailand' 'Vietnam']
classes in model:['India' 'Pakistan' 'Sri lanka']
classes in model:['Greece' 'Italy' 'Malta' 'Portugal' 'Spain']
classes in model:['Cape verde' 'Kenya' 'South africa' 'Tanzania']
classes in model:['Cyprus' 'Saudi arabia' 'Turkey' 'United Arab Emirates']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
