## Libraries

In [None]:
import wandb
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import time

In [None]:
from utils.visualization import *
from utils.dataset import *
from utils.miscellaneous import *
from utils.load import *
from utils.scaling import *
from models.gnn import *
from models.models import *
from training.train import *
from training.loss import *
from database.graph_creation import *

### Plot details

In [None]:
import matplotlib as mpl

mpl.rcParams['grid.color'] = 'k'
mpl.rcParams['grid.linestyle'] = ':'
mpl.rcParams['grid.linewidth'] = 0.5

mpl.rcParams['figure.figsize'] = [7, 5]
mpl.rcParams['figure.dpi'] = 100
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['savefig.bbox'] = 'tight'

mpl.rcParams['font.size'] = 18
mpl.rcParams['legend.fontsize'] = 'small'
mpl.rcParams['figure.titlesize'] = 'small'

mpl.rcParams['font.family'] = 'serif'

video_folder = figures_folder = 'results'

## Dataset creation

In [None]:
# wandb.login()

cfg_file = "config_finetune.yaml"
config = read_config(cfg_file)

wandb.finish()
wandb_logger = WandbLogger(log_model='all',
                           mode='disabled',
                           config=config,)

config = wandb.config

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

L.seed_everything(config.models['seed'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset_parameters = config['dataset_parameters']
scalers = copy(config['scalers'])
selected_node_features = config['selected_node_features']
selected_edge_features = config['selected_edge_features']

train_dataset, val_dataset, test_dataset, scalers = create_model_dataset(
    scalers=scalers, device=device, 
    **dataset_parameters, **selected_node_features, **selected_edge_features
)

In [None]:
temporal_dataset_parameters = config['temporal_dataset_parameters']

temporal_train_dataset = to_temporal_dataset(train_dataset, **temporal_dataset_parameters)

print('Number of training simulations:\t', len(train_dataset))
print('Number of training samples:\t', len(temporal_train_dataset))
print('Number of node features:\t', temporal_train_dataset[0].x.shape[-1])
print('Number of rollout steps:\t', temporal_train_dataset[0].y.shape[-1])

# Model

## Creation

In [None]:
num_node_features, num_edge_features = temporal_train_dataset[0].x.size(-1), temporal_train_dataset[0].edge_attr.size(-1)
num_nodes, num_edges = temporal_train_dataset[0].x.size(0), temporal_train_dataset[0].edge_attr.size(0)

temporal_res = dataset_parameters['temporal_res']
previous_t = temporal_dataset_parameters['previous_t']
time_start = temporal_dataset_parameters['time_start']
time_stop = temporal_dataset_parameters['time_stop']
max_rollout_steps = temporal_dataset_parameters['rollout_steps']
test_dataset_name = dataset_parameters['test_dataset_name']

In [None]:
model_parameters = copy(config['models'])
model_type = model_parameters.pop('model_type')

if model_type == 'MSGNN':
    num_scales = train_dataset[0].mesh.num_meshes
    model_parameters['num_scales'] = num_scales
    
model = get_model(model_type)(
    num_node_features=num_node_features,
    num_edge_features=num_edge_features,
    previous_t=previous_t,
    device=device,
    **model_parameters).to(device)

In [None]:
trainer_options = copy(config['trainer_options'])

lr_info = config['lr_info']

# info for testing dataset
temporal_test_dataset_parameters = get_temporal_test_dataset_parameters(
    config, temporal_dataset_parameters)

temporal_val_dataset = to_temporal_dataset(val_dataset, rollout_steps=-1, **temporal_test_dataset_parameters)

plmodule = LightningTrainer(model, lr_info, trainer_options, temporal_test_dataset_parameters).to(device)

pldatamodule = DataModule(temporal_train_dataset, temporal_val_dataset, 
                          batch_size=trainer_options['batch_size'])

print("Total number of paramters:", sum(p.numel() for p in model.parameters()))

## Training

In [None]:
# Define callbacks
checkpoint_callback = ModelCheckpoint(dirpath='lightning_logs/models',
                                    #   monitor="val_loss", mode='min',
                                      monitor="val_CSI_005", mode='max',
                                      save_top_k=1)
curriculum_callback = CurriculumLearning(max_rollout_steps, patience=5)
early_stopping      = EarlyStopping('val_CSI_005', mode='max', patience=trainer_options['patience'])
batch_size_finder   = CurriculumBatchSizeFinder(max_rollout_steps, init_val=4, steps_per_trial=1,
                                              max_trials=3)
wandb_logger.watch(model, log="all", log_graph=False)

# Load trained model
plmodule_kwargs = {'model': model, 
                   'lr_info': lr_info, 
                   'trainer_options': trainer_options, 
                   'temporal_test_dataset_parameters': temporal_test_dataset_parameters}

if 'saved_model' in config:
  model = plmodule.load_from_checkpoint(config['saved_model'], map_location=device, **plmodule_kwargs)
  model = plmodule.model.to(device)
  
# Define trainer
trainer = L.Trainer(accelerator="auto", devices='auto',
                    max_epochs=trainer_options['max_epochs'],
                    gradient_clip_val=1, 
                    # log_every_n_steps=50,
                    # enable_progress_bar=False,
                    # accumulate_grad_batches=4,
                    # profiler="simple",
                    precision='16-mixed',
                    logger=wandb_logger,
                    callbacks=[checkpoint_callback, 
                               curriculum_callback, 
                               early_stopping, 
                              #  batch_size_finder
                               ])

In [None]:
# Train and get trained model
# trainer.fit(plmodule, pldatamodule)

# Load the best model checkpoint
# plmodule.model = plmodule.load_from_checkpoint(checkpoint_callback.best_model_path, map_location=device, **plmodule_kwargs).model.to(device)

# model = plmodule.model.to(device)

## Testing

In [None]:
test_size = len(test_dataset)
maximum_time = test_dataset[0].WD.shape[1]
numerical_times = get_numerical_times(test_dataset_name+'_test', 
                test_size, temporal_res, maximum_time, 
                **temporal_test_dataset_parameters,
                overview_file='database/overview.csv')

In [None]:
test_dataset = [data.to(device) for data in test_dataset]
temporal_test_dataset = to_temporal_dataset(test_dataset, rollout_steps=-1, **temporal_test_dataset_parameters)

test_dataloader = DataLoader(temporal_test_dataset, batch_size=1, shuffle=False)

start_time = time.time()
predicted_rollout = trainer.predict(plmodule, dataloaders=test_dataloader)
prediction_times = time.time() - start_time
prediction_times = prediction_times/len(temporal_test_dataset)
predicted_rollout = [item for roll in predicted_rollout for item in roll]

In [None]:
spatial_analyser = SpatialAnalysis(predicted_rollout, prediction_times, 
                                   test_dataset, **temporal_test_dataset_parameters)

rollout_loss = spatial_analyser._get_rollout_loss(type_loss='MAE', only_where_water=False)
model_times = spatial_analyser.prediction_times

avg_speedup, std_speedup = get_speed_up(numerical_times, model_times)
print(f'mean speed-up: {avg_speedup:.2f}\nstd speed-up: {std_speedup:.3f}')

print('CSI 0.05m: ', spatial_analyser._get_CSI(water_threshold=0.05))
print('CSI 0.3m: ', spatial_analyser._get_CSI(water_threshold=0.3))

print('water depth error: ', rollout_loss.mean(0)[0].item())
print('discharge error: ', rollout_loss.mean(0)[1:].item())

# Plots

## Exploratory analysis (single simulation)
Find the best and worst simulations in a given dataset

Then, you can plot the simulation summaries

In [None]:
sorted_ids = spatial_analyser.plot_loss_per_simulation(type_loss='RMSE', ranking='loss', only_where_water=False, water_thresholds=[0.05, 0.3])

### Summary

In [None]:
id_dataset = 5

# # rotate sample to check invariance
# angle = -135
# test_dataset[id_dataset] = rotate_data_sample(test_dataset[id_dataset], angle, 
#                                               selected_node_features, selected_edge_features)

rollout_plotter = PlotRollout(model.to(device), test_dataset[id_dataset].to(device), 
                              scalers=scalers, type_loss='RMSE', **temporal_test_dataset_parameters)

rollout_plotter.plot_BC();

In [None]:
fig = rollout_plotter.explore_rollout(time_step=-1, scale=0, logscale=True)
# fig = rollout_plotter.explore_multiscale_rollout(time_step=-1, variable='V', logscale=True)

### Plot WD and V for a single simulation

In [None]:
plot_times = [11, 23, 35]

rollout_plotter.compare_h_rollout(plot_times, scale=0)
# rollout_plotter.compare_v_rollout(plot_times, scale=0, logscale=True)

### Compare flood arrival times (FAT)

In [None]:
rollout_plotter.compare_FAT(water_threshold=0.05, scale=0)

### Video

In [None]:
# rollout_plotter.mesh_scale_plot(scale=0)
# rollout_plotter.create_video(logscale=True)
# rollout_plotter.save_video(f'results/SWEGNN_test_{id_dataset:02d}', fps=7)

### Boundary conditions

In [None]:
spatial_analyser._plot_BCs();

### Spatial (F1 and CSI) and regression metrics for full dataset

In [None]:
mpl.rcParams['font.size'] = 18

fig, axs = plt.subplots(1, 2, figsize=(12, 5))

_, CSI = spatial_analyser.plot_CSI_rollouts(water_thresholds=[0.05, 0.3], ax=axs[0])
print(np.nanmean(CSI, 1).mean(1))

# _, F1 = spatial_analyser.plot_F1_rollouts(water_thresholds=[0.05, 0.3], ax=axs[0])
# print(np.nanmean(F1, 1).mean(1))

# _ = spatial_analyser._plot_rollouts(type_loss='RMSE', ax=axs[1])

_ = spatial_analyser._plot_rollouts(type_loss='MAE', ax=axs[1])

axs[0].grid(False)

plt.tight_layout()