In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import sys
import torch
import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.nn import summary
import xarray as xr
import yaml

import Dataset
import Models
import Loss
from utils import time_func

In [2]:
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Cuda device: {torch.cuda.get_device_name()}")
print(f"Cuda version: {torch.version.cuda}")
print(f"Torch geometric version: {torch_geometric.__version__}")

Torch version: 2.1.1+cu121
Cuda available: True
Cuda device: NVIDIA A100-SXM4-40GB
Cuda version: 12.1
Torch geometric version: 2.4.0


In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
params = yaml.safe_load(open('./config/pipeline.yaml'))

DATA_PATH = params['input_subset_pre_processed']
MESH_PATH = params['input_subset_grid']

DATASET_SIZE = params['dataset_size']

TRAIN_PROP = params['train_prop']
VAL_PROP = params['val_prop']
TEST_PROP = params['test_prop']
TRAIN_VAL_TEST = [TRAIN_PROP, VAL_PROP, TEST_PROP]

TRAIN_BATCH_SIZE = params['train_batch_size']
VAL_BATCH_SIZE = params['val_batch_size']
TEST_BATCH_SIZE = params['test_batch_size']

N_FEATURES = params['n_features']
HID_CHANNELS = params['hid_channels']
N_CLASSES = params['n_classes']
N_LAYERS = params['n_layers']

FINAL_ACT = None
if params['final_act'] == "sigmoid":
    FINAL_ACT = torch.sigmoid
elif params['final_act'] == "softmax":
    FINAL_ACT = torch.softmax
elif params['final_act'] == "linear":
    FINAL_ACT = torch.nn.Linear(1, 1)
class_weights = [params['loss_weight_1'], params['loss_weight_2'], params['loss_weight_3']]
LOSS_OP = None
if params['loss_op'] == "CE":
    LOSS_OP = torch.nn.CrossEntropyLoss()
elif params['loss_op'] == "WCE":
    LOSS_OP = Loss.WeightedCrossEntropyLoss(class_weights, DEVICE)
elif params['loss_op'] == "Focal":
    LOSS_OP = Loss.FocalLoss()
elif params['loss_op'] == "Dice":
    LOSS_OP = Loss.SoftDiceLoss(class_weights)
elif params['loss_op'] == "Tversky":
    LOSS_OP = Loss.TverskyLoss(alpha=0.3, beta=0.7, smooth=1.0, class_weights=class_weights)
elif params['loss_op'] == "TverskyDice":
    LOSS_OP = Loss.TverskyDiceLoss(alpha=0.3, beta=0.7, smooth=1.0, class_weights=class_weights)
    
OPTIMIZER = None
if params['optimizer'] == "Adam":
    OPTIMIZER = torch.optim.Adam

LEARN_RATE = params['learn_rate']

EPOCHS = params['epochs']

PLOT_SHOW = params['plot_show']
PLOT_FOLDER = params['output_images_path']

TIMESTAMP = time_func.start_time()

### Dataset creation

In [5]:
random_seed = random.randint(1, 10000)
print(f"Random seed for train-val-test split: {random_seed}")

timestamp = time_func.start_time()

train_dataset = Dataset.EddyDataset(root=DATA_PATH, mesh_path=MESH_PATH, dataset_size=DATASET_SIZE, split='train', proportions=TRAIN_VAL_TEST, random_seed=random_seed)
val_dataset = Dataset.EddyDataset(root=DATA_PATH, mesh_path=MESH_PATH, dataset_size=DATASET_SIZE, split='val', proportions=TRAIN_VAL_TEST, random_seed=random_seed)
test_dataset = Dataset.EddyDataset(root=DATA_PATH, mesh_path=MESH_PATH, dataset_size=DATASET_SIZE, split='test', proportions=TRAIN_VAL_TEST, random_seed=random_seed)

time_func.stop_time(timestamp, "Datasets creation")

Random seed for train-val-test split: 5743
    Shape of node feature matrix: torch.Size([239536, 1])
    Shape of graph connectivity in COO format: torch.Size([2, 1432160])
    Shape of labels: torch.Size([239536])
  ---  Datasets creation  ---  8.460 seconds.


