In [None]:
import math
import pandas as pd
import numpy as np
import random as rd
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

SMALL_SIZE = 12
MEDIUM_SIZE = 16
BIGGER_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
def create_bar_plot(data, x_title, y_title):
    fig = plt.figure(figsize=(9, 8))
    ax = fig.add_subplot(111)
    ax.hist(x=data, bins='auto', alpha=0.7, rwidth=0.85)
    ax.grid(False)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    plt.show()
    return fig

In [None]:
#Separating data for each drug/cell

def get_pos_map(obj_list, test_df, col):
    pos_map = {obj:[] for obj in obj_list}
    for i, row in test_df.iterrows():
        pos_map[row[col]].append(i)
    return pos_map

In [None]:
#Arrange the obj_list in the descending order of variance

def sort_var(obj_list, var_list):
    var_map = {}
    for i, obj in enumerate(obj_list):
        var_map[obj] = var_list[i]
    return {obj:var for obj,var in sorted(var_map.items(), key=lambda item:item[1], reverse=True)}

In [None]:
def calc_stddev(obj_list, train_df, col):
    var_list = [0.0] * len(obj_list)
    pos_map = get_pos_map(obj_list, train_df, col)
    for i, obj in enumerate(obj_list):
        train_vals = np.take(train_df['auc'], pos_map[obj])
        var_list[i] = np.std(train_vals)
    return sort_var(obj_list, var_list)

In [None]:
def filter_data(train_df, obj_list, col, threshold):
    
    stddev_map = calc_stddev(obj_list, train_df, col)
      
    filtered_list = []
    for k, obj in enumerate(stddev_map.keys()):
        if stddev_map[obj] < threshold:
            continue
        filtered_list.append(obj)

    if col == 'smiles':
        filtered_train_df = train_df.query("smiles in @filtered_list")
    else:
        filtered_train_df = train_df.query("cell_line in @filtered_list")
    return filtered_train_df

In [None]:
def create_per_drug_data(train_df, drug_name_map):
    drugs = set(train_df['smiles'])
    for drug in drugs:
        drug_name = drug_name_map[drug]
        drug_train_df = train_df.query("smiles == @drug")
        drug_train_df.to_csv("../data/training_files/train_gdsc2_" + drug_name + ".txt", sep="\t", header=False, index=False)

In [None]:
dataset = "gdsc2"

cell_lines = list(pd.read_csv("../data/cell2ind_" + dataset + ".txt", sep="\t", header=None, names=['I', 'C'])['C'])

drug_list = list(pd.read_csv("../data/drug2ind_" + dataset + "_all.txt", sep="\t", header=None, names=['I', 'D'])['D'])

all_df = pd.read_csv("../data/train_" + dataset + ".txt", sep="\t", header=None, names=['cell_line', 'smiles', 'auc'])

drug_info = pd.read_csv("../data/master_druglist_smiles_final.csv")[['name', 'isomeric_smiles']]
drug_smiles_map = dict(zip(drug_info.isomeric_smiles, drug_info.name))
drug_name_map = dict(zip(drug_info.name, drug_info.isomeric_smiles))

#cell_line_info_df = pd.read_csv("../data/CCLE/sample_info.csv", sep=",")[["CCLE_Name", "primary_disease"]]

In [None]:
filtered_train_df = filter_data(all_df, drug_list, 'smiles', 0.15)

In [None]:
stddev_map = calc_stddev(drug_list, filtered_train_df, 'smiles')
gdsc2_drug_hist = create_bar_plot(stddev_map.values(), 'Std Dev', '# of Drugs')

In [None]:
create_per_drug_data(filtered_train_df, drug_smiles_map)