# Functions

In [7]:
# Define Directories
path_to_data = '/project2/jafarpou_227/Storage_Folder/Zhen/Data/CO2_Dataset/grid60_gaussian' # path to your data
path_to_project = '/scratch1/zhenq/2.SpatioTemporalSurrogate' # path to the parent directory of your codebase 'simple_runet'
path_to_model = '/scratch1/zhenq/2.SpatioTemporalSurrogate/checkpoint/runet_a_MSE_gradient' # path to checkpoint
path_to_config = 'config/beginer_runet_a.yaml'

In [8]:
# Load Packages
import numpy as np
import sys
import os
import json
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
sys.path.append(path_to_project)

from sys import argv
from os.path import join
from torch.utils.data import DataLoader
from typing import (Callable, List, Optional, Sequence, Tuple, Union)
from simple_runet import RUNet #, RUNetParallel
from simple_runet import memory_usage_psutil
from simple_runet import DatasetCase1 as Dataset
from simple_runet import Trainer_RUNET as Trainer
from simple_runet import get_multifield_loss, MULTIFIELD_LOSS_REGISTRY

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


# Configuration

In [5]:
import yaml

with open(path_to_config, "r") as f:
    config = yaml.safe_load(f)

train_config = config["train_config"]
model_config = config["model_config"]
dataset_config = config["dataset_config"]
print("=============================================\n\
train_config:\n", train_config)
print("=============================================\n\
model_config:\n", model_config)
print("=============================================\n\
dataset_config:\n", dataset_config)


train_config:
 {'learning_rate': 0.001, 'num_epochs': 200, 'weight_decay': 0.0, 'batch_size': 2, 'verbose': 1, 'gradient_clip': False, 'gradient_clip_val': None, 'step_size': 400, 'gamma': 0.975}
model_config:
 {'filters': 16, 'units': [1, 1, 2], 'norm_type': 'group', 'num_groups': 4, 'strides': [2, 2], 'with_control': False}
dataset_config:
 {'num_years': 6, 'interval': 4, 'trainingset_folders': ['twowell_tworange_g20_z2', 'twowell_tworange_g60_z2', 'twowell_tworange_g100_z2', 'twowell_tworange_g20_z5', 'twowell_tworange_g60_z5', 'twowell_tworange_g100_z5'], 'validateset_folders': ['twowell_tworange_g20_z5']}


# Load Data

In [6]:
trainingset_folders = dataset_config['trainingset_folders']
validateset_folders = dataset_config['validateset_folders']
testingset_folders = validateset_folders
try:
    for folder in validateset_folders: 
        trainingset_folders.remove(folder)
except:
    pass

dataset_kwargs = {
    'root_to_data': path_to_data, 
    'num_years': dataset_config['num_years'], 
    'interval': dataset_config['interval']
}

training_set = Dataset(folders=trainingset_folders, **dataset_kwargs)
train_loader = DataLoader(training_set, batch_size=train_config['batch_size'], shuffle=True)
print(training_set.s.shape, training_set.p.shape, training_set.m.shape)

validate_set = Dataset(folders=validateset_folders, split_index=range(50), **dataset_kwargs)
valid_loader = DataLoader(validate_set, batch_size=1, shuffle=False)
print(validate_set.s.shape, validate_set.p.shape, validate_set.m.shape)

test_set = Dataset(folders=testingset_folders, split_index=range(50,100), **dataset_kwargs)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
print(test_set.s.shape, test_set.p.shape, test_set.m.shape)

data_loaders = (train_loader, valid_loader, test_loader)


torch.Size([499, 7, 64, 64, 20]) torch.Size([499, 7, 64, 64, 20]) torch.Size([499, 1, 64, 64, 20])
torch.Size([50, 7, 64, 64, 20]) torch.Size([50, 7, 64, 64, 20]) torch.Size([50, 1, 64, 64, 20])
torch.Size([50, 7, 64, 64, 20]) torch.Size([50, 7, 64, 64, 20]) torch.Size([50, 1, 64, 64, 20])


# Build Model

In [9]:
#  Build Model
model = RUNet(**model_config).to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of total trainable parameters: {trainable_params/1e6} M')


Number of total trainable parameters: 3.034674 M


# Build Trainer

In [10]:
train_config['regularizer_weight'] = 0.001
regularizer = get_multifield_loss('gradient', filter_type='sobel', loss_type='rel_l1', mode='both', reduce_dims=[1, 2, 3, 4])
trainer = Trainer(model=model, train_config=train_config, pixel_loss=nn.MSELoss(), 
                  regularizer=regularizer, device=device)


# Start Training

In [None]:
# Training Loop
if not os.path.exists(path_to_model):
    os.makedirs(path_to_model)

train_loss, valid_loss = trainer.train(train_loader, valid_loader, path_to_model)


Training: 100%|██████████| 250/250 [00:47<00:00,  5.21it/s, loss=0.00681, loss_pixel=0.00561, loss_auxillary=1.19]
Validing: 100%|██████████| 50/50 [00:01<00:00, 35.80it/s, loss=0.00212, loss_pixel=0.00143, loss_auxillary=0.685]


Epoch 001: Train - loss: 0.0068 | loss_pixel: 0.0056 | loss_auxillary: 1.1937 | Valid - loss: 0.0021 | loss_pixel: 0.0014 | loss_auxillary: 0.6851


Training: 100%|██████████| 250/250 [00:46<00:00,  5.35it/s, loss=0.00258, loss_pixel=0.00191, loss_auxillary=0.668]
Validing: 100%|██████████| 50/50 [00:01<00:00, 35.92it/s, loss=0.00222, loss_pixel=0.00166, loss_auxillary=0.562]


Epoch 002: Train - loss: 0.0026 | loss_pixel: 0.0019 | loss_auxillary: 0.6682 | Valid - loss: 0.0022 | loss_pixel: 0.0017 | loss_auxillary: 0.5623


Training: 100%|██████████| 250/250 [00:46<00:00,  5.38it/s, loss=0.00211, loss_pixel=0.00153, loss_auxillary=0.581]
Validing: 100%|██████████| 50/50 [00:01<00:00, 31.30it/s, loss=0.00153, loss_pixel=0.001, loss_auxillary=0.524]   


Epoch 003: Train - loss: 0.0021 | loss_pixel: 0.0015 | loss_auxillary: 0.5805 | Valid - loss: 0.0015 | loss_pixel: 0.0010 | loss_auxillary: 0.5242


Training: 100%|██████████| 250/250 [00:46<00:00,  5.37it/s, loss=0.002, loss_pixel=0.00144, loss_auxillary=0.561]  
Validing: 100%|██████████| 50/50 [00:01<00:00, 37.06it/s, loss=0.00154, loss_pixel=0.00102, loss_auxillary=0.524] 


Epoch 004: Train - loss: 0.0020 | loss_pixel: 0.0014 | loss_auxillary: 0.5607 | Valid - loss: 0.0015 | loss_pixel: 0.0010 | loss_auxillary: 0.5244


Training:  50%|█████     | 125/250 [00:23<00:22,  5.48it/s, loss=0.00184, loss_pixel=0.00129, loss_auxillary=0.543]