<a href="https://colab.research.google.com/github/Cathy-Guang/TB-drug-resistance/blob/main/SLE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# example of a super learner model for binary classification
from numpy import hstack
from numpy import vstack
from numpy import asarray
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from xgboost import XGBClassifier

In [14]:
#create a list of base models
def get_models():
  models = []
  models.append(SVC(kernel='rbf',probability=True,random_state=0))
  models.append(XGBClassifier(max_depth=9, eta = 0.4, gamma = 3))
  return models

In [3]:
# collect out of fold predictions form k-fold cross validation
def get_out_of_fold_predictions(X, y, models):
	meta_X, meta_y = list(), list()
	# define split of data
	kfold = KFold(n_splits=10, shuffle=True)
	# enumerate splits
	for train_ix, test_ix in kfold.split(X):
		fold_yhats = list()
		# get data
		train_X, test_X = X[train_ix], X[test_ix]
		train_y, test_y = y[train_ix], y[test_ix]
		meta_y.extend(test_y)
		# fit and make predictions with each sub-model
		for model in models:
			model.fit(train_X, train_y)
			yhat = model.predict_proba(test_X)
			# store columns
			fold_yhats.append(yhat)
		# store fold yhats as columns
		meta_X.append(hstack(fold_yhats))
	return vstack(meta_X), asarray(meta_y)

In [4]:
# fit all base models on the training dataset
def fit_base_models(X, y, models):
	for model in models:
		model.fit(X, y)

In [5]:
# fit a meta model
def fit_meta_model(X, y):
	model = LogisticRegression(solver='liblinear')
	model.fit(X, y)
	return model

In [6]:
# make predictions with stacked model
def super_learner_predictions(X, models, meta_model):
	meta_X = list()
	for model in models:
		yhat = model.predict_proba(X)
		meta_X.append(yhat)
	meta_X = hstack(meta_X)
	# predict
	return meta_model.predict_proba(meta_X)

In [15]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import os
import io
import pandas as pd

GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'TB-drug-resistance/' #change this to whatever folder your code is in
GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)

#TODO: this directory contains the csv with all results generated in this notebook will go here
RESULTS_PATH = os.path.join(GOOGLE_DRIVE_PATH, "results/SLE_results/PCA")

Mounted at /content/drive


In [18]:
#List of drugs
drugs = ['AMK','CAP','EMB','INH','KAN','MOXI','OFLX','PZA','RIF','STR']

#Read in datasets
for drug in drugs:
  #Read in train and labels
  dataset = pd.read_csv(os.path.join(GOOGLE_DRIVE_PATH, 'PCA/X_trainData_column_modified_PCA_'+drug+'.csv'))
  X =dataset.iloc[:,:-1].values
  y =dataset.iloc[:,-1].values

  # get models
  models = get_models()
  # get out of fold predictions
  meta_X, meta_y = get_out_of_fold_predictions(X, y, models)
  print('Meta ', meta_X.shape, meta_y.shape)
  # fit the base model
  fit_base_models(X, y, models)
  # fit the meta model
  meta_model = fit_meta_model(meta_X, meta_y)

  test_data = pd.read_csv(os.path.join(GOOGLE_DRIVE_PATH, 'test_data/X_testData_column_modified_PCA_'+drug+'.csv'))
  X_test = test_data.iloc[:,:].values
  y_pred = super_learner_predictions(X_test, models, meta_model)

  submission_data = pd.read_csv(os.path.join(GOOGLE_DRIVE_PATH, 'test_data/Y_testData_1_nolabels_'+drug+'.csv'), delimiter=',')
  ids=submission_data.iloc[:,:1].values
  result=list()
  for i in range(len(ids)):
    result.append([ ids[i,0], y_pred[i,1]])
  result.insert(0, ['ID',drug])

  #Save results
  output = pd.DataFrame(result)
  output.to_csv(os.path.join(RESULTS_PATH, drug+'_results.csv'),header=False, index = False)



Meta  (1360, 4) (1360,)
Meta  (1339, 4) (1339,)
Meta  (3319, 4) (3319,)
Meta  (3356, 4) (3356,)
Meta  (1283, 4) (1283,)
Meta  (1337, 4) (1337,)
Meta  (690, 4) (690,)
Meta  (2941, 4) (2941,)
Meta  (3335, 4) (3335,)
Meta  (2081, 4) (2081,)
