In [1]:
from sklearn.model_selection import ShuffleSplit
import pandas as pd
import numpy as np
from collections import OrderedDict
import math

In [2]:
df = pd.read_parquet('top21.parquet')# .sample(frac=1.)
df

Unnamed: 0,CELL,DRUG,AUC,GE_A1BG,GE_A1CF,GE_A2M,GE_A2ML1,GE_A3GALT2,GE_A4GALT,GE_A4GNT,...,dd_SRW10,dd_TSRW10,dd_MW,dd_AMW,dd_WPath,dd_WPol,dd_Zagreb1,dd_Zagreb2,dd_mZagreb1,dd_mZagreb2
0,CCL_100,Drug_1,0.7194,-1.440283,3.845564,-0.471557,-0.362765,0.244599,0.767670,-0.427922,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
1,CCL_1000,Drug_1,0.8588,0.831151,-0.339088,-0.395036,-0.181962,-0.408307,0.548000,0.579653,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
2,CCL_1001,Drug_1,0.8150,-1.449517,2.467445,-0.482489,-0.292453,-0.408307,-0.264778,-0.427922,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
3,CCL_1002,Drug_1,0.7922,1.158940,-0.339088,-0.099885,-0.372810,0.353417,0.421690,0.475421,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
4,CCL_1004,Drug_1,0.8194,0.863469,-0.339088,-0.449694,-0.302497,0.244599,0.569967,-0.288946,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
411299,CCL_88,Drug_999,0.9109,-0.216848,-0.087607,-0.526215,-0.332631,-0.408307,0.020793,0.822861,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
411300,CCL_889,Drug_999,0.9730,0.992737,-0.308911,-0.373173,-0.252274,-0.408307,1.124633,-0.184714,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
411301,CCL_93,Drug_999,0.9461,-1.149429,-0.278733,-0.449694,-0.312542,-0.408307,0.729228,4.783671,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
411302,CCL_961,Drug_999,0.8778,0.969654,-0.329029,-0.438762,-0.372810,-0.408307,-1.330176,-0.010994,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448


## Partitions

In [3]:
idx_vec = df.index.to_numpy(copy=True)
idx_vec

array([     0,      1,      2, ..., 411301, 411302, 411303])

In [4]:
tr_size = math.ceil(len(df) * 0.8 / 32) * 32
te_size = math.ceil(len(df) * 0.1 / 32) * 32
val_size = len(df) - tr_size - te_size
print(f"tr_size: {tr_size}, val_size: {val_size}, te_size: {te_size}")

tr_size: 329056, val_size: 41096, te_size: 41152


In [5]:
selector = ShuffleSplit(
    n_splits=1,
    train_size=tr_size,
    test_size=val_size,
    random_state=2022,
)
splits = selector.split(idx_vec)
train_idx, val_idx = next(splits)
tr_vl_idx = np.asarray(train_idx.tolist() + val_idx.tolist())

## Feature Selection

In [6]:
# select lincs1000 genes
lincs1k = pd.read_csv('lincs1000', sep='\t', header=0)
lincs1k_cols = set(map(lambda x: f"GE_{x}", lincs1k.columns.to_list()))
not_in_lincs_cols = set(filter(lambda x: x.startswith("GE_"), df.columns)) - lincs1k_cols

In [7]:
df.drop(labels=not_in_lincs_cols, axis=1, inplace=True)
df

Unnamed: 0,CELL,DRUG,AUC,GE_AARS,GE_ABCB6,GE_ABCC5,GE_ABCF1,GE_ABCF3,GE_ABHD4,GE_ABHD6,...,dd_SRW10,dd_TSRW10,dd_MW,dd_AMW,dd_WPath,dd_WPol,dd_Zagreb1,dd_Zagreb2,dd_mZagreb1,dd_mZagreb2
0,CCL_100,Drug_1,0.7194,0.033692,-0.635926,-1.193125,-1.781994,-1.799911,-0.068052,0.849771,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
1,CCL_1000,Drug_1,0.8588,0.839550,0.021048,1.653025,0.642191,3.430577,-0.398743,-0.668217,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
2,CCL_1001,Drug_1,0.8150,-0.071994,0.289810,-0.401957,0.206186,1.219314,-0.591005,1.369082,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
3,CCL_1002,Drug_1,0.7922,-0.124837,1.327531,-0.412232,0.520109,0.156206,0.477972,0.070804,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
4,CCL_1004,Drug_1,0.8194,0.496070,1.305134,0.759108,0.711951,0.432614,-1.190863,0.130724,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
411299,CCL_88,Drug_999,0.9109,0.218643,0.700419,-1.223949,-0.875105,-1.821174,1.346997,-0.178865,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
411300,CCL_889,Drug_999,0.9730,-0.217312,-0.023745,-0.186184,0.415468,-0.524182,0.316472,-0.059024,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
411301,CCL_93,Drug_999,0.9461,2.081363,-0.494079,-0.288933,-0.508861,-1.544766,0.470282,-0.088985,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
411302,CCL_961,Drug_999,0.8778,1.909623,-0.882291,-0.453331,-0.456541,-0.141464,-0.098814,-0.857965,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448