In [6]:
print(train_dataset.len(), val_dataset.len(), test_dataset.len())

292 36 37


In [7]:
'''
train_dataset[0]

timestamp = time_func.start_time()
features_list = [data.x for data in train_dataset]
print(np.shape(features_list))
time_func.stop_time(timestamp, "features in a list!")

timestamp = time_func.start_time()
all_features = torch.cat(features_list, dim=0)
time_func.stop_time(timestamp, "features concatenated!")

all_features.shape
'''

'\ntrain_dataset[0]\n\ntimestamp = time_func.start_time()\nfeatures_list = [data.x for data in train_dataset]\nprint(np.shape(features_list))\ntime_func.stop_time(timestamp, "features in a list!")\n\ntimestamp = time_func.start_time()\nall_features = torch.cat(features_list, dim=0)\ntime_func.stop_time(timestamp, "features concatenated!")\n\nall_features.shape\n'

In [8]:
'''
global_mean = all_features.mean(dim=0)
global_std = all_features.std(dim=0)
print(f"Mean: {global_mean}\nStd: {global_std}")
'''

'\nglobal_mean = all_features.mean(dim=0)\nglobal_std = all_features.std(dim=0)\nprint(f"Mean: {global_mean}\nStd: {global_std}")\n'

In [9]:
'''
from torch_geometric.transforms import NormalizeFeatures
transform = NormalizeFeatures()

print(train_dataset[0].x)

timestamp = time_func.start_time()

train_dataset = [transform(data) for data in train_dataset]
val_dataset = [transform(data) for data in val_dataset]
test_dataset = [transform(data) for data in test_dataset]

time_func.stop_time(timestamp, "features normalized!")
'''

'\nfrom torch_geometric.transforms import NormalizeFeatures\ntransform = NormalizeFeatures()\n\nprint(train_dataset[0].x)\n\ntimestamp = time_func.start_time()\n\ntrain_dataset = [transform(data) for data in train_dataset]\nval_dataset = [transform(data) for data in val_dataset]\ntest_dataset = [transform(data) for data in test_dataset]\n\ntime_func.stop_time(timestamp, "features normalized!")\n'

In [10]:
'''
features = [data.x for data in train_dataset]
all_features = torch.cat(features, dim=0)
mean = all_features.mean(dim=0)
std = all_features.std(dim=0)
print(f"Mean: {mean}\nStd: {std}")
'''

'\nfeatures = [data.x for data in train_dataset]\nall_features = torch.cat(features, dim=0)\nmean = all_features.mean(dim=0)\nstd = all_features.std(dim=0)\nprint(f"Mean: {mean}\nStd: {std}")\n'

### Testing some parameters and orientation of graph edges

In [11]:
if (TRAIN_PROP+VAL_PROP+TEST_PROP) != 100:
    raise ValueError(f"Sum of train-val-test proportions with value {TRAIN_PROP+VAL_PROP+TEST_PROP} is different from 100")

if FINAL_ACT == None:
    raise ValueError(f"Parameter 'final_act' is invalid with value {params['final_act']}")

if LOSS_OP == None:
    if params['loss_op'] != "Dice":
        raise ValueError(f"Parameter 'loss_op' is invalid with value {params['loss_op']}")

if OPTIMIZER == None:
    raise ValueError(f"Parameter 'optimizer' is invalid with value {params['optimizer']}")

dummy_graph = train_dataset[0]

if dummy_graph.num_features != N_FEATURES:
    raise ValueError(f"Graph num_features is different from parameter N_FEATURES: ({dummy_graph.num_features} != {N_FEATURES})")

if dummy_graph.is_directed():
    raise ValueError("Graph edges are directed!")

### Train-validation-test split

In [12]:
train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=6, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=6, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=6, pin_memory=True)

print(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset))

292 36 37


### Model instantiation

In [13]:
Model = Models.GCNModel

model = Model(
    num_features = N_FEATURES,
    hidden_dim = HID_CHANNELS,
    num_classes = N_CLASSES,
    num_layers = N_LAYERS,
    num_nodes = dummy_graph.num_nodes,   # TODO can put these in Dataset.py
    final_act = FINAL_ACT
).to(DEVICE)

