In [1]:
from sklearn.model_selection import ShuffleSplit
import pandas as pd
import numpy as np
from collections import OrderedDict

In [2]:
df = pd.read_parquet('top21.parquet').sample(frac=1.)
df

Unnamed: 0,CELL,DRUG,AUC,GE_A1BG,GE_A1CF,GE_A2M,GE_A2ML1,GE_A3GALT2,GE_A4GALT,GE_A4GNT,...,dd_SRW10,dd_TSRW10,dd_MW,dd_AMW,dd_WPath,dd_WPol,dd_Zagreb1,dd_Zagreb2,dd_mZagreb1,dd_mZagreb2
355529,CCL_562,Drug_688,0.9099,-1.218680,-0.268673,0.173404,2.530088,1.768047,0.817096,-0.427922,...,10.519889,88.830250,516.16090,8.602682,5134.0,60,202,237,11.444445,8.250000
114968,CCL_712,Drug_25,0.9922,0.572614,0.576304,-0.460626,-0.372810,0.353417,-0.402071,1.448252,...,10.586357,72.249630,368.13510,7.218335,911.0,47,132,168,7.902778,4.500000
52523,CCL_834,Drug_1481,0.9051,-0.286099,-0.268673,-0.416899,-0.372810,-0.408307,-0.896328,-0.427922,...,10.115247,73.856926,331.11210,8.490053,1475.0,38,134,158,6.638889,5.444445
91069,CCL_11,Drug_210,0.6898,-1.186598,2.743773,-0.429121,-0.006940,-0.310406,0.753503,-0.158447,...,9.972780,73.420520,362.12128,8.230029,1541.0,35,128,147,8.138889,5.500000
16342,CCL_956,Drug_115,0.7639,-1.509534,3.171594,-0.438762,0.028975,-0.408307,-1.363126,-0.427922,...,10.743567,63.496136,371.10050,8.434102,1584.0,60,156,198,10.895833,5.666666
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
241718,CCL_79,Drug_46,0.9989,0.766517,-0.339088,-0.291187,-0.232185,-0.408307,0.196529,-0.219458,...,11.363973,124.092190,904.34320,7.863854,18175.0,129,366,451,19.506945,14.944445
411200,CCL_878,Drug_997,0.5901,0.669566,-0.268673,3.217836,-0.181962,-0.408307,-0.978704,0.093238,...,11.127130,109.728424,729.40880,6.571250,9767.0,102,304,375,15.145833,10.916667
17359,CCL_869,Drug_116,0.5729,1.338993,-0.278733,-0.487954,-0.372810,-0.408307,0.493083,0.232213,...,10.979053,91.811554,507.18933,7.684687,3502.0,78,206,262,12.597222,8.354167
340928,CCL_449,Drug_662,0.8607,1.172790,-0.278733,-0.274789,-0.362765,0.244599,-1.346651,-0.427922,...,9.757941,65.049950,257.02590,10.281035,628.0,25,96,112,5.166666,3.888889


## Partitions

In [3]:
idx_vec = df.index.to_numpy(copy=True)
idx_vec

array([355529, 114968,  52523, ...,  17359, 340928,  13498])

In [4]:
selector = ShuffleSplit(
    n_splits=1,
    train_size=0.8,
    test_size=0.1,
    random_state=2022,
)
splits = selector.split(idx_vec)
train_idx, val_idx = next(splits)
tr_vl_idx = np.asarray(train_idx.tolist() + val_idx.tolist())

In [5]:
train_idx, len(train_idx)

(array([208810, 107585, 204275, ..., 201139, 168391,  55777]), 329043)

## Feature Selection

In [6]:
# select lincs1000 genes
lincs1k = pd.read_csv('lincs1000', sep='\t', header=0)
lincs1k_cols = set(map(lambda x: f"GE_{x}", lincs1k.columns.to_list()))
not_in_lincs_cols = set(filter(lambda x: x.startswith("GE_"), df.columns)) - lincs1k_cols

In [7]:
df.drop(labels=not_in_lincs_cols, axis=1, inplace=True)
df

