In [1]:
import os
import sys
import time

import random

import numpy as np
import pandas as pd

from sklearn import metrics

In [2]:
from sklearn.ensemble import RandomForestClassifier

In [3]:
from utils_extracting_encoded_data import _utils_simplified

In [4]:
random.seed(1000) 
np.random.seed(1216)

In [5]:
hla_class = "HLA_I"
data_dir = "../data/HLA_I_all_match/train_valid"
enc_method ="one_hot"

In [6]:
((encoded_train, y2_train, n_pos_train, n_neg_train),
 (encoded_valid, y2_valid, n_pos_valid, n_neg_valid)) = \
    _utils_simplified.get_data_flatten(hla_class, data_dir, enc_method, False)

In [7]:
encoded_train.shape

(23124, 2044)

In [8]:
encoded_valid.shape

(7710, 2044)

In [9]:
label_train = y2_train.ravel()
label_train

array([0, 0, 0, ..., 1, 0, 0])

In [10]:
sum(y2_train)

array([3854])

In [11]:
test_pos_file = "../data/HLA_I_all_match/test/test_pos.csv"
test_neg_file = "../data/HLA_I_all_match/test/test_neg.csv"

encoded_test_pos = _utils_simplified.get_pure_data_flatten(hla_class, test_pos_file, enc_method, False)
encoded_test_neg = _utils_simplified.get_pure_data_flatten(hla_class, test_neg_file, enc_method, False)

In [12]:
train_auc = []
valid_auc = []
test_auc = []

cnt = 0
best_valid_auc = 0

for i in range(1,50):
    
    if cnt == 2:
        print("Validation auc hasn't improved for 2 steps of 100 additional trees. ")
        break
        
    n_trees = 100*i
    
    print("n_trees:", n_trees)
    clf = RandomForestClassifier(class_weight={0:3854,1:19270}, 
                                 n_estimators=n_trees,
                                 random_state=0)

    clf.fit(encoded_train, label_train)

    y_hat_train = clf.predict_proba(encoded_train)

    fpr, tpr, thresholds = metrics.roc_curve(label_train, y_hat_train[:, [1]].ravel(), pos_label=1)
    print("train AUC")
    cur_auc = metrics.auc(fpr, tpr)
    print(cur_auc)
    train_auc += [cur_auc]


    y_hat_valid = clf.predict_proba(encoded_valid)
    label_valid = y2_valid.ravel()

    fpr, tpr, thresholds = metrics.roc_curve(label_valid, y_hat_valid[:, [1]].ravel(), pos_label=1)
    print("valid AUC")
    cur_auc = metrics.auc(fpr, tpr)
    print(cur_auc)
    valid_auc += [cur_auc]
    
    if cur_auc <= best_valid_auc:
        cnt += 1
    else:
        cnt = 0
        best_valid_auc = cur_auc

n_trees: 100
train AUC
1.0
valid AUC
0.7773493012763251
n_trees: 200
train AUC
1.0
valid AUC
0.7786599040106588
n_trees: 300
train AUC
1.0
valid AUC
0.7812337507002376
n_trees: 400
train AUC
1.0
valid AUC
0.7806387984678043
n_trees: 500
train AUC
1.0
valid AUC
0.7807370285696982
Validation auc hasn't improved for 2 steps of 100 additional trees. 


In [13]:
train_auc

[1.0, 1.0, 1.0, 1.0, 1.0]

In [14]:
valid_auc

[0.7773493012763251,
 0.7786599040106588,
 0.7812337507002376,
 0.7806387984678043,
 0.7807370285696982]

### Pick the best model when validation auc hasn't increased for two more addings of trees

In [15]:
n_trees = 100*3

print("n_trees:", n_trees)
clf = RandomForestClassifier(class_weight={0:3854,1:19270}, 
                             n_estimators=n_trees,
                             random_state=0)

clf.fit(encoded_train, label_train)

n_trees: 300


RandomForestClassifier(class_weight={0: 3854, 1: 19270}, n_estimators=300,
                       random_state=0)