model

GCNModel(
  (conv_layers): ModuleList(
    (0): GCNConv(1, 32)
    (1-3): 3 x GCNConv(32, 32)
    (4): GCNConv(32, 3)
  )
)

In [14]:
dummy_graph.to(DEVICE)
print(summary(model, dummy_graph))

+---------------------------+----------------------------+----------------+----------+
| Layer                     | Input Shape                | Output Shape   | #Param   |
|---------------------------+----------------------------+----------------+----------|
| GCNModel                  | [239536, 239536]           | [239536, 3]    | 3,331    |
| ├─(conv_layers)ModuleList | --                         | --             | 3,331    |
| │    └─(0)GCNConv         | [239536, 1], [2, 1432160]  | [239536, 32]   | 64       |
| │    └─(1)GCNConv         | [239536, 32], [2, 1432160] | [239536, 32]   | 1,056    |
| │    └─(2)GCNConv         | [239536, 32], [2, 1432160] | [239536, 32]   | 1,056    |
| │    └─(3)GCNConv         | [239536, 32], [2, 1432160] | [239536, 32]   | 1,056    |
| │    └─(4)GCNConv         | [239536, 32], [2, 1432160] | [239536, 3]    | 99       |
+---------------------------+----------------------------+----------------+----------+


### Optimizer

In [15]:
OPTIMIZER = OPTIMIZER(model.parameters(), lr=LEARN_RATE)

### Dice Loss

In [16]:
if params['loss_op'] == "Dice":
    
    timestamp = time_func.start_time()

    tot_counts = [0, 0, 0]
    for batch in train_loader:
        batch = batch.to(DEVICE)
        
        unique, counts = torch.unique(batch.y, return_counts=True)
        
        # TODO - I don't really like this, it just informs me whether something is wrong and then does it anyway
        if 0 not in unique:
            print("Error: class 0 not present in batch")
        elif 1 not in unique:
            print("Error: class 1 not present in batch")
        elif 2 not in unique:
            print("Error: class 2 not present in batch")
        else:
            for class_idx in unique:
                tot_counts[class_idx] += counts[class_idx].item()

    time_func.stop_time(timestamp, "Unique counted!")
    
    freq = [c/np.sum(tot_counts) for c in tot_counts]
    freq_inv = [1/f for f in freq]
    class_weights = [f/np.sum(freq_inv) for f in freq_inv]
    print(freq_inv, "- freq_inv")
    print(class_weights, "- class_weights")
    LOSS_OP = Loss.SoftDiceLoss(class_weights)

  ---  Unique counted!  ---  13.499 seconds.
[1.1468968602862453, 14.861141429462894, 16.449415899966535] - freq_inv
[0.03533539178958915, 0.45786528242784147, 0.5067993257825694] - class_weights


### Train function

In [17]:
def train():
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(DEVICE)

        # zero the parameter gradients
        OPTIMIZER.zero_grad()

        # forward + loss
        pred = model(batch)
        loss = LOSS_OP(pred, batch.y)
        
        # If you try the Soft Dice Score, use this(even if the loss stays constant)
        #loss.requires_grad = True
        #loss = torch.tensor(loss.item(), requires_grad=True)

        total_loss += loss.item() * batch.num_graphs
        
        # backward + optimize
        loss.backward()
        OPTIMIZER.step()

    average_loss = total_loss / len(train_loader.dataset)
    return average_loss

### Evaluation function

In [18]:
@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0

    for batch in loader:
        batch = batch.to(DEVICE)

        # forward + loss
        pred = model(batch)
        loss = LOSS_OP(pred, batch.y)

        total_loss += loss.item() * batch.num_graphs
    
    average_loss = total_loss / len(loader.dataset)
    return average_loss

### Computation time check

In [19]:
time_func.stop_time(TIMESTAMP, "Computation before training finished!")

  ---  Computation before training finished!  ---  23.455 seconds.


