Purpose: Troubleshoot code for peripheral stress RF.<br>
Author: Anna Pardo<br>
Date initiated: Aug. 24, 2023

In [1]:
import os
import argparse
import random
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import RandomizedSearchCV
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import ClusterCentroids
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import normalize
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import ShuffleSplit
from sklearn.preprocessing import StandardScaler

In [2]:
def load_clean_data(path_to_tpm,single_stress):
    """
    Args:
        path_to_tpm = full path to file containing raw TPM, columns for Sample, BioProject, & Treatment
        single_stress = the stressor that will be class 0
    """
    # load the TPM data
    raw_tpm = pd.read_csv(path_to_tpm,sep="\t",header="infer")
    # replace DroughtRepeat with Drought
    raw_tpm["Treatment"].mask(raw_tpm["Treatment"]=="DroughtRepeat","Drought",inplace=True)
    # drop control samples
    raw_tpm = raw_tpm[raw_tpm["Treatment"]!="Control"]
    # labeling: set single_stress to 0 and any stress to 1
    proxy = []
    for i in range(len(raw_tpm.index)):
        if raw_tpm.iloc[i,raw_tpm.columns.get_loc("Treatment")] == single_stress:
            proxy.append(0)
        else:
            proxy.append(1)
    raw_tpm["Label"] = proxy
    # return the dataframe
    return raw_tpm

def variance_threshold_selector(data):
    selector = VarianceThreshold()
    selector.fit(data)
    return data[data.columns[selector.get_support(indices=True)]]

In [3]:
def check_if_balanced(labeled_tpm):
    """
    Args:
        labeled_tpm = raw TPM with columns for Sample, BioProject, Label, Treatment
    """
    if labeled_tpm["Label"].value_counts()[0] == labeled_tpm["Label"].value_counts()[1]:
        return True
    else:
        return False

def downsample(dataframe):
    """
    Args:
        dataframe = a log TPM dataframe with a Label column and Sample set as the index
    """
    # generate a variable of value counts
    vc = dataframe["Label"].value_counts()

    # subset data to only samples labeled 1
    ones_only = dataframe[dataframe["Label"]==1]

    # downsample from the subsetted dataframe
    ds = ones_only.sample(n=vc[1],random_state=42)

    # subset original data to control samples
    zeroes = dataframe[dataframe["Label"]==0]

    # concatenate controls and downsampled stress samples
    downsampled = pd.concat([ds,zeroes])
    # return dataframe
    return downsampled

In [4]:
def pre_split_transform(raw_tpm,balanced,downsample=False):
    """
    Args:
        raw_tpm = dataframe containing raw TPM values, columns for Sample, BioProject, Treatment, Label
        balanced = Boolean variable, True or False (result of check_if_balanced())
        downsample = Boolean variable, True or False, default False (set manually outside function)
    """
    # if data have treatment column, drop it
    #if "Treatment" in raw_tpm.columns:
    #    raw_tpm = raw_tpm.drop("Treatment",axis=1)
    # temporarily, set index to Sample and drop BioProject, Label, & Treatment columns
    blt = raw_tpm[["Sample","BioProject","Treatment","Label"]]
    tpmi = raw_tpm.set_index("Sample").drop(["BioProject","Treatment","Label"],axis=1)
    # remove zero-variance genes
    vttpm = variance_threshold_selector(tpmi)
    # log-transform TPM
    vttpm_log = vttpm.apply(lambda x: np.log2(x+1))
    # downsample data if needed
    if balanced!=True:
        if downsample==True:
            # add back labels
            vttpm_log = blt[["Sample","Label"]].merge(vttpm_log.reset_index().rename(columns={"index":"Sample"}))
            # set Sample as index
            vttpm_log = vttpm_log.set_index("Sample")
            # downsample the data
            vttpm_log = downsample(vttpm_log)
    # add treatment, labels, and BioProject back in, set Sample as the index again
    labeled = blt.merge(vttpm_log.reset_index().rename(columns={"index":"Sample"}))
    labeled.set_index("Sample",inplace=True)
    # drop rows containing NaN values
    labeled = labeled.dropna(axis=0)
    # return dataframe
    return labeled

