In [1]:
import pandas as pd
import yaml
import torch

from src.constants import CONTINUOUS_COVARIATES_PROCESSED, STATIC_COLS, TARGET_COL
from src.dataset import df_to_patient_tensors, GraphClassificationDataset, build_classification_datasets

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
from torch_geometric.utils import remove_self_loops

from src.utils import dag_fully_connected
from src.models.temporal_gnn import GraphClassification, LSTMClassification

In [2]:
fpath_train = "/home/gaurang/glycemic_control/data/glycaemia_project_csvs/processed_data/train_test_splits/v4_patient_split_mini/train_mini.csv"
fpath_valid = "/home/gaurang/glycemic_control/data/glycaemia_project_csvs/processed_data/train_test_splits/v4_patient_split_mini/valid_mini.csv"
fpath_test = "/home/gaurang/glycemic_control/data/glycaemia_project_csvs/processed_data/train_test_splits/v4_patient_split_mini/test_mini.csv"

fpath_config = "/home/gaurang/glycemic_control/code/src/config.yaml"

In [3]:
df_train = pd.read_csv(fpath_train)
df_valid = pd.read_csv(fpath_valid)
df_test = pd.read_csv(fpath_test)

with open(fpath_config, 'r') as f:
    config = yaml.load(f, Loader=yaml.Loader)
config

{'lstm_input_dim': 1,
 'lstm_hidden_dim': 32,
 'lstm_num_layers': 5,
 'gnn_hidden_dim': 128,
 'gnn_out_dim': 256,
 'fc_hidden_dim': 256,
 'dropout': 0.7,
 'gat_heads': 4,
 'num_classes': 2,
 'batch_size_train': 64,
 'batch_size_val': 64,
 'batch_size_test': 64,
 'num_epochs': 50,
 'patience': 10,
 'min_delta': 1e-05,
 'lr': 0.01}

In [4]:
patient_tensors, patient_seq_lens, target_labels = df_to_patient_tensors(
        df_train,
        feature_cols=CONTINUOUS_COVARIATES_PROCESSED + STATIC_COLS,
        target_col=TARGET_COL,
    )

patient_tensors = pad_sequence(patient_tensors, batch_first=True)


In [5]:
init_mat_adj = dag_fully_connected(num_nodes=16, add_self_loops=False)

In [6]:
dataset = GraphClassificationDataset(patient_tensors, patient_seq_lens, target_labels, init_adj_mat=init_mat_adj)

In [7]:
loader = DataLoader(dataset, batch_size=2)

In [8]:
it = iter(loader)
batch = next(it)
batch

DataBatch(x=[14, 24], edge_index=[2, 480], edge_attr=[480, 1], y=[2], seq_lens=[2], batch=[14], ptr=[3])

In [9]:
batch.y

tensor([1, 1])

In [10]:
dataset_train, dataset_val, dataset_test = build_classification_datasets(
        df_train=df_train,
        df_val=df_valid,
        df_test=df_test,
        feature_cols=CONTINUOUS_COVARIATES_PROCESSED,
        target_col=TARGET_COL,
    )

In [11]:
loader = torch.utils.data.DataLoader(dataset_train, batch_size=2)
it = iter(loader)
batch = next(it)
x, y, seq_lens = batch

In [12]:
x

tensor([[[-0.8244, -0.6240, -0.1155,  0.7369,  0.9792,  0.8945,  1.0972,
           1.9676,  1.7440,  1.6664,  1.7042,  1.1466, -0.0492, -0.0633,
           0.5351,  0.5048,  0.4311,  0.3404,  1.5141,  0.8249,  0.1092,
          -0.8244,  0.0000,  0.0000],
         [-0.4491, -0.3683, -0.0417,  1.6283,  2.4366,  3.6446,  4.0667,
           3.4946,  2.7069,  2.1011,  0.9488,  0.3014,  0.2205,  0.0923,
           0.0569,  0.0569,  0.0417, -0.0595, -0.1607, -0.2619, -0.4491,
          -0.4491,  0.0000,  0.0000],
         [ 0.9747,  0.6709,  2.7215,  2.6962,  2.4177,  2.0886,  1.6329,
           3.5316,  0.4177,  0.1392,  0.4684, -0.2152, -1.0506, -0.7722,
          -0.2658,  0.0127, -0.2911, -0.5190, -0.7215, -0.3165,  1.7342,
           0.6962,  0.0000,  0.0000]],

        [[-0.8244, -0.8244, -0.8244, -0.6866, -0.5797, -0.4738, -0.3118,
          -0.2262,  0.3454,  0.3336,  0.2766,  0.2960, -0.3124, -0.3102,
          -0.2992, -0.2992,  3.3921, -0.5101,  0.1186,  2.6333,  0.1186,
        

In [13]:
model = GraphClassification(config, num_nodes=16, pretraining=False)
model_lstm = LSTMClassification(config)

In [14]:
out = model_lstm(x, seq_lens)
out

tensor([[-0.0621,  0.0656],
        [-0.0275,  0.0468]], grad_fn=<AddmmBackward0>)

In [15]:
loss = torch.nn.CrossEntropyLoss()
loss(out, y)

tensor(0.6440, grad_fn=<NllLossBackward0>)

In [17]:
gnn_input[0].shape

torch.Size([2, 2])

In [16]:
torch.max(out, 1)

torch.return_types.max(
values=tensor([0.0656, 0.0468], grad_fn=<MaxBackward0>),
indices=tensor([1, 1]))

In [17]:
out = model.gnn(gnn_input, 1, data.edge_index)

In [18]:
out[0].shape

torch.Size([1, 16])

In [20]:
model.fc(out[0])

tensor([[-0.0421,  0.0925]], grad_fn=<AddmmBackward0>)

In [22]:
data = batch

x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
seq_lens = data.seq_lens

x = x.view(2, -1, 16)
x = pack_padded_sequence(x[:, :, 3], seq_lens, batch_first=True)

In [23]:
x.data.shape

torch.Size([46, 3])

In [56]:
s = data.x[:data.seq_len]
s.shape

torch.Size([22, 16])

In [51]:
s = pack_padded_sequence(s, seq_lens, batch_first=True)

RuntimeError: Expected `len(lengths)` to be equal to batch_size, but got 2 (batch_size=1)

In [37]:
out, (hn,cn) = lstm(x)
hn.shape

torch.Size([4, 2, 64])

In [39]:
hn.size()

torch.Size([4, 2, 64])

In [58]:
import torch_geometric


isinstance(data, torch_geometric.data.batch.Batch)

False