In [20]:
'''
from time import time
import multiprocessing as mp

for num_workers in range(2, mp.cpu_count(), 2):
    train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)

    start = time()
    for epoch in range(1, 3):
        for i, data in enumerate(train_loader, 0):
            pass
        for i, data in enumerate(val_loader, 0):
            pass
    end = time()
    print("Finish with: {} second, num_workers={}".format(end - start, num_workers))
'''

'\nfrom time import time\nimport multiprocessing as mp\n\nfor num_workers in range(2, mp.cpu_count(), 2):\n    train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True)\n    val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)\n\n    start = time()\n    for epoch in range(1, 3):\n        for i, data in enumerate(train_loader, 0):\n            pass\n        for i, data in enumerate(val_loader, 0):\n            pass\n    end = time()\n    print("Finish with: {} second, num_workers={}".format(end - start, num_workers))\n'

### Epoch training, validation and testing

In [21]:
timestamp = time_func.start_time()

train_loss = []
valid_loss = []

for epoch in range(EPOCHS):
    t_loss = train()
    v_loss = evaluate(val_loader)
    print(f'Epoch: {epoch+1:03d}, Train running loss: {t_loss:.4f}, Val running loss: {v_loss:.4f}')
    train_loss.append(t_loss)
    valid_loss.append(v_loss)

time_func.stop_time(timestamp, "Training Complete!")

metric = evaluate(test_loader)
print(f'Metric for test: {metric:.4f}')

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.64 GiB. GPU 0 has a total capacty of 39.39 GiB of which 2.54 GiB is free. Process 756407 has 28.11 GiB memory in use. Including non-PyTorch memory, this process has 8.67 GiB memory in use. Of the allocated memory 5.83 GiB is allocated by PyTorch, and 2.36 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

### Comparison plot for train/validation loss

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(train_loss, label='Train loss')
plt.plot(valid_loss, label='Validation loss')
plt.legend(title="Loss type: " + params['loss_op'])

if PLOT_SHOW:
    plt.show()
else:
    plt.savefig(PLOT_FOLDER+"/train_val_losses_demo.png")
    plt.close()

### Graphical comparison model prediction/ground truth

In [None]:
timestamp = time_func.start_time()
DEVICE=torch.device('cpu')
model = model.to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    batch = next(iter(test_loader))
    batch = batch.to(DEVICE)
    pred = model(batch)

In [None]:
mesh = xr.open_dataset(MESH_PATH)
mesh_lon = mesh.lon[mesh.nodes].values
mesh_lat = mesh.lat[mesh.nodes].values

In [None]:
this_target = batch.y[:mesh.dims['nodes_subset']]
_, this_pred = torch.max(pred[:mesh.dims['nodes_subset']], dim=1)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(12, 12))

im = axes[0].scatter(mesh_lon, mesh_lat, c=this_target, s=1)
im2 = axes[1].scatter(mesh_lon, mesh_lat, c=this_pred, s=1)

if PLOT_SHOW:
    plt.show()
else:
    plt.savefig(PLOT_FOLDER + "/pred_vs_ground_demo.png")
    plt.close()

time_func.stop_time(timestamp, "pred_vs_ground plot created!")

### Accuracy calculation

In [None]:
# Running it on cuda is a huge improvement
DEVICE=torch.device('cuda')
model = model.to(DEVICE)

In [None]:
timestamp = time_func.start_time()

model.eval()
with torch.no_grad():
    tot_background = 0
    correct_pred = 0
    tot_pred = len(test_loader.dataset)*dummy_graph.num_nodes

    for batch in test_loader:
        batch = batch.to(DEVICE)

        pred = model(batch)

        _, indices = torch.max(pred, dim=1)

        tot_background += (batch.y == 0).sum().item()

        # This works because the values in the indices correspond to the values in batch.y
        correct_pred += (indices == batch.y).sum().item()

    print(f"Total background cells:\t{tot_background}")
    print(f"Correct predictions:\t{correct_pred}")
    print(f"Total predictions:\t{tot_pred}")
    print(f"GCN accuracy:\t{correct_pred/tot_pred*100:.2f}%")

time_func.stop_time(timestamp, "Accuracy calculated!")

In [None]:
# Save the model's state dictionary to a file
#PATH = "gcn_modelone.pth"
#torch.save(model.state_dict(), PATH)