Unnamed: 0,CELL,DRUG,AUC,GE_AARS,GE_ABCB6,GE_ABCC5,GE_ABCF1,GE_ABCF3,GE_ABHD4,GE_ABHD6,...,dd_SRW10,dd_TSRW10,dd_MW,dd_AMW,dd_WPath,dd_WPol,dd_Zagreb1,dd_Zagreb2,dd_mZagreb1,dd_mZagreb2
355529,CCL_562,Drug_688,0.9099,-0.098415,0.603366,1.365327,0.886353,1.835916,0.162662,-2.475820,...,10.519889,88.830250,516.16090,8.602682,5134.0,60,202,237,11.444445,8.250000
114968,CCL_712,Drug_25,0.9922,-1.274175,0.588435,1.221479,0.764272,-0.587969,-1.306221,-0.818018,...,10.586357,72.249630,368.13510,7.218335,911.0,47,132,168,7.902778,4.500000
52523,CCL_834,Drug_1481,0.9051,0.614967,-0.553804,-0.186184,0.921233,0.517663,-1.044744,0.330459,...,10.115247,73.856926,331.11210,8.490053,1475.0,38,134,158,6.638889,5.444445
91069,CCL_11,Drug_210,0.6898,-1.565972,0.000769,0.376519,-0.690784,-0.765157,-1.022898,-0.236833,...,9.972780,73.420520,362.12128,8.230029,1541.0,35,128,147,8.138889,5.500000
16342,CCL_956,Drug_115,0.7639,-1.089224,-0.352233,-0.134809,0.223626,0.262517,-0.883244,0.630062,...,10.743567,63.496136,371.10050,8.434102,1584.0,60,156,198,10.895833,5.666666
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
241718,CCL_79,Drug_46,0.9989,-0.574003,-2.390346,-0.556080,-0.369340,-0.758066,0.154972,-0.438521,...,11.363973,124.092190,904.34320,7.863854,18175.0,129,366,451,19.506945,14.944445
411200,CCL_878,Drug_997,0.5901,-0.349420,0.976647,-0.114259,0.380588,0.177468,-0.229552,1.269215,...,11.127130,109.728424,729.40880,6.571250,9767.0,102,304,375,15.145833,10.916667
17359,CCL_869,Drug_116,0.5729,-0.455106,0.088239,-0.288933,-0.229818,0.517663,0.854806,0.929665,...,10.979053,91.811554,507.18933,7.684687,3502.0,78,206,262,12.597222,8.354167
340928,CCL_449,Drug_662,0.8607,-0.904273,0.095704,1.540001,-0.090297,2.027275,-2.321365,0.819810,...,9.757941,65.049950,257.02590,10.281035,628.0,25,96,112,5.166666,3.888889


## Build Uno AUC dataframe

In [8]:
# Rename cols to compatible to Uno
df.rename({'CELL': 'Sample', 'DRUG': 'Drug1'}, axis=1, inplace=True)

In [9]:
# masks
tr_mask = df.index.isin(train_idx)
vl_mask = df.index.isin(val_idx)
te_mask = ~df.index.isin(tr_vl_idx)

In [10]:
# build label
Y_train = df.iloc[tr_mask, :][['Sample', 'Drug1', 'AUC']].reset_index(drop=True)
Y_val = df.iloc[vl_mask, :][['Sample', 'Drug1', 'AUC']].reset_index(drop=True)
Y_test = df.iloc[te_mask, :][['Sample', 'Drug1', 'AUC']].reset_index(drop=True)

In [11]:
# cell features
col_cl = list(filter(lambda x: x.startswith("GE_"), df.columns))

x_train_0 = df.iloc[tr_mask, :][col_cl].reset_index(drop=True)
x_val_0 = df.iloc[vl_mask, :][col_cl].reset_index(drop=True)
x_test_0 = df.iloc[te_mask, :][col_cl].reset_index(drop=True)

In [12]:
x_train_0

Unnamed: 0,GE_AARS,GE_ABCB6,GE_ABCC5,GE_ABCF1,GE_ABCF3,GE_ABHD4,GE_ABHD6,GE_ABL1,GE_ACAA1,GE_ACAT2,...,GE_ZMIZ1,GE_ZMYM2,GE_ZNF131,GE_ZNF274,GE_ZNF318,GE_ZNF395,GE_ZNF451,GE_ZNF586,GE_ZNF589,GE_ZW10
0,-0.098415,0.603366,1.365327,0.886353,1.835916,0.162662,-2.475820,-1.470567,2.483160,-0.025473,...,0.784771,0.054667,1.597031,0.758591,-0.945954,-0.397134,0.057486,1.324547,-1.962975,-1.184192
1,-1.274175,0.588435,1.221479,0.764272,-0.587969,-1.306221,-0.818018,-0.444090,0.095332,1.235497,...,1.358107,2.625690,1.835522,-0.031689,0.890347,1.186781,1.031985,0.822593,0.976140,2.968377
2,0.614967,-0.553804,-0.186184,0.921233,0.517663,-1.044744,0.330459,-0.799935,-0.268913,0.234422,...,-1.167898,-0.856878,-0.255862,0.334818,-1.185471,-0.354612,-0.562649,1.234452,0.523968,0.288728
3,-1.089224,-0.352233,-0.134809,0.223626,0.262517,-0.883244,0.630062,-0.608326,2.725989,1.254749,...,-1.483649,1.036331,0.037666,-1.165569,0.318166,-0.280200,-0.846140,-3.115815,0.050265,0.146760
4,1.011290,-1.345159,-1.449997,-0.160058,-0.587969,-0.183410,-0.108958,0.336032,-0.403818,0.272925,...,-2.148387,0.545499,-0.072407,0.174471,1.622205,1.069848,0.801649,0.063226,0.427074,0.519426
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
329038,-0.481528,0.984113,1.457802,1.165396,1.198052,-1.844554,-0.348640,2.033141,0.108823,0.147790,...,0.336072,0.849347,0.496303,1.182365,1.901642,0.102491,0.925676,1.350288,0.674692,0.572664
329039,-0.574003,-2.390346,-0.556080,-0.369340,-0.758066,0.154972,-0.438521,0.418151,-0.403818,0.648328,...,-0.985095,0.101413,-0.916299,-0.558542,0.344779,0.474551,-0.403186,0.578051,-0.875610,0.129014
329040,-0.349420,0.976647,-0.114259,0.380588,0.177468,-0.229552,1.269215,0.294973,0.769860,-0.824408,...,-0.087698,0.148159,0.221121,-0.627262,-0.387080,-1.396382,-0.810704,-0.297151,0.642394,-0.509843
329041,-0.904273,0.095704,1.540001,-0.090297,2.027275,-2.321365,0.819810,0.171796,1.538821,0.898597,...,1.707095,1.048017,1.523649,0.712778,3.179069,-0.471546,1.209166,2.444290,1.848185,-1.752065


