### Test functions for DRP_nb module

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from importlib import reload
import torch_geometric.data as tgd

In [None]:
from DRP_nb import data_imports, feature_selection, utils, splitting

In [None]:
reload(data_imports)
reload(feature_selection)
reload(utils)
reload(splitting)

## Data imports

In [None]:
#input phos prot rna (ppr) data
inp_ppr = data_imports.DrpInputData(omic_types=['phos'], drug_rep='mol_graph')
#take out disjoint cls
inp_ppr.remove_disjoint()
inp_ppr

## Feature selection and create data for all drugs
here using ladmarks targets that are also ladmarks (ltl)

In [None]:
ltl = feature_selection.ltl(inp_ppr.phos.columns)
x_all_phos, x_drug, y_list = utils.create_all_drugs(
    inp_ppr.phos, inp_ppr.marker_drugs, inp_ppr.y_df)

_all_cls = inp_ppr.phos.index
_all_drugs = inp_ppr.all_drugs

## Data splitting and putting data in dataloaders

In [None]:
pairs_with_truth_vals = y_list.index
batch_size = 512
train_size = 0.8
rand_seed = 42

train_pairs, test_pairs = splitting.split(
    rand_seed, _all_cls, _all_drugs, pairs_with_truth_vals,
    train_size=train_size, split_type='cblind')

test_cls = np.unique([cl.split('::')[0] for cl in test_pairs])
val_pairs, test_pairs = splitting.split(
    rand_seed, pd.Index(test_cls), _all_drugs, test_pairs,
    train_size=0.5, split_type='cblind') 

xo_train_phos = x_all_phos.loc[train_pairs]
xo_val_phos = x_all_phos.loc[val_pairs]
xo_test_phos = x_all_phos.loc[test_pairs]

xd_train = x_drug.loc[train_pairs]
xd_val = x_drug.loc[val_pairs]
xd_test = x_drug.loc[test_pairs]


y_train = y_list[train_pairs]
y_val = y_list[val_pairs]
y_test = y_list[test_pairs]

In [None]:
train_dls = utils.into_dls([np.expand_dims(xo_train_phos, 1), xd_train, 
                            np.expand_dims(y_train, 1)])
test_dls = utils.into_dls([np.expand_dims(xo_test_phos, 1), xd_test, 
                           np.expand_dims(y_test, 1)], 
                          batch_size=len(y_test))
val_dls = utils.into_dls([np.expand_dims(xo_val_phos, 1), xd_val, 
                         np.expand_dims(y_val, 1)], 
                         batch_size=len(y_val))

In [None]:
from torch_geometric.data import batch 

In [None]:
tgd

In [None]:
tgd.Batch()

In [None]:
#dict that maps drug cl pair to graph rep (has lots of repeats)
pairs_to_graphs = {}
for pair in pairs_with_truth_vals:
    d = pair.split('::')[1]
    y = y_list.loc[pair].astype(np.float32)
    y = np.expand_dims(y, -1)
    graph = tgd.Data.clone(inp_ppr.dtg[d])
    graph.y = torch.tensor(y)
    pairs_to_graphs[pair] = graph

#map train and testing pairs to graphs in torch geo list objects 
train_graphs = tgd.Batch().from_data_list(
    [pairs_to_graphs[pair] for pair in train_pairs])
#test_graphs = tgd.Batch().from_data_list(
    #[pairs_to_graphs[pair] for pair in test_pairs])
#val_graphs = tgd.Batch().from_data_list(
##[pairs_to_graphs[pair] for pair in val_pairs])

In [None]:
train_graph_dls = utils.into_dls([np.expand_dims(xo_train_phos, 1), 
                                  train_graphs, 
                                  np.expand_dims(y_train, 1)])

In [None]:
train_graph_dls

In [None]:
type(train_graphs)

In [None]:
type(train_graphs) == tgd.batch.DataDataBatch