In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch
import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.nn import summary
import yaml

import Dataset
import Models
from utils import time_func

In [2]:
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Cuda device: {torch.cuda.get_device_name()}")
print(f"Cuda version: {torch.version.cuda}")
print(f"Torch geometric version: {torch_geometric.__version__}")

Torch version: 2.0.1+cu117
Cuda available: False
Cuda version: 11.7
Torch geometric version: 2.3.1


In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

In [4]:
params = yaml.safe_load(open('./config/pipeline.yaml'))

DATA_PATH = params['input_subset_pre_processed']
MESH_PATH = params['input_subset_grid']

TRAIN_PROP = params['train_prop']
VAL_PROP = params['val_prop']
TEST_PROP = params['test_prop']

TRAIN_BATCH_SIZE = params['train_batch_size']
VAL_BATCH_SIZE = params['val_batch_size']
TEST_BATCH_SIZE = params['test_batch_size']

# TODO use these
N_FEATURES = params['n_features']
HID_CHANNELS = params['hid_channels']
N_CLASSES = params['n_classes']

FINAL_ACT = None
if params['final_act'] == "sigmoid":
    FINAL_ACT = torch.sigmoid
elif params['final_act'] == "linear":
    FINAL_ACT = torch.nn.Linear(1, 1)

LOSS_OP = None
if params['loss_op'] == "CE":
    LOSS_OP = torch.nn.CrossEntropyLoss()

OPTIMIZER = None
if params['optimizer'] == "Adam":
    OPTIMIZER = torch.optim.Adam

LEARN_RATE = params['learn_rate']

# TODO use these
PLOT_SHOW = params['plot_show']
PLOT_VERTICAL = params['plot_vertical']

TIMESTAMP = time_func.start_time()

### Dataset creation

In [5]:
timestamp = time_func.start_time()
train_dataset = Dataset.EddyDataset(root=DATA_PATH, mesh_path=MESH_PATH, split='train')
val_dataset = Dataset.EddyDataset(root=DATA_PATH, mesh_path=MESH_PATH, split='val')
test_dataset = Dataset.EddyDataset(root=DATA_PATH, mesh_path=MESH_PATH, split='test')
time_func.stop_time(timestamp, "Datasets creation")

train:  256  val:  73  test:  36
    Shape of node feature matrix: torch.Size([757747, 1])
    Shape of graph connectivity in COO format: torch.Size([2, 4537526])
    Shape of labels: torch.Size([757747])
train:  256  val:  73  test:  36
train:  256  val:  73  test:  36
  ---  Datasets creation  ---  0.307 seconds.


In [6]:
print(train_dataset.len(), val_dataset.len(), test_dataset.len())

256 73 36


In [7]:
val_dataset[25]

Data(x=[757747, 1], edge_index=[2, 4537526], y=[757747])

### Testing some parameters and orientation of graph edges

In [8]:
if (TRAIN_PROP+VAL_PROP+TEST_PROP) != 100:
    raise ValueError(f"Sum of train-val-test proportions with value {TRAIN_PROP+VAL_PROP+TEST_PROP} is different from 1")

if FINAL_ACT == None:
    raise ValueError(f"Parameter 'final_act' is invalid with value {params['final_act']}")

if LOSS_OP == None:
    raise ValueError(f"Parameter 'loss_op' is invalid with value {params['loss_op']}")

if OPTIMIZER == None:
    raise ValueError(f"Parameter 'optimizer' is invalid with value {params['optimizer']}")

dummy_graph = train_dataset[0]

if dummy_graph.num_features != N_FEATURES:
    raise ValueError(f"Graph num_features is different from parameter N_FEATURES: ({dummy_graph.num_features} != {N_FEATURES})")

if dummy_graph.is_directed():
    raise ValueError("Graph edges are directed!")

### Train-validation-test split

In [9]:
train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)

print(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset))

256 73 36


### Model instantiation

In [10]:
Model = Models.GUNet