In [18]:
def split_prep_peripheral(c0_bpstr,dataframe,balance="Up"):
    """
    Args:
        c0_bpstr = string containing BioProject(s) (delimited with commas if more than one) to hold out for testing for class 0
        dataframe = log TPM dataframe with Sample, Label, BioProject, Treatment columns (or Sample as index)
        balance = str: "none", "Up" (downsampling will be done before splitting, outside of this function)
    """

    # in case Sample isn't already a column, reset the index and rename the column to Sample
    if "Sample" not in dataframe.columns:
        dataframe = dataframe.reset_index().rename(columns={"index":"Sample"})
    # parse the BioProject string for class 0 test set
    if "," in c0_bpstr:
        c0bp = c0_bpstr.split(",")
    else:
        c0bp = [c0_bpstr]
    # find what percentage of total samples for class 0 these BioProjects contain
    bpsamp = dataframe[dataframe["BioProject"].isin(c0bp)]
    if len(bpsamp["Treatment"].unique())>1:
        bpsamp = bpsamp[bpsamp["Label"]==0]
    nsamp_test = len(bpsamp["Sample"].unique())
    nsamp_all = len(dataframe[dataframe["Label"]==0]["Sample"].unique())
    percent = nsamp_test/nsamp_all
    # how many samples in class 1 comprise this percentage of class 1, +/- 5%?
    class1 = dataframe[dataframe["Label"]==1]
    upperp = percent + 0.02
    lowerp = percent - 0.02
    print("Upper percentage:",upperp)
    print("Lower percentage:",lowerp)
    # these are our upper and lower bounds of acceptable sample numbers

    # for class 1, set up a dataframe of numbers of samples in each BioProject
    bp = []
    nsamp = []
    pct = []
    for b in list(class1["BioProject"].unique()):
        bp.append(b)
        df = class1[class1["BioProject"]==b]
        nsamp.append(len(df["Sample"].unique()))
        pct.append(len(df["Sample"].unique())/len(class1["Sample"].unique()))
    nsbp = pd.DataFrame(list(zip(bp,nsamp,pct)),columns=["BioProject","N_samples","Percentage"])

    # pick BioProjects for the testing set, the sum of whose samples fall within the acceptable bounds
    inlist = random.sample(list(nsbp["BioProject"]),len(list(nsbp["BioProject"])))
    bpfortest = []
    provisionalbp = []
    provisionalpct = []
    toohigh = []
    for b in inlist:
        p = float(nsbp.loc[nsbp["BioProject"]==b,"Percentage"])
        if len(provisionalbp)==0:
            print(len(provisionalbp))
            if upperp>p>lowerp:
                bpfortest.append(b)
                break
            elif p<lowerp:
                provisionalbp.append(b)
                provisionalpct.append(p)
            elif p>upperp:
                toohigh.append(b)
        else:
            print(sum(provisionalpct))
            if upperp>sum(provisionalpct)+p>lowerp:
                provisionalbp.append(b)
                bpfortest = provisionalbp
                break
            elif sum(provisionalpct)+p<lowerp:
                provisionalbp.append(b)
                provisionalpct.append(p)
            elif p>upperp:
                toohigh.append(b)
    #print(bpfortest)

    # for the future: find what actual percent of the class 1 data the testing BioProjects comprise
    tc1 = class1[class1["BioProject"].isin(bpfortest)]
    n = len(tc1["Sample"].unique())
    pc1 = (n/len(class1["Sample"].unique()))*100
    # also save the percentage of the class 0 BioProjects
    pc0 = percent*100

    # combine an omnibus list of the BioProjects comprising the test set for both classes
    alltestbp = bpfortest+c0bp

    # split test from train data
    test = dataframe[dataframe["BioProject"].isin(alltestbp)]
    # pull out training data
    train = dataframe[~dataframe["Sample"].isin(test["Sample"])]
    # for both sets, make Sample the index again
    test = test.set_index("Sample")
    train = train.set_index("Sample")
    # drop BioProject and Treatment columns from both sets
    test = test.drop(["BioProject","Treatment"],axis=1)
    train = train.drop(["BioProject","Treatment"],axis=1)
    # generate X_train, X_test, y_train, and y_test
    ## where X = gene expression values and y = class labels
    train_X = train.drop("Label",axis=1)
    y_train = train["Label"]
    test_X = test.drop("Label",axis=1)
    y_test = test["Label"]
    # if upsampling: do the upsampling using SMOTE
    if balance=="Up":
        sm = SMOTE(random_state=42)
        train_X, y_train = sm.fit_resample(train_X,y_train)
    # for X_train and X_test: scale data to a z-score
    scalar = StandardScaler()
    X_train = scalar.fit_transform(train_X)
    X_test = scalar.fit_transform(test_X)
    # return training and test data as well as percentages (for saving with the hyperparameters later)
    return X_train, y_train, X_test, y_test, pc0, pc1

