In [5]:

# %% Imports

from ast import arg
from pytorch_lightning.loggers import WandbLogger

import time

from model import deep_GNN
import torch
from torch import device, nn
import torch
import os
from utils.dataset import H5GeometricDataset
from utils.eval import evaluate
from utils.train import train_one_epoch
from utils.config import *
import wandb
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
from model import LitModel
import numpy as np
from torch.utils.data import ConcatDataset
from pytorch_lightning.callbacks import ModelCheckpoint



In [7]:
# DATA LOADING
TIME_MEANS = np.load(TIME_MEANS_PATH)[0, :N_VAR]
MEANS = np.load(GLOBAL_MEANS_PATH)[0, :N_VAR]
STDS = np.load(GLOBAL_STDS_PATH)[0, :N_VAR]
M = torch.as_tensor((TIME_MEANS - MEANS)/STDS)[:, 0:HEIGHT].unsqueeze(0)
STD = torch.tensor(STDS).unsqueeze(0)
datasets = {
    "train": ConcatDataset([H5GeometricDataset(os.path.join(DATA_FILE_PATH, f"{year}.h5"), means=MEANS, stds=STDS) for year in YEARS]),
    "val": H5GeometricDataset(VAL_FILE, means=MEANS, stds=STDS)
}



In [8]:
from utils.config import BATCH_SIZE, BATCH_SIZE_VAL, LEARNING_RATE, MODEL_CONFIG
from utils.eval import weighted_rmse_channels
from torch.optim import Adam
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch import nn
from torch.utils.data import DistributedSampler
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader


class GraphCON(nn.Module):
    def __init__(self, GNNs, dt=1., alpha=1., gamma=1., dropout=None):
        super(GraphCON, self).__init__()
        self.dt = dt
        self.alpha = alpha
        self.gamma = gamma
        self.GNNs = GNNs  # list of the individual GNN layers
        self.dropout = dropout

    def forward(self, X0, Y0, edge_index):
        # set initial values of ODEs

        # solve ODEs using simple IMEX scheme
        for gnn in self.GNNs:
            Y0 = Y0 + self.dt * (torch.relu(gnn(X0, edge_index)) -
                                 self.alpha * Y0 - self.gamma * X0)
            X0 = X0 + self.dt * Y0

            if (self.dropout is not None):
                Y0 = F.dropout(Y0, self.dropout, training=self.training)
                X0 = F.dropout(X0, self.dropout, training=self.training)

        return X0, Y0


class deep_GNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers, dt=1., alpha=1., gamma=1., dropout=None):
        super(deep_GNN, self).__init__()
        self.enc = nn.Linear(nfeat, nhid)
        self.GNNs = nn.ModuleList()
        for _ in range(nlayers):
            self.GNNs.append(GCNConv(nhid, nhid))
        self.graphcon = GraphCON(self.GNNs, dt, alpha, gamma, dropout)
        self.dec = nn.Linear(nhid, nclass)

    def forward(self, x0, edge_index):
        # compute initial values of ODEs (encode input)
        x0 = self.enc(x0)
        # stack GNNs using GraphCON
        x0, _ = self.graphcon(x0, x0, edge_index)
        # decode X state of GraphCON at final time for output nodes
        x0 = self.dec(x0)
        return x0

In [9]:

gnn_model = deep_GNN(**MODEL_CONFIG)
model = LitModel(datasets=datasets, std=STD, model= gnn_model)

In [10]:
dataloaders = model.train_dataloader()

# %% Training

for i, batch in enumerate(dataloaders):
    # sample 
    print(f"Batch {i}")

    loss = model.training_step(batch, 0)
    print(f"Batch: {i} --- Loss: {loss.item():.2f}")
    # backwad pass
    model.optimizer.zero_grad()
    loss.backward()
    model.optimizer.step()



    

Number of CPUs: 14
Memory CPU in GB: 36.0
Batch 0


/Users/magnus/Documents/eth/PMLR/venv/lib/python3.10/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(1328.1466, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 1
tensor(1329.7198, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 2
tensor(1330.9776, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 3
tensor(1327.4208, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 4
tensor(1327.8705, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 5
tensor(1328.2026, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 6
tensor(1329.8271, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 7
tensor(1327.5175, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 8
tensor(1332.5893, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 9
tensor(1336.6306, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 10
tensor(1341.4285, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 11
tensor(1343.7970, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 12
tensor(1350.7523, dtype=torch.float64, grad_fn=<MeanBackward0>)
Batch 13
tensor(1358.5043, dtype=torch.float64, grad_fn=<MeanBackward