In [1]:
from pprint import pprint

import hydra
import torch
from omegaconf import OmegaConf

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_checkpoint_dict(path):
    checkpoint = torch.load(path)
    return checkpoint

In [3]:
def compare(a, b):
    _max = max(len(key) for key in a.keys()) + 2
    for key in a.keys():
        b_val = b.get(key, None)
        flag_present = 'X' if b_val is None else ' '
        flag_different = 'X' if a[key] != b_val else ' '
        print(f'[{flag_present}{flag_different}] {key:{_max}}: {a[key]} | {b_val}')

In [4]:
XP_list = [
    '4DVarNet-SSH-SST',
    '4DVarNet-SSH-only',
    'U-Net-SSH-SST',
    'U-Net-SSH-only',
]

XP = XP_list[3]

In [5]:
a = get_checkpoint_dict(f'TrainedModels/{XP}.ckpt')

with hydra.initialize(version_base=None, config_path='hydra_config'):
    b = hydra.compose(overrides=[f'+xp={XP}', '+entrypoint=train', '+file_paths=_LOCAL_imt'])



In [7]:
compare(a['hyper_parameters'], b['params'])

[ X] files_cfg                    : {'oi_path': '/gpfsstore/rech/yrf/commun/NATL60/NATL/oi/ssh_NATL60_swot_4nadir.nc', 'oi_var': 'ssh_mod', 'obs_mask_path': '/gpfsstore/rech/yrf/commun/NATL60/NATL/data_new/dataset_nadir_0d_swot.nc', 'obs_mask_var': 'ssh_mod', 'gt_path': '/gpfsdsstore/projects/rech/yrf/commun/NATL60/NATL/ref_new/NATL60-CJM165_NATL_ssh_y2013.1y.nc', 'gt_var': 'ssh', 'u_path': '/gpfsdsstore/projects/rech/yrf/commun/NATL60/NATL/ref_new/NATL60-CJM165_NATL_u_y2013.1y.nc', 'u_var': 'u', 'v_path': '/gpfsdsstore/projects/rech/yrf/commun/NATL60/NATL/ref_new/NATL60-CJM165_NATL_v_y2013.1y.nc', 'v_var': 'v'} | {'oi_path': '${file_paths.oi_swot_4nadir}', 'oi_var': 'ssh_mod', 'obs_mask_path': '${file_paths.pseudo_obs}', 'obs_mask_var': 'ssh_mod', 'gt_path': '${file_paths.natl_ssh_daily}', 'gt_var': 'ssh', 'u_path': '${file_paths.natl_u_daily}', 'u_var': 'u', 'v_path': '${file_paths.natl_v_daily}', 'v_var': 'v'}
[  ] iter_update                  : [0, 200, 400, 600, 1000, 1500, 8000] 

In [14]:
dict(b.datamodule)

{'_target_': 'dataloading_uv.FourDVarNetDataModule',
 'slice_win': {'lat': '${div:240,${datamodule.resize_factor}}', 'lon': '${div:240,${datamodule.resize_factor}}', 'time': '${params.dT}'},
 'strides': {'lat': 20, 'lon': 20, 'time': 1},
 'train_slices': [{'_target_': 'builtins.slice', '_args_': ['2013-02-04', '2013-09-30']}],
 'test_slices': [{'_target_': 'builtins.slice', '_args_': "${adjust_testslices:['2012-10-22', '2012-12-02'],${params.dT}}"}],
 'val_slices': [{'_target_': 'builtins.slice', '_args_': ['2013-01-01', '2013-02-04']}],
 'oi_path': '/DATASET/NATL/ssh_NATL60_swot_4nadir.nc',
 'oi_var': 'ssh_mod',
 'obs_mask_path': '/DATASET/NATL/dataset_nadir_0d_swot.nc',
 'obs_mask_var': 'ssh_mod',
 'gt_path': '/DATASET/NATL/NATL60-CJM165_NATL_ssh_y2013.1y.nc',
 'gt_var': 'ssh',
 'sst_path': '/DATASET/NATL/NATL60-CJM165_NATL_sst_y2013.1y.nc',
 'sst_var': 'sst',
 'u_path': '/DATASET/NATL/NATL60-CJM165_NATL_u_y2013.1y.nc',
 'u_var': 'u',
 'v_path': '/DATASET/NATL/NATL60-CJM165_NATL_v_y2