## What data gets passed to the model?

right before our data gets passed to the train function there are 2 things happening with it first

```pyhon
geo_datalist = read_batch(index, upper_index, xTr, yTr, drugTr,\
                num_feature, num_gene, num_drug, edge_index)
```
&<br>
```python
dataset_loader, node_num, feature_dim = GeoGraphLoader.load_graph(geo_datalist, prog_args)
```

what does `read_batch` do?

In [2]:
import torch
import numpy as np
import pandas as pd
import networkx as nx

from numpy import inf
from torch_geometric.data import Data

### Loading our data

In [3]:
k = 5
n_fold = 5
fold_n = 5

# DATASET SELECTION
# dataset = 'data-drugcomb-fi'
# dataset = 'data-DrugCombDB'
dataset = 'data-nci'
form_data_path = '../' + dataset + '/form_data'

In [4]:
print('--- LOADING TRAINING FILES ... ---')
xTr = np.load('../' + dataset + '/form_data/xTr' + str(fold_n) + '.npy')
yTr = np.load('../' + dataset + '/form_data/yTr' + str(fold_n) + '.npy')
drugTr =  np.load('../' + dataset + '/form_data/drugTr' + str(fold_n) + '.npy')
edge_index = torch.from_numpy(np.load(form_data_path + '/edge_index.npy')).long() 

--- LOADING TRAINING FILES ... ---


In [5]:
num_feature = 4
dict_drug_num = pd.read_csv('../' + dataset + '/filtered_data/drug_num_dict.csv')
num_drug = dict_drug_num.shape[0]
final_annotation_gene_df = pd.read_csv('../' + dataset + '/filtered_data/kegg_gene_annotation.csv')
num_gene = final_annotation_gene_df.shape[0]

In [6]:
class ReadGeoGraph():
    def __init__(self, dir_opt):
        self.dir_opt = dir_opt

    def read_feature(self, num_graph, num_feature, num_gene, num_drug, xBatch):
        # FORM [graph_feature_list]
        num_node = num_gene + num_drug
        xBatch = xBatch.reshape(num_graph, num_node, num_feature)
        graph_feature_list = []
        for i in range(num_graph):
            graph_feature_list.append(xBatch[i, :, :])
        return graph_feature_list

    def read_label(self, yBatch):
        yBatch_list = [label[0] for label in list(yBatch)]
        graph_label_list = yBatch_list
        return graph_label_list

    def read_drug(self, num_graph, drugBatch):
        graph_drug_list = []
        for i in range(num_graph):
            graph_drug_list.append(drugBatch[i, :])
        return graph_drug_list

    def form_geo_datalist(self, num_graph, graph_feature_list, graph_label_list, graph_drug_list, edge_index):
        geo_datalist = []
        for i in range(num_graph):
            graph_feature = graph_feature_list[i]
            graph_label = graph_label_list[i]
            graph_drug = graph_drug_list[i]
            # CONVERT [numpy] TO [torch]
            graph_feature = torch.from_numpy(graph_feature).float()
            graph_label = torch.from_numpy(np.array([graph_label])).float()
            graph_drug = torch.from_numpy(graph_drug).int()
            geo_data = Data(x=graph_feature, edge_index=edge_index, label=graph_label, drug_index=graph_drug)
            geo_datalist.append(geo_data)
        return geo_datalist


def read_batch(index, upper_index, x_input, y_input, drug_input,\
            num_feature, num_gene, num_drug, edge_index):
    # FORMING BATCH FILES
    dir_opt = '/datainfo'
    form_data_path = './datainfo/form_data'
    print('--------------' + str(index) + ' to ' + str(upper_index) + '--------------')
    xBatch = x_input[index : upper_index, :]
    yBatch = y_input[index : upper_index, :]
    drugBatch = drug_input[index : upper_index, :]
    print(xBatch.shape)
    print(yBatch.shape)
    print(drugBatch.shape)
    # PREPARE LOADING LISTS OF [features, labels, drugs, edge_index]
    print('READING BATCH GRAPHS TO LISTS ...')
    num_graph = upper_index - index
    # print('READING BATCH FEATURES ...')
    # For each row there is a [2034 by 4] array
    # That means for each node (gene or drug) there are 4 numbers
    # DoubleDrug, SingleDrug, RNA, CNV
    graph_feature_list =  ReadGeoGraph(dir_opt).read_feature(num_graph, num_feature, num_gene, num_drug, xBatch)
    # print('READING BATCH LABELS ...')
    # put the scores in a list
    graph_label_list = ReadGeoGraph(dir_opt).read_label(yBatch)
    # print('READING BATCH DRUGS ...')
    # transform list of drugs into a list
    graph_drug_list = ReadGeoGraph(dir_opt).read_drug(num_graph, drugBatch)
    # print('FORMING GEOMETRIC GRAPH DATALIST ...')
    # for each x in batch,
    # create a graph, with 
    # label=score
    # drug_index=[DRG A, DRG B]
    # 2034 nodes each node 4 features
    # connections according to edge_index (all drugs and genes, drugs and targets bidirectional, genes depending on graph)
    geo_datalist = ReadGeoGraph(dir_opt).form_geo_datalist(num_graph, \
        graph_feature_list, graph_label_list, graph_drug_list, edge_index)
    return geo_datalist