## Build Uno AUC dataframe

In [8]:
# Rename cols to compatible to Uno
df.rename({'CELL': 'Sample', 'DRUG': 'Drug1'}, axis=1, inplace=True)

In [9]:
# masks
tr_mask = df.index.isin(train_idx)
vl_mask = df.index.isin(val_idx)
te_mask = ~df.index.isin(tr_vl_idx)

In [10]:
# build label
Y_train = df.iloc[tr_mask, :][['Sample', 'Drug1', 'AUC']].reset_index(drop=True)
Y_val = df.iloc[vl_mask, :][['Sample', 'Drug1', 'AUC']].reset_index(drop=True)
Y_test = df.iloc[te_mask, :][['Sample', 'Drug1', 'AUC']].reset_index(drop=True)

In [11]:
print(f"size tr: {len(Y_train)}, val: {len(Y_val)}, test: {len(Y_test)}")

size tr: 329056, val: 41096, test: 41152


In [12]:
# cell features
col_cl = list(filter(lambda x: x.startswith("GE_"), df.columns))

x_train_0 = df.iloc[tr_mask, :][col_cl].reset_index(drop=True)
x_val_0 = df.iloc[vl_mask, :][col_cl].reset_index(drop=True)
x_test_0 = df.iloc[te_mask, :][col_cl].reset_index(drop=True)

In [13]:
x_train_0

Unnamed: 0,GE_AARS,GE_ABCB6,GE_ABCC5,GE_ABCF1,GE_ABCF3,GE_ABHD4,GE_ABHD6,GE_ABL1,GE_ACAA1,GE_ACAT2,...,GE_ZMIZ1,GE_ZMYM2,GE_ZNF131,GE_ZNF274,GE_ZNF318,GE_ZNF395,GE_ZNF451,GE_ZNF586,GE_ZNF589,GE_ZW10
0,0.033692,-0.635926,-1.193125,-1.781994,-1.799911,-0.068052,0.849771,-0.772562,-0.053064,-1.055425,...,-0.511469,0.989585,0.129393,-0.902142,-0.280627,0.952915,-0.757549,-0.593175,-0.100459,0.306474
1,0.839550,0.021048,1.653025,0.642191,3.430577,-0.398743,-0.668217,0.130737,0.742879,0.330679,...,1.158686,0.113099,-0.457662,1.102191,-0.160869,-0.110116,-0.367750,0.410733,0.330180,1.726155
2,-0.124837,1.327531,-0.412232,0.520109,0.156206,0.477972,0.070804,0.705564,0.540520,0.542445,...,1.125449,-0.003765,-0.219171,-1.646609,0.171794,0.910394,0.039768,-0.322892,-0.401907,0.129014
3,0.496070,1.305134,0.759108,0.711951,0.432614,-1.190863,0.130724,-0.252481,-0.066554,2.294327,...,-0.187409,0.709109,-0.512698,-0.547089,0.185101,-0.110116,-0.580368,-0.709010,1.116098,0.483934
4,0.046903,1.797865,1.807148,0.328267,0.985430,1.793045,-0.168879,-0.594640,-0.862497,-0.237239,...,-0.353594,-0.401105,-0.292553,-0.421102,-1.252004,1.378127,-1.590302,-1.455506,-0.595695,-0.563081
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
329051,-0.574003,0.603366,-0.658829,0.589870,0.028633,-0.360291,0.430327,-0.019813,-0.201460,-0.381625,...,0.709987,-1.289277,-0.365934,-0.993769,-0.453612,-0.726673,-0.704395,-0.490210,-0.746418,0.821108
329052,-0.217312,-0.023745,-0.186184,0.415468,-0.524182,0.316472,-0.059024,0.842428,-1.091836,0.648328,...,0.643514,-0.155690,0.331193,1.033471,0.916960,1.792709,-0.226004,0.101838,-0.638758,0.129014
329053,2.081363,-0.494079,-0.288933,-0.508861,-1.544766,0.470282,-0.088985,0.349719,-1.091836,-0.092853,...,-0.312048,0.615618,-1.008026,-0.123316,-1.398375,-0.120746,-1.253658,-0.490210,-1.812251,-0.492097
329054,1.909623,-0.882291,-0.453331,-0.456541,-0.141464,-0.098814,-0.857965,-0.895740,-0.457780,-0.343122,...,0.635205,1.316806,0.441266,-1.245743,1.196397,1.069848,2.041920,-2.987109,0.857714,-0.137177