In [16]:
y_hat_train = clf.predict_proba(encoded_train)

fpr, tpr, thresholds = metrics.roc_curve(label_train, y_hat_train[:, [1]].ravel(), pos_label=1)
print("train AUC")
cur_auc = metrics.auc(fpr, tpr)
print(cur_auc)


y_hat_valid = clf.predict_proba(encoded_valid)
label_valid = y2_valid.ravel()

fpr, tpr, thresholds = metrics.roc_curve(label_valid, y_hat_valid[:, [1]].ravel(), pos_label=1)
print("valid AUC")
cur_auc = metrics.auc(fpr, tpr)
print(cur_auc)


y_hat_test_pos = clf.predict_proba(encoded_test_pos)
y_hat_test_neg = clf.predict_proba(encoded_test_neg)

y_test = np.array([1 for _ in range(encoded_test_pos.shape[0])] + [0 for _ in range(encoded_test_neg.shape[0])])
y_hat_test = np.concatenate((y_hat_test_pos[:, [1]].ravel(), 
                             y_hat_test_neg[:, [1]].ravel()))

fpr, tpr, thresholds = metrics.roc_curve(y_test, y_hat_test, pos_label=1)
print("test AUC")
cur_auc = metrics.auc(fpr, tpr)
print(cur_auc)

train AUC
1.0
valid AUC
0.7812337507002376
test AUC
0.772590825496647


In [17]:
szeto_file = "../data/Szeto_2020/HLA_I_szeto_2020_compatible_pairs.csv"
encoded_szeto = _utils_simplified.get_pure_data_flatten(hla_class, szeto_file, enc_method, False)
encoded_szeto.shape

(54, 2044)

In [18]:
y_hat_szeto = clf.predict_proba(encoded_szeto)

y_test_szeto = np.array([1 for _ in range(encoded_szeto.shape[0])] + [0 for _ in range(encoded_test_neg.shape[0])])
y_hat_test_szeto = np.concatenate((y_hat_szeto[:, [1]].ravel(), 
                                   y_hat_test_neg[:, [1]].ravel()))

In [19]:
fpr, tpr, thresholds = metrics.roc_curve(y_test_szeto, y_hat_test_szeto, pos_label=1)
print("szeto v.s. negative test AUC")
cur_auc = metrics.auc(fpr, tpr)
print(cur_auc)

szeto v.s. negative test AUC
0.6785897657782393


In [20]:
df_szeto = pd.read_csv(szeto_file, sep=",", header=0)
df_szeto.shape

(54, 2)

In [33]:
df_szeto['v_allele'] = [x.split(",")[0] for x in df_szeto.tcr.tolist()]
df_szeto["len"] = [len(x.split(",")[1]) for x in df_szeto.tcr.tolist()]
df_szeto['random_forest_score'] = y_hat_szeto[:, [1]].ravel()

In [34]:
df_test_pos = pd.read_csv(test_pos_file, sep=",", header=0)
df_test_neg = pd.read_csv(test_neg_file, sep=",", header=0)

df_test_pos['v_allele'] = [x.split(",")[0] for x in df_test_pos.tcr.tolist()]
df_test_neg['v_allele'] = [x.split(",")[0] for x in df_test_neg.tcr.tolist()]

df_test_pos["len"] = [len(x.split(",")[1]) for x in df_test_pos.tcr.tolist()]
df_test_neg["len"] = [len(x.split(",")[1]) for x in df_test_neg.tcr.tolist()]

df_test_pos['random_forest_score'] = y_hat_test_pos[:, [1]].ravel()
df_test_neg['random_forest_score'] = y_hat_test_neg[:, [1]].ravel()

In [37]:
df_szeto.to_csv("../results/st3_szeto_random_forest.csv", index=False)
df_test_pos.to_csv("../results/st3_test_pos_random_forest.csv", index=False)
df_test_neg.to_csv("../results/st3_test_neg_random_forest.csv", index=False)