In [7]:
index = 0
upper_index = 64


In [8]:
print(xTr[0].shape)
print(yTr[0].shape)
print(drugTr[0].shape)
print(yTr[0])
print(drugTr[0])
print(edge_index[:, :].shape)
print(edge_index[:, 0])

(8136,)
(1,)
(2,)
[2.11111111]
[2022. 2030.]
torch.Size([2, 18610])
tensor([  17, 1240])


In [9]:
geo_datalist = read_batch(index, upper_index, xTr, yTr, drugTr,\
                num_feature, num_gene, num_drug, edge_index)
len(geo_datalist)

--------------0 to 64--------------
(64, 8136)
(64, 1)
(64, 2)
READING BATCH GRAPHS TO LISTS ...


64

In [10]:
# FORMING BATCH FILES
dir_opt = '/datainfo'
form_data_path = './datainfo/form_data'
print('--------------' + str(index) + ' to ' + str(upper_index) + '--------------')
xBatch = xTr[index : upper_index, :]
yBatch = yTr[index : upper_index, :]
drugBatch = drugTr[index : upper_index, :]
print(xBatch.shape)
print(yBatch.shape)
print(drugBatch.shape)
# PREPARE LOADING LISTS OF [features, labels, drugs, edge_index]
print('READING BATCH GRAPHS TO LISTS ...')
num_graph = upper_index - index
# print('READING BATCH FEATURES ...')
graph_feature_list =  ReadGeoGraph(dir_opt).read_feature(num_graph, num_feature, num_gene, num_drug, xBatch)
print(graph_feature_list[0])

--------------0 to 64--------------
(64, 8136)
(64, 1)
(64, 2)
READING BATCH GRAPHS TO LISTS ...
[[ 0.  0. 29. -1.]
 [ 0.  0. 11.  1.]
 [ 0.  0.  6.  1.]
 ...
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]


In [11]:
num_node = num_gene + num_drug
print(xBatch.shape)
xBatch = xBatch.reshape(num_graph, num_node, num_feature)
print(xBatch.shape)
graph_feature_list = []
for i in range(num_graph):
    graph_feature_list.append(xBatch[i, :, :])
graph_feature_list[0][1890]

(64, 8136)
(64, 2034, 4)


array([0., 0., 2., 0.])

In [12]:
print(yBatch.shape)
yBatch_list = [label[0] for label in list(yBatch)]

graph_label_list = yBatch_list
print(graph_label_list)

(64, 1)
[2.111111111111111, -3.333333333333333, -8.11111111111111, 4.888888888888889, -18.77777777777778, -5.444444444444445, 1.0, -3.0, -0.3333333333333333, -5.555555555555555, 5.444444444444445, 4.222222222222222, -7.444444444444445, -4.111111111111111, 11.77777777777778, -7.555555555555555, -3.6666666666666665, -8.11111111111111, -1.8888888888888888, -2.111111111111111, -15.77777777777778, 11.77777777777778, -8.333333333333334, -4.111111111111111, -5.222222222222222, 1.7777777777777777, -2.2222222222222223, 8.777777777777779, -3.7777777777777777, -4.333333333333333, -3.555555555555556, -15.555555555555555, 2.2222222222222223, -3.111111111111111, -5.888888888888889, -2.7777777777777777, -10.11111111111111, -12.333333333333334, -4.666666666666667, -4.444444444444445, 0.5555555555555556, 2.4444444444444446, -0.2222222222222222, -5.222222222222222, -3.0, -3.2222222222222223, -4.222222222222222, 21.444444444444443, -1.7777777777777777, -5.444444444444445, 1.3333333333333333, -2.777777777

In [13]:
graph_drug_list = []
print(drugBatch[0].shape)
for i in range(num_graph):
    graph_drug_list.append(drugBatch[i, :])
graph_drug_list[0].shape

(2,)


(2,)

In [15]:
geo_datalist = []
for i in range(num_graph):
    graph_feature = graph_feature_list[i]
    #print(len(graph_feature))
    graph_label = graph_label_list[i]
    #print(graph_label)
    graph_drug = graph_drug_list[i]
    #print(graph_drug)
    # CONVERT [numpy] TO [torch]
    graph_feature = torch.from_numpy(graph_feature).float()
    graph_label = torch.from_numpy(np.array([graph_label])).float()
    graph_drug = torch.from_numpy(graph_drug).int()
    geo_data = Data(x=graph_feature, edge_index=edge_index, label=graph_label, drug_index=graph_drug)
    #print(edge_index.shape)
    geo_datalist.append(geo_data)

geo_datalist[0].x.shape[0]

2034