In [14]:
# drug features
col_dr = list(filter(lambda x: x.startswith("dd_"), df.columns))

x_train_1 = df.iloc[tr_mask, :][col_dr].reset_index(drop=True)
x_val_1 = df.iloc[vl_mask, :][col_dr].reset_index(drop=True)
x_test_1 = df.iloc[te_mask, :][col_dr].reset_index(drop=True)

In [15]:
x_train_1

Unnamed: 0,dd_ABC,dd_ABCGG,dd_nAcid,dd_nBase,dd_SpAbs_A,dd_SpMax_A,dd_SpDiam_A,dd_SpAD_A,dd_SpMAD_A,dd_LogEE_A,...,dd_SRW10,dd_TSRW10,dd_MW,dd_AMW,dd_WPath,dd_WPol,dd_Zagreb1,dd_Zagreb2,dd_mZagreb1,dd_mZagreb2
0,0.178943,0.026685,-0.296681,0.307032,0.187441,0.380043,0.267720,0.187441,0.301024,0.372629,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
1,0.178943,0.026685,-0.296681,0.307032,0.187441,0.380043,0.267720,0.187441,0.301024,0.372629,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
2,0.178943,0.026685,-0.296681,0.307032,0.187441,0.380043,0.267720,0.187441,0.301024,0.372629,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
3,0.178943,0.026685,-0.296681,0.307032,0.187441,0.380043,0.267720,0.187441,0.301024,0.372629,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
4,0.178943,0.026685,-0.296681,0.307032,0.187441,0.380043,0.267720,0.187441,0.301024,0.372629,...,0.573775,0.374900,-0.013642,-0.550884,-0.055783,-0.031402,0.219639,0.261155,-0.394765,0.054656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
329051,0.610574,0.441011,-0.296681,-0.398811,0.675958,0.333467,0.075551,0.675958,0.489068,0.641463,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
329052,0.610574,0.441011,-0.296681,-0.398811,0.675958,0.333467,0.075551,0.675958,0.489068,0.641463,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
329053,0.610574,0.441011,-0.296681,-0.398811,0.675958,0.333467,0.075551,0.675958,0.489068,0.641463,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448
329054,0.610574,0.441011,-0.296681,-0.398811,0.675958,0.333467,0.075551,0.675958,0.489068,0.641463,...,0.655956,0.901895,0.634924,0.838391,-0.055783,0.409001,0.679669,0.746909,-0.115743,0.431448


In [16]:
with pd.HDFStore('top21_uno_v2.h5', "w") as store:
    store.put("y_train", Y_train, format="table")
    store.put("y_val", Y_val, format="table")
    store.put("y_test", Y_test, format="table")
    store.put("x_train_0", x_train_0, format="table")
    store.put("x_train_1", x_train_1, format="table")
    store.put("x_val_0", x_val_0, format="table")
    store.put("x_val_1", x_val_1, format="table")
    store.put("x_test_0", x_test_0, format="table")
    store.put("x_test_1", x_test_1, format="table")

    # model info
    cl_width = len(col_cl)
    dd_width = len(col_dr)
    store.put("model", pd.DataFrame())
    store.get_storer("model").attrs.input_features = OrderedDict(
        [("cell.rnaseq", "cell.rnaseq"), ("drug1.descriptors", "drug.descriptors")]
    )
    store.get_storer("model").attrs.feature_shapes = OrderedDict(
        [("cell.rnaseq", (cl_width,)), ("drug.descriptors", (dd_width,))]
    )