# classify new data using pre-run model

In [1]:
# dependencies 
import networkx as nx
import numpy as np
import pandas as pd

import copy 
import sys
import gzip
import pickle
import time

from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report

from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel

# import hierarchical classification package
sys.path.append( '../../HC_package/' )
from HierarchicalClassification import *

In [2]:
# set variables 
data_file = "patterns.tab"
class_file = "sample.location" 
output_file = "output_"
include_file = "sample_list.txt"

n_jobs = -1
seed = 34

In [3]:
# classificaton mode and thresholds
mode = 'max' # maximum class proedicted probability
threshold = 'adaptive' # minimum threshold value is 1/number of classes in model - this can also be a set value e.g. 0.5

## load dataset and expected classes

In [4]:
# features
raw_features  = pd.read_csv(data_file, sep = "\t", index_col = "pattern_id")

In [5]:
# labels 
raw_labels = pd.read_csv(class_file, sep = "\t", header = None, index_col = 0)

# samples to include - just an index
include = pd.read_csv(include_file, sep = "\t", header = None, index_col = 0)

In [6]:
# find intersection between sample list and those included in features
inc = include.index.values
in_feat = raw_features.columns

intersect = []
for i in include.index.values:
    if i in in_feat:
        intersect.append(i)
        
print("original samples to include:", len(inc))
print("samples present in data:", len(intersect))

original samples to include: 35
samples present in data: 35


In [7]:
# filter feature/labels on include
new_features = raw_features.loc[:, intersect]

In [8]:
# order labels on colnames of features
lab_ind = pd.DataFrame(index=new_features.columns)
new_labels = lab_ind.join(raw_labels)[1]

print(new_features.shape)
print(new_labels.shape)

(94860, 35)
(35,)


## load required prefiltering methods 

In [9]:
# random forest to select features based on GINI
class RF_FS:
     
    def __init__(self, n=100, n_estimators=1000, random_state=0, n_jobs=-1):
        self.name = 'RF_FS'
        self.scores = None
        self.n = n
        self.n_estimators = n_estimators
        self.n_jobs = n_jobs
        self.random_state = random_state
        
    def classifier(self):
        return(self.name)
    
    def fit(self, feat, lab):
        
        # fit model
        fs = RandomForestClassifier(n_estimators=self.n_estimators, random_state=self.random_state, n_jobs=self.n_jobs)
        fs.fit(feat, lab)
      
        # make score / feature id dataframe 
        self.scores = pd.DataFrame(index = feat.columns.values)
        self.scores['scores'] = copy.copy(fs.feature_importances_)
        
        # sort scores
        self.scores.sort_values(by=['scores'], ascending=False, inplace=True)

        return(self)
    
    def transform(self, feat):
        
        # select top scores /transform input data
        ind = self.scores['scores'][0:self.n].index.values # get indices of top hits
        out_feat  = feat.loc[:,ind] # select input cols on indices
      
        # return no_selected top features
        return(copy.copy(out_feat))
    
    def fit_transform(self, feat, lab):
        self.fit(feat, lab)
        out_feat = self.transform(feat)
        return(out_feat)
    
    def plot(self):

        # plot sorted scores
        x=np.r_[0:len(self.scores['scores'])]
        plt.plot(x, self.scores['scores'])
        plt.vlines(x = self.n, ymin = 0, ymax = max(self.scores['scores']), 
               colors = 'red') 

In [10]:
# load model and associated data 
pkl = "/home/sbayliss/Desktop/projects/PHE_Salmonella/notebooks/github/HierarchicalML/optimised_model/optimised_model_data.pkl"
with gzip.open(pkl, 'rb') as f:
    
    models = pickle.load(f)
    train_features = pickle.load(f)
    graph = pickle.load(f)

In [11]:
input_converted = copy.copy(new_features.loc[train_features].T)
print(input_converted.shape)
input_converted

(35, 94860)


pattern_id,1,2,3,4,5,6,7,8,12,13,...,110471,110472,110474,110476,110477,110479,110480,110481,110482,110483
ERR2278721,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278723,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278728,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278729,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278733,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278739,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278742,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278743,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278749,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0
ERR2278713,1,1,1,1,1,1,1,1,0,1,...,0,0,1,0,0,0,0,0,0,0


## classify samples

In [12]:
# classify samples
timeb4 = time.time()

# classify test sample
(classification_table_test, classifications_test) = classify_samples_in_hierarchy(graph, input_converted, 
                                                                           models,
                                                                           mode = mode, 
                                                                           threshold = threshold)


timeafter = time.time()
total_seconds = timeafter-timeb4
print("classification took ", total_seconds, " seconds")

classification took  0.7167212963104248  seconds


In [13]:
# generate per class statistics
(summary_test, summary_table_test) = summary_statistics_per_class(graph, new_labels, classifications_test,
                                                                         penalty=False, macro_inc_all = False)
#print( summary_test, summary_table_test) 
print(summary_test['macro'])
print(summary_test['micro'])

