In [4]:

# %% Imports

from ast import arg
from pytorch_lightning.loggers import WandbLogger

import time

from model import deep_GNN
import torch
from torch import device, nn
import torch
import os
from utils.dataset import H5GeometricDataset
from utils.eval import evaluate
from utils.train import train_one_epoch
from utils.configs.config import *
import wandb
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
from model import LitModel
import numpy as np
from torch.utils.data import ConcatDataset
from pytorch_lightning.callbacks import ModelCheckpoint

from utils.configs.config_dict import get_config





Number of CPUs: 14
Memory CPU in GB: 36.0


In [5]:
# DATA LOADING
TIME_MEANS = np.load(TIME_MEANS_PATH)[0, :N_VAR]
MEANS = np.load(GLOBAL_MEANS_PATH)[0, :N_VAR]
STDS = np.load(GLOBAL_STDS_PATH)[0, :N_VAR]
M = torch.as_tensor((TIME_MEANS - MEANS)/STDS)[:, 0:HEIGHT].unsqueeze(0)
STD = torch.tensor(STDS).unsqueeze(0)
datasets = {
    "train": ConcatDataset([H5GeometricDataset(os.path.join(DATA_FILE_PATH, f"{year}.h5"), means=MEANS, stds=STDS) for year in YEARS]),
    "val": H5GeometricDataset(VAL_FILE, means=MEANS, stds=STDS)
}



In [6]:
from utils.configs.config import BATCH_SIZE, BATCH_SIZE_VAL, LEARNING_RATE, MODEL_CONFIG
from utils.eval import weighted_rmse_channels
from torch.optim import Adam
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch import nn
from torch.utils.data import DistributedSampler
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader


class GraphCON(nn.Module):
    def __init__(self, GNNs, dt=1., alpha=1., gamma=1., dropout=None):
        super(GraphCON, self).__init__()
        self.dt = dt
        self.alpha = alpha
        self.gamma = gamma
        self.GNNs = GNNs  # list of the individual GNN layers
        self.dropout = dropout

    def forward(self, X0, Y0, edge_index):
        # set initial values of ODEs

        # solve ODEs using simple IMEX scheme
        for gnn in self.GNNs:
            Y0 = Y0 + self.dt * (torch.relu(gnn(X0, edge_index)) -
                                 self.alpha * Y0 - self.gamma * X0)
            X0 = X0 + self.dt * Y0

            if (self.dropout is not None):
                Y0 = F.dropout(Y0, self.dropout, training=self.training)
                X0 = F.dropout(X0, self.dropout, training=self.training)

        return X0, Y0


class deep_GNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers, dt=1., alpha=1., gamma=1., dropout=None):
        super(deep_GNN, self).__init__()
        self.enc = nn.Linear(nfeat, nhid)
        self.GNNs = nn.ModuleList()
        for _ in range(nlayers):
            self.GNNs.append(GCNConv(nhid, nhid))
        self.graphcon = GraphCON(self.GNNs, dt, alpha, gamma, dropout)
        self.dec = nn.Linear(nhid, nclass)

    def forward(self, x0, edge_index):
        # compute initial values of ODEs (encode input)
        x0 = self.enc(x0)
        # stack GNNs using GraphCON
        x0, _ = self.graphcon(x0, x0, edge_index)
        # decode X state of GraphCON at final time for output nodes
        x0 = self.dec(x0)
        return x0

In [7]:

gnn_model = deep_GNN(**MODEL_CONFIG)
model = LitModel(datasets=datasets, std=STD, model= gnn_model)

In [8]:
dataloaders = model.train_dataloader()

# %% Training
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
for epoch in range(EPOCHS):
    print(f"Epoch {epoch}")

    for i, batch in enumerate(dataloaders):
        loss = model.training_step(batch, 0)
        # backwad pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Batch: {i} --- Loss: {loss.item():.2f}")
        if i == 5:
            break


    

Epoch 0
Number of CPUs: 14
Memory CPU in GB: 36.0


/Users/magnus/Documents/eth/PMLR/venv/lib/python3.10/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


Batch: 0 --- Loss: 1348.65
Batch: 1 --- Loss: 1333.63
Batch: 2 --- Loss: 1321.51
Batch: 3 --- Loss: 1304.86
Batch: 4 --- Loss: 1293.21
Batch: 5 --- Loss: 1281.34
Batch: 6 --- Loss: 1273.12
Batch: 7 --- Loss: 1259.26
Batch: 8 --- Loss: 1251.77
Batch: 9 --- Loss: 1241.88
Batch: 10 --- Loss: 1234.55
Batch: 11 --- Loss: 1221.88
Batch: 12 --- Loss: 1211.87
Batch: 13 --- Loss: 1200.87
Batch: 14 --- Loss: 1191.17
Batch: 15 --- Loss: 1174.37
Batch: 16 --- Loss: 1158.46
Batch: 17 --- Loss: 1135.89
Batch: 18 --- Loss: 1117.50
Batch: 19 --- Loss: 1094.26
Batch: 20 --- Loss: 1078.08
Batch: 21 --- Loss: 1062.65
Batch: 22 --- Loss: 1052.87
Batch: 23 --- Loss: 1037.44
Batch: 24 --- Loss: 1020.20
Batch: 25 --- Loss: 1005.43
Batch: 26 --- Loss: 990.45
Batch: 27 --- Loss: 973.33
Batch: 28 --- Loss: 953.65
Batch: 29 --- Loss: 941.89
Batch: 30 --- Loss: 929.26
Batch: 31 --- Loss: 911.95
Batch: 32 --- Loss: 892.51
Batch: 33 --- Loss: 877.89
Batch: 34 --- Loss: 858.37
Batch: 35 --- Loss: 839.13
Batch: 36 --