# Functions

In [5]:
# Define Directories
path_to_data = '/project2/jafarpou_227/Storage_Folder/Zhen/Data/CO2_Dataset/grid2D_512_128_gaussian' # path to your data
path_to_project = '/scratch1/zhenq/2.SpatioTemporalSurrogate' # path to your codebase
path_to_model = '/scratch1/zhenq/2.SpatioTemporalSurrogate/checkpoint_case2D/runet_base_RMSE' # path to checkpoint
path_to_config = 'config/case2_2D_runet.yaml'

In [6]:
# Load Packages
import yaml
import numpy as np
import sys
import os
import json
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
sys.path.append(path_to_project)

from sys import argv
from os.path import join
from torch.utils.data import DataLoader
from typing import (Callable, List, Optional, Sequence, Tuple, Union)
from simple_runet import RUNet
from simple_runet import DatasetCase2 as Dataset
from simple_runet import TrainerCase2 as Trainer
from simple_runet import get_multifield_loss, MULTIFIELD_LOSS_REGISTRY
from simple_runet import plot0

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


# Configuration

In [3]:
with open(path_to_config, "r") as f:
    config = yaml.safe_load(f)

train_config = config["train_config"]
model_config = config["model_config"]
dataset_config = config["dataset_config"]
print("=============================================\n\
train_config:\n", train_config)
print("=============================================\n\
model_config:\n", model_config)
print("=============================================\n\
dataset_config:\n", dataset_config)


train_config:
 {'learning_rate': 0.0005, 'num_epochs': 100, 'weight_decay': 0.0, 'batch_size': 4, 'verbose': 1, 'gradient_clip': True, 'gradient_clip_val': 40, 'step_size': 1000, 'gamma': 0.9}
model_config:
 {'filters': 16, 'units': [1, 1, 2], 'kernel_size': [5, 5, 1], 'padding': [2, 2, 0], 'with_control': True, 'with_states': True, 'norm_type': 'group', 'num_groups': 4, 'strides': [[2, 2, 1], [2, 2, 1]]}
dataset_config:
 {'pred_length': 8, 'year': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]}


# Load Data

In [4]:
# Load Data
batch_size = config['train_config']['batch_size'] # 4
pred_length = config['dataset_config']['pred_length'] # 8
year = config['dataset_config']['year'] # [0~15]
timestep = [_ * 12 for _ in year]
training_index = list(range(500))
validate_index = list(range(500, 550))
testing_index  = list(range(550, 600))

training_set = Dataset(training_index, path_to_data, timestep, pred_length)
train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)

validate_set = Dataset(validate_index, path_to_data, timestep, pred_length)
valid_loader = DataLoader(validate_set, batch_size=1, shuffle=False)

testing_set = Dataset(testing_index, path_to_data, timestep, pred_length)
test_loader = DataLoader(testing_set, batch_size=1, shuffle=False)

print(training_set.S.shape, training_set.P.shape, training_set.M.shape)
print(validate_set.S.shape, validate_set.P.shape, validate_set.M.shape)
print(testing_set.S.shape, testing_set.P.shape, testing_set.M.shape)


100%|██████████| 500/500 [00:55<00:00,  8.96it/s]
100%|██████████| 50/50 [00:06<00:00,  7.19it/s]
100%|██████████| 50/50 [00:06<00:00,  7.18it/s]


torch.Size([500, 16, 128, 512, 1]) torch.Size([500, 16, 128, 512, 1]) torch.Size([500, 1, 128, 512, 1])
torch.Size([50, 16, 128, 512, 1]) torch.Size([50, 16, 128, 512, 1]) torch.Size([50, 1, 128, 512, 1])
torch.Size([50, 16, 128, 512, 1]) torch.Size([50, 16, 128, 512, 1]) torch.Size([50, 1, 128, 512, 1])


# Build Model

In [7]:
#  Build Model
model = RUNet(**model_config).to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of total trainable parameters: {trainable_params/1e6} M')


Number of total trainable parameters: 3.002194 M


# Build Trainer

In [8]:
train_config['regularizer_weight'] = 0.001
pixel_loss = get_multifield_loss('pixel', loss_type='rel_mse', mode='both', reduce_dims=[1, 2, 3, 4])
regularizer = get_multifield_loss('gradient', filter_type='sobel', loss_type='rel_l1', mode='both', reduce_dims=[1, 2, 3, 4])
trainer = Trainer(model=model, 
                  train_config=train_config, 
                  pixel_loss=pixel_loss, 
                  # regularizer=regularizer,
                  device=device)


# Start Training

In [None]:
# Training Loop
if not os.path.exists(path_to_model):
    os.makedirs(path_to_model)
print(path_to_model)

loss_tracker_dict = trainer.train(train_loader, valid_loader, path_to_model)


/scratch1/zhenq/2.SpatioTemporalSurrogate/checkpoint_case2D/runet_base_RMSE


Training:  25%|██▌       | 251/1000 [02:52<08:32,  1.46it/s, loss=0.00918, loss_pixel=0.00918, loss_auxillary=0]

# Testing

In [16]:
test_results = trainer.test(test_loader)
test_losses = test_results['losses']
test_tensors = test_results['tensors']
print([(k, v.shape) for k, v in test_tensors.items()])


Testing: 100%|██████████| 400/400 [02:15<00:00,  2.95it/s, loss=0.0007, loss_pixel=0.0007, loss_auxillary=0.0000]


[('preds', torch.Size([400, 8, 2, 128, 512, 1])), ('outputs', torch.Size([400, 8, 2, 128, 512, 1])), ('states', torch.Size([400, 2, 128, 512, 1])), ('static', torch.Size([400, 1, 128, 512, 1]))]


In [18]:
trues, preds = test_tensors['outputs'], test_tensors['preds']
print(trues.shape, preds.shape)
s_trues, s_preds = trues[:,:,1,...], preds[:,:,1,...]
print(s_trues.shape, s_preds.shape)
p_trues, p_preds = trues[:,:,0,...], preds[:,:,0,...]
print(p_trues.shape, p_preds.shape)


torch.Size([400, 8, 2, 128, 512, 1]) torch.Size([400, 8, 2, 128, 512, 1])
torch.Size([400, 8, 128, 512, 1]) torch.Size([400, 8, 128, 512, 1])
torch.Size([400, 8, 128, 512, 1]) torch.Size([400, 8, 128, 512, 1])


In [None]:
figsize=(20, 2)
index, layer = 0, 0
plot0(p_trues, p_preds, index, layer, figsize=figsize, error_cmap='seismic', error_vmin=-0.1, error_vmax=0.1)
plot0(s_trues, s_preds, index, layer, figsize=figsize, error_cmap='seismic', error_vmin=-0.1, error_vmax=0.1)
