In [None]:
import numpy as np
import torch
import gc
from torch.utils.data import DataLoader, random_split
import pandas as pd
from time import time
import os
import sys
import datetime
import pathlib
import logging
import argparse
import configparser
import ast
import shutil
from data.datasets import DarcyFlowDataset, KSDataset, ERA5Dataset, SSWEDataset
from models import FNO

In [7]:
from utils import train_utils

In [12]:
data_parameters = {
    'dataset_name': 'KS',
    'downscaling_factor': 1,
    'temporal_downscaling': 2,
    'pred_horizon': 5,
    't_start': 0,
    'init_steps': 5
}

training_parameters = {
    'batch_size': 32,
    'model': 'FNO',
    'uncertainty_quantification': 'dropout',
    'n_modes': (10, 12),
    'hidden_channels': 20,
    'dropout': 0.05,
    'lifting_channels': 128,
    'fourier_dropout': None,
    'projection_channels': 128
}

In [13]:
data_dir = f"data/{data_parameters["dataset_name"]}/processed/"

In [14]:
if data_parameters['dataset_name'] == 'DarcyFlow':
    train_data = DarcyFlowDataset(data_dir, test = False, downscaling_factor=int(data_parameters['downscaling_factor']))
    train_data_full_res = DarcyFlowDataset(data_dir, test = False)
    test_data = DarcyFlowDataset(data_dir, test = True)            
elif data_parameters["dataset_name"] == "KS":
    downscaling_factor = int(data_parameters['downscaling_factor'])
    temporal_downscaling_factor = int(data_parameters['temporal_downscaling'])
    pred_horizon = data_parameters['pred_horizon']
    t_start = data_parameters['t_start']
    init_steps = data_parameters['init_steps']

    assert 300 > temporal_downscaling_factor * (pred_horizon + t_start + init_steps)

    train_data = KSDataset(data_dir, test = False, downscaling_factor=downscaling_factor, mode = "autoregressive",
                pred_horizon=pred_horizon, t_start=t_start, init_steps=init_steps,
                temporal_downscaling_factor=temporal_downscaling_factor)
    test_data = KSDataset(data_dir, test = True, mode = "autoregressive",
                pred_horizon=pred_horizon, t_start=t_start, init_steps=init_steps,
                temporal_downscaling_factor=temporal_downscaling_factor)
    
elif data_parameters["dataset_name"] == "era5":
    data_dir = f"data/{data_parameters['dataset_name']}/"
    pred_horizon = data_parameters['pred_horizon']
    init_steps = data_parameters['init_steps']
    train_data = ERA5Dataset(data_dir, var = "train", init_steps = init_steps, prediction_steps = pred_horizon)
    val_data = ERA5Dataset(data_dir, var = "val", init_steps = init_steps, prediction_steps = pred_horizon)
    test_data = ERA5Dataset(data_dir, var = "test", init_steps = init_steps, prediction_steps = pred_horizon)

elif data_parameters["dataset_name"] == "SSWE":
    data_dir = f"data/{data_parameters['dataset_name']}/processed/"
    pred_horizon = data_parameters['pred_horizon']
    train_data = SSWEDataset(data_dir, test = False, pred_horizon = data_parameters["train_horizon"], return_all = True)
    test_data = SSWEDataset(data_dir, test = True, pred_horizon = pred_horizon, return_all = True)

In [15]:
if data_parameters["dataset_name"] != "SSWE":
    domain_range = train_data.get_domain_range()
else:
    # Requires Longitude and quadrature weights instead of domain range
    domain_range = (train_data.get_nlon(), train_data.get_train_weights(), test_data.get_nlon(), test_data.get_eval_weights())  
    
if data_parameters['dataset_name'] == 'DarcyFlow':
    # Validation data on full resolution
    train_data, _ = random_split(train_data, lengths = [0.8,0.2], generator = torch.Generator().manual_seed(42))
    _, val_data = random_split(train_data_full_res, lengths = [0.8,0.2], generator = torch.Generator().manual_seed(42))
elif data_parameters['dataset_name'] != 'ERA5':
    train_data, val_data = random_split(train_data, lengths = [0.8,0.2], generator = torch.Generator().manual_seed(42))
 

In [16]:
batch_size = training_parameters['batch_size']

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [17]:
in_channels = next(iter(train_loader))[0].shape[1]
out_channels = next(iter(train_loader))[1].shape[1]

model = train_utils.setup_model(training_parameters, device, in_channels, out_channels)

NameError: name 'device' is not defined