# Notebook for loading hydra config and check dataloaders

### Uses hydra.compose API

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import sys
from pprint import pprint
import numpy as np
import torch
import xarray as xr
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
import matplotlib.pyplot as plt

import hydra
from hydra import compose, initialize
from hydra.utils import instantiate, get_class
from omegaconf import OmegaConf

sys.path.append('..')
from main import FourDVarNetRunner
from hydra_main import FourDVarNetHydraRunner

## Choose xp

In [None]:
config_path = "../hydra_config"

pprint(os.listdir(os.path.join(config_path, "xp")))

In [None]:
xp = "sla_glorys"
entrypoint = "train"
training = "glorys"
file_paths = "hal"

##  Load experiment config

In [None]:
with initialize(config_path=config_path):
    cfg = compose(
        config_name="main",
        overrides=[f"xp={xp}", f"entrypoint={entrypoint}", f"training={training}", f"file_paths={file_paths}"])
    print(OmegaConf.to_yaml(cfg))

## Reproduce hydra_main.py

In [None]:
seed_everything(seed=cfg.get('seed', None))

dm = instantiate(cfg.datamodule)
dm.setup()

lit_mod_cls = get_class(cfg.lit_mod_cls)

runner = FourDVarNetHydraRunner(cfg.params, dm, lit_mod_cls)

In [None]:
train_dl = dm.train_dataloader()
val_dl = dm.val_dataloader()
test_dl = dm.test_dataloader()

print(len(train_dl), len(val_dl), len(test_dl))

In [None]:
for batch in train_dl:
    
    targets_OI, inputs_Mask, inputs_obs, targets_GT = batch
    break    
    
targets_OI, inputs_Mask, inputs_obs, targets_GT = (
    targets_OI.cpu().numpy(), 
    inputs_Mask.cpu().numpy(),
    inputs_obs.cpu().numpy(),
    targets_GT.cpu().numpy()
)

print('mean obs : ', inputs_obs[inputs_obs != 0].mean())
print('std obs  : ', inputs_obs[inputs_obs != 0].std())
print('min obs  : ', inputs_obs[inputs_obs != 0].min())
print('max obs  : ', inputs_obs[inputs_obs != 0].max())

print('NaNs obs : ', np.isnan(inputs_obs).sum()) 
print('---')
print('mean oi : ', targets_OI[targets_OI != 0].mean())
print('std oi  : ', targets_OI[targets_OI != 0].std())
print('min oi  : ', targets_OI[targets_OI != 0].min())
print('max oi  : ', targets_OI[targets_OI != 0].max())
print('NaNs oi : ', np.isnan(targets_OI).sum()) 

In [None]:
n_times = int(inputs_obs.shape[1])

fig, ax = plt.subplots(4, n_times, figsize=(16,16))

for i in range(n_times):
    ax[0,i].imshow(inputs_obs[0,i])
    ax[0,i].set_title(f"Input obs time {i}")

    ax[1,i].imshow(inputs_Mask[0,i])
    ax[1,i].set_title(f"Input mask time {i}")

    ax[2,i].imshow(targets_OI[0,i], vmin=-2, vmax=2)
    ax[2,i].set_title(f"Target OI time {i}")

    ax[3,i].imshow(targets_GT[0,i], vmin=-2, vmax=2)
    ax[3,i].set_title(f"Target GT time {i}")

plt.subplots_adjust()