In [13]:
# drug features
col_dr = list(filter(lambda x: x.startswith("dd_"), df.columns))

x_train_1 = df.iloc[tr_mask, :][col_dr].reset_index(drop=True)
x_val_1 = df.iloc[vl_mask, :][col_dr].reset_index(drop=True)
x_test_1 = df.iloc[te_mask, :][col_dr].reset_index(drop=True)

In [14]:
x_train_1

Unnamed: 0,dd_ABC,dd_ABCGG,dd_nAcid,dd_nBase,dd_SpAbs_A,dd_SpMax_A,dd_SpDiam_A,dd_SpAD_A,dd_SpMAD_A,dd_LogEE_A,...,dd_SRW10,dd_TSRW10,dd_MW,dd_AMW,dd_WPath,dd_WPol,dd_Zagreb1,dd_Zagreb2,dd_mZagreb1,dd_mZagreb2
0,30.002792,22.111973,0,0,49.796196,2.406517,4.799943,49.796196,1.310426,4.571155,...,10.519889,88.830250,516.16090,8.602682,5134.0,60,202,237,11.444445,8.250000
1,17.955917,14.437097,0,0,28.621237,2.621543,5.192865,28.621237,1.300965,4.073557,...,10.586357,72.249630,368.13510,7.218335,911.0,47,132,168,7.902778,4.500000
2,19.856009,15.247416,0,0,33.322075,2.474047,4.812980,33.322075,1.332883,4.165300,...,10.115247,73.856926,331.11210,8.490053,1475.0,38,134,158,6.638889,5.444445
3,21.589062,17.307938,0,0,33.643410,2.622182,5.244363,33.643410,1.246052,4.255863,...,10.743567,63.496136,371.10050,8.434102,1584.0,60,156,198,10.895833,5.666666
4,26.403017,22.130955,0,0,41.330150,2.468933,4.936377,41.330150,1.215593,4.441976,...,10.457746,83.520010,520.05950,10.197246,3364.0,53,178,206,13.625000,7.236111
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
329038,45.864094,36.531483,0,0,71.380844,2.727273,5.454546,71.380844,1.230704,5.000399,...,11.459282,98.156990,807.34660,7.273393,12587.0,121,326,403,23.951390,12.145833
329039,52.686478,38.476750,0,1,89.640010,2.748132,5.284868,89.640010,1.337911,5.146177,...,11.363973,124.092190,904.34320,7.863854,18175.0,129,366,451,19.506945,14.944445
329040,42.488075,26.499352,0,1,68.264770,2.652757,5.119367,68.264770,1.312784,4.922651,...,11.127130,109.728424,729.40880,6.571250,9767.0,102,304,375,15.145833,10.916667
329041,14.280035,11.600451,0,0,23.430252,2.385046,4.742016,23.430252,1.301681,3.839144,...,9.757941,65.049950,257.02590,10.281035,628.0,25,96,112,5.166666,3.888889


In [15]:
with pd.HDFStore('top21_uno_v2.h5', "w") as store:
    store.put("y_train", Y_train, format="table")
    store.put("y_val", Y_val, format="table")
    store.put("y_test", Y_val, format="table")
    store.put("x_train_0", x_train_0, format="table")
    store.put("x_train_1", x_train_1, format="table")
    store.put("x_val_0", x_val_0, format="table")
    store.put("x_val_1", x_val_1, format="table")
    store.put("x_test_0", x_val_0, format="table")
    store.put("x_test_1", x_val_1, format="table")

    # model info
    cl_width = len(col_cl)
    dd_width = len(col_dr)
    store.put("model", pd.DataFrame())
    store.get_storer("model").attrs.input_features = OrderedDict(
        [("cell.rnaseq", "cell.rnaseq"), ("drug1.descriptors", "drug.descriptors")]
    )
    store.get_storer("model").attrs.feature_shapes = OrderedDict(
        [("cell.rnaseq", (cl_width,)), ("drug.descriptors", (dd_width,))]
    )