{'hR': 0.8015873015873017, 'hP': 0.8015873015873017, 'F1': 0.8015873015873017}
{'hR': 0.7213114754098361, 'hP': 0.5641025641025641, 'F1': 0.6330935251798561}


In [14]:
# convert class dist into table 

# create df
col_names = ["hP", "hR", "hF1", "n", "nP", "nT", "root_dist"]
class_table =  pd.DataFrame(
    columns = col_names,
    index = summary_test['per_class'])

# fill table
for i in summary_test['per_class']:
    for j in col_names:
        class_table.loc[i, j] = summary_test['per_class'][i][j]
    
# sort on class names and root_dist 
class_table.sort_index(inplace=True)
class_table.sort_values(by=['root_dist'], inplace=True)

In [15]:
# save per class hierachical summary values 
class_table.to_csv(( "%sper_class.tsv" % output_file ), sep = "\t", header = True)

In [16]:
# overall hier stats
(Hsummary) = overall_summary_stats(new_labels, classifications_test, graph, penalty=False)
print(Hsummary)

{'hR': 0.6761904761904762, 'hP': 0.6761904761904762, 'hF1': 0.6761904761904762, 'hAcc': 0.6761904761904761}


In [17]:
# classification per sample 
classification_table_renamed = pd.concat([raw_labels.rename(columns = {1:"labels"}), classification_table_test], axis=1).loc[new_labels.index]

classification_table_renamed.to_csv(( "%sper_sample.tsv" % output_file ), sep = "\t", header = True)

classification_table_renamed

Unnamed: 0,labels,classification,Americas,Europe,Africa,Asia,Latin America and the Caribbean,Northern America,Eastern Europe,Western Europe,...,Turkey,United Arab Emirates,India,Pakistan,Sri lanka,Indonesia,Malaysia,Singapore,Thailand,Vietnam
ERR2278721,Poland,Spain,0.006667,0.7985,0.0,0.194833,,,0.325,0.0,...,,,,,,,,,,
ERR2278723,Poland,Poland,0.0,0.921889,0.0,0.078111,,,0.511667,0.02,...,,,,,,,,,,
ERR2278728,Poland,Poland,0.0,0.879964,0.006667,0.113369,,,0.585,0.06,...,,,,,,,,,,
ERR2278729,Poland,Poland,0.0,0.8805,0.0,0.1195,,,0.525,0.04,...,,,,,,,,,,
ERR2278733,Poland,Spain,0.003333,0.990833,0.0,0.005833,,,0.1,0.01,...,,,,,,,,,,
ERR2278739,Poland,Poland,0.0,0.9025,0.0,0.0975,,,0.631667,0.03,...,,,,,,,,,,
ERR2278742,Poland,Poland,0.0,0.89925,0.0,0.10075,,,0.56,0.04,...,,,,,,,,,,
ERR2278743,Poland,Poland,0.0,0.901952,0.0,0.098048,,,0.75,0.02,...,,,,,,,,,,
ERR2278749,Poland,Poland,0.01,0.880536,0.014286,0.095179,,,0.79,0.03,...,,,,,,,,,,
ERR2278713,Poland,Poland,0.0,0.596028,0.09225,0.311722,,,0.67,0.0,...,,,,,,,,,,


In [18]:
# save classified samples and summary data
hsum_file = "%shsummary.tab" % output_file

# Saving the reference of the standard output
original_stdout = sys.stdout    
 
with open(hsum_file, 'w') as f:
    
    # set stdout
    sys.stdout = f 
    
    # print data
    for i in Hsummary:
        print(i, "\t", Hsummary[i] )
    
    # Reset the standard output
    sys.stdout = original_stdout 

In [19]:
# use standard assessment metrics per node (non-hierarchical)
(test_summary_per_node, test_summary_per_class, test_clf_tables) = per_node_summary_stats(graph, new_labels, input_converted, models, verbose = True)
test_summary_per_class.to_csv(( "%snonhier_per_class.tsv" % output_file ), sep = "\t", header = True)

classes in model:['Africa' 'Americas' 'Asia' 'Europe']
classes in model:['Northern Africa' 'Sub-Saharan Africa']
classes in model:['Latin America and the Caribbean' 'Northern America']
classes in model:['Eastern Asia' 'South-eastern Asia' 'Southern Asia' 'Western Asia']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


classes in model:['Eastern Europe' 'Southern Europe' 'Western Europe']
classes in model:['Bulgaria' 'Czech republic' 'Hungary' 'Poland' 'Russian federation']
classes in model:['Barbados' 'Cuba' 'Dominica' 'Dominican republic' 'Jamaica' 'Mexico']
classes in model:['Egypt' 'Morocco' 'Tunisia']
classes in model:['Indonesia' 'Malaysia' 'Singapore' 'Thailand' 'Vietnam']
classes in model:['India' 'Pakistan' 'Sri lanka']
classes in model:['Greece' 'Italy' 'Malta' 'Portugal' 'Spain']
classes in model:['Cape verde' 'Kenya' 'South africa' 'Tanzania']
classes in model:['Cyprus' 'Saudi arabia' 'Turkey' 'United Arab Emirates']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