model = Model(
    in_channels = N_FEATURES,
    hidden_channels = HID_CHANNELS,
    out_channels = N_CLASSES,
    num_nodes = dummy_graph.num_nodes,   # TODO can put these in Dataset.py
    final_act = FINAL_ACT
).to(DEVICE)

model

GUNet instantiated!
	Middle act: relu
	Final act: torch


GUNet(
  (unet): GraphUNet(1, 32, 3, depth=3, pool_ratios=[0.002639403389257892, 0.5, 0.5])
)

In [11]:
if torch.cuda.is_available():
    #print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    #print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    #print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print(torch.cuda.memory_summary())
else:
    print(summary(model, dummy_graph))

  C = torch.sparse.mm(A, B)


+-------------------------------+---------------------------------------------------+---------------------------------------------------------+----------+
| Layer                         | Input Shape                                       | Output Shape                                            | #Param   |
|-------------------------------+---------------------------------------------------+---------------------------------------------------------+----------|
| GUNet                         | [757747, 757747]                                  | [757747, 3]                                             | 5,539    |
| ├─(unet)GraphUNet             | [757747, 1], [2, 4537526]                         | [757747, 3]                                             | 5,539    |
| │    └─(down_convs)ModuleList | --                                                | --                                                      | 3,232    |
| │    │    └─(0)GCNConv        | [757747, 1], [2, 4537526], [4537526]

### Optimizer

In [12]:
OPTIMIZER = OPTIMIZER(model.parameters(), lr=LEARN_RATE)

### Train function

In [13]:
def train():
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(DEVICE)

        # zero the parameter gradients
        OPTIMIZER.zero_grad()

        # forward + loss
        pred = model(batch)
        loss = LOSS_OP(pred, batch.y)

        # If you try the Soft Dice Score, use this(even if the loss stays constant)
        #loss.requires_grad = True
        #loss = torch.tensor(loss.item(), requires_grad=True)

        # backward + optimize
        # loss * _train_batch_size(8)
        total_loss += loss.item() * batch.num_graphs
        print(total_loss, loss.item())
        loss.backward()
        OPTIMIZER.step()

    # average loss = total_loss / training graps(256)
    average_loss = total_loss / len(train_loader.dataset)
    return average_loss

#loss = train()
#print("Train loss, debug: ", loss)

In [14]:
'''
train_count = 0
for batch in train_loader:
    for i in range(batch.y.shape[0]):
        if batch.y[i] > 2:
            train_count += 1
print(train_count)

val_count = 0
for batch in val_loader:
    for i in range(batch.y.shape[0]):
        if batch.y[i] > 2:
            val_count += 1
print(val_count)

test_count = 0
for batch in test_loader:
    for i in range(batch.y.shape[0]):
        if batch.y[i] > 2:
            test_count += 1
print(test_count)
'''

'\ntrain_count = 0\nfor batch in train_loader:\n    for i in range(batch.y.shape[0]):\n        if batch.y[i] > 2:\n            train_count += 1\nprint(train_count)\n\nval_count = 0\nfor batch in val_loader:\n    for i in range(batch.y.shape[0]):\n        if batch.y[i] > 2:\n            val_count += 1\nprint(val_count)\n\ntest_count = 0\nfor batch in test_loader:\n    for i in range(batch.y.shape[0]):\n        if batch.y[i] > 2:\n            test_count += 1\nprint(test_count)\n'

In [15]:
batch = next(iter(test_loader))
print(batch.y[1818])
print(batch.y[4681])

tensor(1)
tensor(0)


### Evaluation function

In [16]:
@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0

    for batch in loader:
        batch = batch.to(DEVICE)

        # forward + loss
        pred = model(batch)
        loss = LOSS_OP(pred, batch.y)

        total_loss += loss.item() * batch.num_graphs
    
    total_loss = total_loss / len(loader.dataset)
    return total_loss

### Computation time check

In [18]:
time_func.stop_time(TIMESTAMP, "Computation before training finished!")

  ---  Computation before training finished!  ---  4.403 seconds.