In [6]:
def get_tuned_rf(X_train, y_train, random_grid):
    rf = RandomForestClassifier()
    rf_random = RandomizedSearchCV(estimator=rf,
                                  param_distributions=random_grid,
                                  n_iter=100,
                                  cv=5,
                                  verbose=2,
                                  random_state=42,
                                  n_jobs=-1)
    rf_random.fit(X_train, y_train)
    hyper = rf_random.best_params_
    rfclf_tune = RandomForestClassifier(n_estimators=hyper["n_estimators"],
                                min_samples_split=hyper["min_samples_split"],
                                    min_samples_leaf=hyper["min_samples_leaf"],
                                   max_features=hyper["max_features"],
                                   max_depth=hyper["max_depth"],
                                   bootstrap=hyper["bootstrap"])
    return rfclf_tune

In [7]:
def get_scores(y_test,y_pred,filename,AUC,bal,sampling,pc0,pc1):
    """
    Args:
        y_test = true labels of test data
        y_pred = predicted labels of test data
        filename = file name with extension (JSON)
        AUC = float, AUC score output from ROC curve calculations
        bal = whether the data are balanced (True or False)
        sampling = Up or Down
        pc0 = float, percentage of class 0 data used in test set
        pc1 = float, percentage of class 1 data used in test set
    """
    f1 = list(f1_score(y_test,y_pred,average=None))
    prec = list(precision_score(y_test,y_pred,average=None))
    rec = list(recall_score(y_test,y_pred,average=None))
    # construct dictionary of F1 and accuracy scores, precision, and recall
    scores = {"Accuracy":accuracy_score(y_test,y_pred),"F1_class_0":f1[0],"F1_class_1":f1[1],
            "Precision_class_0":prec[0],"Precision_class_1":prec[1],"Recall_class_0":rec[0],"Recall_class_1":rec[1],"AUC":AUC,"Data Balanced":bal,
            "Sampling":sampling,"Percent class 0 in test":pc0,"Percent class 1 used in test":pc1}
    # write dictionary to a JSON file
    with open(filename,"w+") as outfile:
        json.dump(scores,outfile,indent=4)

In [8]:
# load the data
## use Heat as example class 0
tpm_file = "../../data/rawtpm_bptreat_noPEG.tsv"
single_stress = "Heat"
cleaned_tpm = load_clean_data(tpm_file,single_stress)
cleaned_tpm.head()

Unnamed: 0,Sample,BioProject,Treatment,Zm00001eb000010,Zm00001eb000020,Zm00001eb000050,Zm00001eb000060,Zm00001eb000070,Zm00001eb000080,Zm00001eb000100,...,Zm00001eb442820,Zm00001eb442840,Zm00001eb442850,Zm00001eb442870,Zm00001eb442890,Zm00001eb442910,Zm00001eb442960,Zm00001eb442980,Zm00001eb443030,Label
0,SRR11933261,PRJNA637522,Drought,12.553818,2.321077,0.04252,12.932676,5.253755,11.105837,0.409268,...,0.0,0.0,0.0,0.0,0.309501,0.0,0.0,0.0,0.0,1
1,SRR11933272,PRJNA637522,Drought,16.255838,3.110372,0.405226,7.214039,1.902461,2.346186,0.170305,...,0.127878,0.0,0.0,0.0,6.703281,0.0,0.0,0.0,0.0,1
2,SRR11933250,PRJNA637522,Drought,9.028815,2.984479,0.0,3.092442,2.586555,16.186141,0.0,...,0.0,0.0,0.0,0.0,0.417565,0.0,0.254123,0.0,1.213349,1
4,SRR11933040,PRJNA637522,Drought,10.371251,2.799099,0.0,1.280629,3.771234,19.717683,0.143764,...,0.012158,0.0,0.0,0.0,9.625225,0.0,0.0,0.0,2.352959,1
5,SRR11932822,PRJNA637522,Drought,37.430009,27.508819,0.0,29.510498,7.005587,0.367545,0.314919,...,0.0,0.0,0.287114,0.0,0.0,0.0,0.0,0.0,1.604105,1


In [11]:
len(cleaned_tpm.index)

1173

In [12]:
bal = check_if_balanced(cleaned_tpm)
bal

False

In [13]:
sampling = "Up"
if sampling == "Down":
    ds = True
    us = "none"
elif sampling == "Up":
    ds = False
    us = "Up"
else:
    ds = False
    us = "none"

