# GCA-ROM

This notebook trains the GCA-ROM model (link) and evaluates the performance.

In [None]:
import torch
import torch_geometric
import numpy as np
from itertools import product

import sys
sys.path.append('../gca-rom')
from gca_rom import network, pde, loader, plotting, preprocessing, training, initialization, testing, error, gui

sys.path.append('..')
from gfn_rom import defaults

In [None]:
problem_name, variable, mu_space, n_param, dim_pde, n_comp = pde.problem(2)
argv = ['advection', 'U', 4, 3, 1, defaults.rate, defaults.latent_size, defaults.mapper_sizes[0], int(n_param*defaults.N_basis_factor), defaults.mapper_weight, 2, n_param, defaults.epochs, n_comp]
HyperParams = network.HyperParams(argv)

HyperParams.learning_rate = defaults.lr
HyperParams.weight_decay = defaults.lambda_
HyperParams.seed = defaults.split_seed

# load data in batches
# If memory issues, reduce batch size
HyperParams.batch_size = np.inf

# remove early stopping and scheduler
HyperParams.tolerance = 0
HyperParams.miles = []
HyperParams.gamma = 1

In [None]:
device = initialization.set_device()
initialization.set_reproducibility(HyperParams)
initialization.set_path(HyperParams)

In [None]:
dataset_dir = 'data/matrix_large.mat'
dataset = loader.LoadDataset(dataset_dir, variable, dim_pde, n_comp)

graph_loader, train_loader, test_loader, \
    val_loader, scaler_all, scaler_test, xyz, VAR_all, VAR_test, \
        train_trajectories, test_trajectories = preprocessing.graphs_dataset(dataset, HyperParams)

params = torch.tensor(np.array(list(product(*mu_space))))
params = params.to(device)

In [None]:
HyperParams.seed = defaults.seed
initialization.set_reproducibility(HyperParams)

model = network.Net(HyperParams)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=HyperParams.learning_rate, weight_decay=HyperParams.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=HyperParams.miles, gamma=HyperParams.gamma)

In [None]:
try:
    model.load_state_dict(torch.load(HyperParams.net_dir+HyperParams.net_name+HyperParams.net_run+'.pt'))
    print('Loading saved network')
except FileNotFoundError:
    print('Training network')
    training.train(model, optimizer, device, scheduler, params, train_loader, test_loader, train_trajectories, test_trajectories, HyperParams)

In [None]:
model.to("cpu")
params = params.to("cpu")
vars = "GCA-ROM"
results, latents_map, latents_gca = testing.evaluate(VAR_all, model, graph_loader, params, HyperParams, range(params.shape[0]))

In [None]:
error_abs, norm = error.compute_error(results_test, VAR_test, scaler_test, HyperParams)
error.print_error(error_abs, norm, vars)
error.save_error(error_abs, norm, HyperParams, vars)