In [16]:
log_tpm = pre_split_transform(cleaned_tpm,bal,ds)
log_tpm.head()

Unnamed: 0_level_0,BioProject,Treatment,Label,Zm00001eb000010,Zm00001eb000020,Zm00001eb000050,Zm00001eb000060,Zm00001eb000070,Zm00001eb000080,Zm00001eb000100,...,Zm00001eb442810,Zm00001eb442820,Zm00001eb442840,Zm00001eb442850,Zm00001eb442870,Zm00001eb442890,Zm00001eb442910,Zm00001eb442960,Zm00001eb442980,Zm00001eb443030
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR11933261,PRJNA637522,Drought,1,3.760627,1.731651,0.060075,3.8004,2.644723,3.597631,0.494946,...,0.227968,0.0,0.0,0.0,0.0,0.389017,0.0,0.0,0.0,0.0
SRR11933272,PRJNA637522,Drought,1,4.109013,2.039269,0.490802,3.038092,1.537277,1.742518,0.226885,...,0.148026,0.173611,0.0,0.0,0.0,2.945473,0.0,0.0,0.0,0.0
SRR11933250,PRJNA637522,Drought,1,3.326079,1.994391,0.0,2.032962,1.842599,4.103174,0.0,...,0.0,0.0,0.0,0.0,0.0,0.503415,0.0,0.326679,0.0,1.146231
SRR11933040,PRJNA637522,Drought,1,3.507319,1.925657,0.0,1.189432,2.254362,4.372791,0.193789,...,0.236712,0.017435,0.0,0.0,0.0,3.409421,0.0,0.0,0.0,1.745435
SRR11932822,PRJNA637522,Drought,1,5.264161,4.833336,0.0,4.931234,3.001007,0.451588,0.394974,...,0.178043,0.0,0.0,0.36414,0.0,0.0,0.0,0.0,0.0,1.380788


In [19]:
X_train, y_train, X_test, y_test, pc0, pc1 = split_prep_peripheral(single_stress,log_tpm,us)

Upper percentage: 0.02
Lower percentage: -0.02
0


In [24]:
pc1

0.6085192697768762

In [25]:
log_tpm[log_tpm["BioProject"].isin(["PRJNA906711","PRJNA291919"])]

Unnamed: 0_level_0,BioProject,Treatment,Label,Zm00001eb000010,Zm00001eb000020,Zm00001eb000050,Zm00001eb000060,Zm00001eb000070,Zm00001eb000080,Zm00001eb000100,...,Zm00001eb442810,Zm00001eb442820,Zm00001eb442840,Zm00001eb442850,Zm00001eb442870,Zm00001eb442890,Zm00001eb442910,Zm00001eb442960,Zm00001eb442980,Zm00001eb443030
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR2144414,PRJNA291919,Drought,1,4.058038,2.036268,0.000000,3.577774,0.775624,4.478831,0.0,...,0.371722,0.100631,0.0,0.0,0.000000,0.862596,0.0,1.481934,0.0,1.872555
SRR2144415,PRJNA291919,Drought,1,3.547407,1.870083,0.000000,3.218024,1.772667,3.905955,0.0,...,0.361880,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.0,0.790822
SRR2144416,PRJNA291919,Drought,1,3.910170,1.370586,0.000000,3.738116,1.986165,4.266589,0.0,...,0.602690,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.0,0.904620
SRR2144417,PRJNA291919,Drought,1,4.131511,1.730484,0.243162,4.034691,0.902260,4.563333,0.0,...,0.000000,0.074362,0.0,0.0,0.000000,1.220902,0.0,0.000000,0.0,1.504433
SRR2144418,PRJNA291919,Drought,1,2.388742,7.394438,0.396618,4.749235,2.806444,5.637815,0.0,...,0.000000,2.357471,0.0,0.0,0.664503,0.000000,0.0,0.000000,0.0,1.938114
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CML228D1D,PRJNA906711,Drought,1,1.279532,0.000000,0.000000,0.000000,0.000000,1.386638,0.0,...,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.0,0.000000
CML333D3D,PRJNA906711,Drought,1,3.403289,0.000000,0.000000,0.000000,1.013190,2.675093,0.0,...,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.0,0.000000
P39D3D,PRJNA906711,Drought,1,2.402443,0.832973,0.000000,0.000000,0.000000,0.200860,0.0,...,0.000000,0.221100,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.0,0.000000
M162WD3D,PRJNA906711,Drought,1,2.149165,0.464142,0.000000,0.000000,1.411628,1.530760,0.0,...,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.0,0.000000
