In [1]:
import torch
import xarray as xr
import numpy as np 
import hydra
import yaml
import inspect
from IPython.display import Markdown, display
from omegaconf import OmegaConf
import pytorch_lightning as pl
from collections import namedtuple
import functools as ft
from src.data import AugmentedDataset, XrDataset

In [2]:
gpu = 0  #*gpu:Union[None,int]

if torch.cuda.is_available() and gpu is not None:
  dev = f"cuda:{gpu}"
else:
  dev = "cpu"
device = torch.device(dev)

print(device)

#torch.set_default_device(device)

cuda:0


# eNATL vs NATL

In [3]:
#sound_speed_path_nan_filled = "/DATASET/envs/o23gauvr/ss_depth_features_weighted_mean_filled.nc"
sound_speed_path_enatl = "/DATASET/eNATL/eNATL60_BLB002_sound_speed_regrid_0_1000m.nc"
input_da_enatl = xr.open_dataset(sound_speed_path_enatl)
coords_enatl = input_da_enatl.coords


In [4]:
#sound_speed_path_nan_filled = "/DATASET/envs/o23gauvr/ss_depth_features_weighted_mean_filled.nc"
sound_speed_path_natl = "/DATASET/NATL/NATL60GULF-CJM165_sound_speed_regrid_0_1000m.nc"
input_da_natl = xr.open_dataset(sound_speed_path_natl)
coords_natl = input_da_natl.coords

In [5]:
print(f"eNATL latitude: [{np.round(coords_enatl['lat'].values.min(),2)},{np.round(coords_enatl['lat'].values.max(),2)}]")
print(f"NATL latitude: [{np.round(coords_natl['lat'].values.min(),2)},{np.round(coords_natl['lat'].values.max(),2)}]")

eNATL latitude: [32.0,43.95]
NATL latitude: [32.0,43.95]


In [6]:
print(f"eNATL longitude: [{np.round(coords_enatl['lon'].values.min(),2)},{np.round(coords_enatl['lon'].values.max(),2)}]")
print(f"NATL longitude: [{np.round(coords_natl['lon'].values.min(),2)},{np.round(coords_natl['lon'].values.max(),2)}]")

eNATL longitude: [-65.95,-54.0]
NATL longitude: [-65.95,-54.0]


In [7]:
print(f" eNATL min: {np.nanmin(input_da_enatl.celerity.values)}, eNATL max: {np.nanmax(input_da_enatl.celerity.values)}")
print(f" NATL min: {np.nanmin(input_da_natl.celerity.values)}, NATL max: {np.nanmax(input_da_natl.celerity.values)}")


 eNATL min: 1459.0439165829073, eNATL max: 1545.8698054910844
 NATL min: 1421.348940011407, NATL max: 1546.862179034367


we take a domain: {lat= slice(31,43), lon=slice(-64,-55)}

choice: cut domains of 1° on each side

In [8]:
print(f" eNATL mean: {np.nanmean(input_da_enatl.celerity.values)}, eNATL std: {np.nanstd(input_da_enatl.celerity.values)}")
print(f" NATL mean: {np.nanmean(input_da_natl.celerity.values)}, NATL std: {np.nanstd(input_da_natl.celerity.values)}")

 eNATL mean: 1513.9706701708644, eNATL std: 15.007288853760143
 NATL mean: 1511.844605664954, NATL std: 15.242737332216553


# Data module

## feature devellopment

In [9]:
def pprint_cfg(cfg):
    display(Markdown("""```yaml\n\n""" +yaml.dump(OmegaConf.to_container(cfg), default_flow_style=None, indent=2)+"""\n\n```"""))

def get_cfg(cfg_name):
    with hydra.initialize('./config/xp', version_base='1.3'):
        cfg = hydra.compose(config_name = cfg_name)
    pprint_cfg(cfg)
    return hydra.utils.call(cfg)


In [10]:
cfg = get_cfg("enatl_natl")

```yaml

dm:
  accoustic_var: [cutoff_freq, ecs]
  dl_kw: {batch_size: 4, num_workers: 1}
  xrds_kw:
    patch_dims: {lat: 240, lon: 240, time: 15}
    strides: {lat: 240, lon: 240, time: 1}
entrypoints: {train_dm: '${train_dm}'}
paths:
  accoustic: {test: /DATASET/NATL/NATL60GULF-CJM165_cutoff_freq_regrid_0_1000m.nc,
    train: /DATASET/eNATL/eNATL60_BLB002_cutoff_freq_regrid_0_1000m.nc}
  celerity: {test: /DATASET/NATL/NATL60GULF-CJM165_sound_speed_regrid_0_1000m.nc,
    train: /DATASET/eNATL/eNATL60_BLB002_sound_speed_regrid_0_1000m.nc}
spatial_domain:
  lat:
    _args_: [31, 43]
    _target_: builtins.slice
  lon:
    _args_: [-64, -55]
    _target_: builtins.slice
test_dm:
  dl_kw: ${dm.dl_kw}
  input_da: {accoustic_path: '${paths.accoustic.test}', accoustic_var: '${dm.accoustic_var}',
    celerity_path: '${paths.celerity.test}', spatial_domain: '${spatial_domain}'}
  norm_stats: {mean: 1511.844605664954, std: 15.242737332216553}
  time_domain:
    test:
      time:
        _args_: ['2012-10-01', '2013-09-30']
        _target_: builtins.slice
    train: null
    val: null
  xrds_kw: {patch_dims: '${dm.xrds_kw.patch_dims}', strides: '${dm.xrds_kw.strides}'}
train_dm:
  dl_kw: ${dm.dl_kw}
  input_da: {accoustic_path: '${paths.accoustic.train}', accoustic_var: '${dm.accoustic_var}',
    celerity_path: '${paths.celerity.train}', spatial_domain: '${spatial_domain}'}
  norm_stats: {mean: 1513.9706701708644, std: 15.007288853760143}
  time_domain:
    test: null
    train:
      time:
        _args_: ['2009-08-12', '2010-06-30']
        _target_: builtins.slice
    val:
      time:
        _args_: ['2009-07-01', '2009-08-11']
        _target_: builtins.slice
  xrds_kw: {patch_dims: '${dm.xrds_kw.patch_dims}', strides: '${dm.xrds_kw.strides}'}


```

In [11]:
class TransfertDataModule(BaseDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mean_std_domain = kwargs.get('mean_std_domain', 'train')
        ##* Sets attribute mean_std_domain to the value passed in the keyword arguments (kwargs). 
        ##* If the 'mean_std_domain' keyword argument is not provided, it defaults to 'train'.

    # def train_mean_std(self, variable='tgt'):
    #     train_data = (
    #         self.input_da.sel(self.xrds_kw.get('domain_limits', {}))
    #         .sel(self.domains[self.mean_std_domain])
    #     )
    #     return (
    #         train_data
    #         .sel(variable=variable)
    #         .pipe(lambda da: (da.mean().item(), da.std().item()))
    #     )
    def train_mean_std(self, variable='celerity'):
        # train_data = (
        #     self.input_da.sel(self.xrds_kw.get('domain_limits', {}))
        #     .sel(self.domains[self.mean_std_domain])
        # )
        ##* the selection over the domain is already done on load_input()
        
    
        return (
            self.train_data
            .sel(variable=variable)
            .pipe(lambda da: (da.mean().item(), da.std().item()))
        )
        
        
        
    def setup(self, stage='test'):
        post_fn = self.post_fn()

        if stage == 'fit':
            train_data = self.input_da.sel(self.domains['train'])
            self.train_ds = XrDataset(
                train_data, **self.xrds_kw, postpro_fn=post_fn,
            )
            if self.aug_kw:
                self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

            self.val_ds = XrDataset(
                self.input_da.sel(self.domains['val']),
                **self.xrds_kw,
                postpro_fn=post_fn,
            )
        else:
            self.test_ds = XrDataset(
                self.input_da.sel(self.domains['test']),
                **self.xrds_kw,
                postpro_fn=post_fn,
            )


def cosanneal_lr_adamw(lit_mod, lr, T_max, weight_decay=0.):
    opt = torch.optim.AdamW(
        [
            {'params': lit_mod.solver.grad_mod.parameters(), 'lr': lr},
            {'params': lit_mod.solver.obs_cost.parameters(), 'lr': lr},
            {'params': lit_mod.solver.prior_cost.parameters(), 'lr': lr / 2},
        ], weight_decay=weight_decay
    )
    return {
        'optimizer': opt,
        'lr_scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=T_max,
        ),
    }

# def load_and_interpolate(tgt_path, inp_path, tgt_var, inp_var, domain):
#     """
#     Load ground truth `tgt` and apply the satellites observations `inp`.
#     """
#     tgt = xr.open_dataset(tgt_path)[tgt_var].sel(domain)
#     inp = xr.open_dataset(inp_path)[inp_var].sel(domain)

#     return (
#         xr.Dataset(
#             dict(input=inp*tgt, tgt=(tgt.dims, tgt.values)),
#             inp.coords,
#         )
#         .transpose('time', 'lat', 'lon')
#         .to_array()
#     )

def matching_coords_test(cel_da,acc_da):

    dims = np.array(['time','lat','lon'])
    unmatched_dim = np.array([not np.array_equal(cel_da[dim].values, acc_da[dim].values) for dim in dims])
    if any(unmatched_dim):
        raise ValueError(f"Celerity and accoustic dataarrays don't have matching coordinates on {*dims[unmatched_dim],}")   ##* Unpacking with trailing comma. ##*https://stackoverflow.com/questions/42756537/f-string-syntax-for-unpacking-a-list-with-brace-suppression
  
  
def load_input(celerity_path,accoustic_path,acc_var,domains):
    cel_ds = xr.open_dataset(celerity_path).sel(domains)
    acc_ds = xr.open_dataset(accoustic_path).sel(domains)[acc_var]
    matching_coords_test(cel_ds,acc_ds)

    input_ds = xr.merge([cel_ds,acc_ds], join='outer').transpose('time','lon','lat','z').to_array() ##* t,x,y,z   ##* to_array() est long et pourrait poser probleme car toutes les variables ne dépendent pas de smemes coordonnées 
    return input_ds


def run(trainer, train_dm, test_dm, lit_mod, ckpt=None):
    """
    Fit and test on two distinct domains.
    """
    if trainer.logger is not None:
        print()
        print('Logdir:', trainer.logger.log_dir)
        print()

    trainer.fit(lit_mod, datamodule=train_dm, ckpt_path=ckpt)
    trainer.test(lit_mod, datamodule=test_dm, ckpt_path='best')

NameError: name 'BaseDataModule' is not defined

In [None]:

paths = cfg.paths
paths

{'celerity': {'train': '/DATASET/eNATL/eNATL60_BLB002_sound_speed_regrid_0_1000m.nc', 'test': '/DATASET/NATL/NATL60-CJM165-ssh-2012-2013-1_20.nc'}, 'accoustic': {'train': '/DATASET/eNATL/eNATL60_BLB002_cutoff_freq_regrid_0_1000m.nc', 'test': '/DATASET/NATL/NATL60GULF-CJM165_cutoff_freq_regrid_0_1000m.nc'}}

In [None]:

space_domain = cfg.spatial_domain
space_domain

{'lon': slice(-64, -55, None), 'lat': slice(31, 43, None)}

In [None]:
celerity_path = cfg.train_dm.input_da.celerity_path
accoustic_path = cfg.train_dm.input_da.accoustic_path
acc_var = cfg.dm.accoustic_var

In [None]:

time_domain = cfg.train_dm.time_domain
time_domain


{'train': {'time': slice('2009-08-12', '2010-06-30', None)}, 'val': {'time': slice('2009-07-01', '2009-08-11', None)}, 'test': None}

In [None]:
type(acc_var)  #! deal with this

omegaconf.listconfig.ListConfig

In [None]:
def matching_coords_test(cel_da,acc_da):

    dims = np.array(['time','lat','lon'])
    unmatched_dim = np.array([not np.array_equal(cel_da[dim].values, acc_da[dim].values) for dim in dims])
    if any(unmatched_dim):
        raise ValueError(f"Celerity and accoustic dataarrays don't have matching coordinates on {*dims[unmatched_dim],}")   ##* Unpacking with trailing comma. ##*https://stackoverflow.com/questions/42756537/f-string-syntax-for-unpacking-a-list-with-brace-suppression
  

In [None]:

def load_input(celerity_path,accoustic_path,acc_var,domains):
    cel_ds = xr.open_dataset(celerity_path).sel(domains)
    acc_ds = xr.open_dataset(accoustic_path).sel(domains)[acc_var]
    ###* spatial intersection over domain, time left untouched
    matching_coords_test(cel_ds,acc_ds)

    input_ds = xr.merge([cel_ds,acc_ds], join='outer').transpose('time','lon','lat','z').to_array() ##* t,x,y,z   ##* to_array() est long et pourrait poser probleme car toutes les variables ne dépendent pas de smemes coordonnées 
    ###* to_array() necessaire pour le xrdataset
    return input_ds


In [None]:
data_ds = load_input(celerity_path,accoustic_path,list(acc_var),space_domain)  ##TODO: manage this list opperator
#train_data = train_ds.to_array()
data_ds

In [None]:
cel_ds = xr.open_dataset(celerity_path).sel(space_domain)
acc_ds = xr.open_dataset(accoustic_path).sel(space_domain)[list(acc_var)]
###* spatial intersection over domain, time left untouched
matching_coords_test(cel_ds,acc_ds)

data_ds_bis = xr.Dataset(
    dict(input=acc_ds.values, tgt= cel_ds.values),
    cel_ds.coords,
).transpose('time','lon','lat','z')

In [None]:
data_arr = data_ds.to_array()

In [None]:
data_arr.shape

(3, 365, 180, 220, 107)

In [None]:
def train_mean_std(self, variable='celerity'):
    train_data = self.input_data.sel(domains[self.mean_std_domain])
    ##* the selection over the spatial domain is already done on load_input()
    
    
    return (
        train_data
        .sel(variable=variable)
        .pipe(lambda da: (da.mean().item(), da.std().item()))
    )


In [None]:
    
def setup(self, stage='test'):
    post_fn = self.post_fn()

    if stage == 'fit':
        train_data = self.input_da.sel(time_domain['train'])
        self.train_ds = XrDataset(
            train_data, **self.xrds_kw, postpro_fn=post_fn,
        )
        if self.aug_kw:
            self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

        self.val_ds = XrDataset(
            self.input_da.sel(time_domain['val']),
            **self.xrds_kw,
            postpro_fn=post_fn,
        )
    else:
        self.test_ds = XrDataset(
            self.input_da.sel(time_domain['test']),
            **self.xrds_kw,
            postpro_fn=post_fn,
        )

## Run class

In [12]:
TrainingItem = namedtuple('TrainingItem', ['input', 'tgt'])

##*input : celerity data
##* target: accoustic data

In [14]:
class BaseDataModule(pl.LightningDataModule):
    def __init__(self, input_da, time_domain, xrds_kw, dl_kw, aug_kw=None, norm_stats=None, **kwargs):
        super().__init__()
        self.input_da = input_da
        self.time_domain = time_domain
        self.xrds_kw = xrds_kw
        self.dl_kw = dl_kw
        self.aug_kw = aug_kw if aug_kw is not None else {}
        self._norm_stats = norm_stats

        self.train_ds = None
        self.val_ds = None
        self.test_ds = None
        self._post_fn = None
    
    def get_norm_stats(self):
        if self._norm_stats is None:
            self._norm_stats = self.train_mean_std()
            print("Norm stats", self._norm_stats)
        else:
            self.norm_stats = tuple(self._norm_stats.values())
            
        return self._norm_stats
    

    def train_mean_std(self):
        pass

    def post_fn(self):
        m, s = self.get_norm_stats()
        normalize = lambda item: (item - m) / s
        return ft.partial(ft.reduce,lambda i, f: f(i), [
            TrainingItem._make,
            lambda item: item._replace(tgt=normalize(item.celerity)),
            lambda item: item._replace(input=normalize(item.accoustic)),
        ])
        
        ##*input : celerity data
        ##* target: accoustic data
        
    def setup(self):
        pass

In [47]:
class TransfertDataModule(BaseDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mean_std_domain = kwargs.get('mean_std_domain', 'train')
        ##* Sets attribute mean_std_domain to the value passed in the keyword arguments (kwargs). 
        ##* If the 'mean_std_domain' keyword argument is not provided, it defaults to 'train'.


    def train_mean_std(self, variable='celerity'):
        train_data = self.input_da.sel(self.time_domain[self.mean_std_domain])
        ##* the selection over the spatial domain is already done on load_input()
        
        
        return (
            train_data
            [variable]
            .pipe(lambda da: (da.mean().item(), da.std().item()))
        )
            

    def setup(self, stage='test'):
        post_fn = self.post_fn()

        if stage == 'fit':
            train_data = self.input_da.sel(self.time_domain['train'])
            self.train_ds = XrDataset(
                train_data, **self.xrds_kw, postpro_fn=post_fn,
            )
            if self.aug_kw:
                self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

            self.val_ds = XrDataset(
                self.input_da.sel(self.time_domain['val']),
                **self.xrds_kw,
                postpro_fn=post_fn,
            )
        else:
            self.test_ds = XrDataset(
                self.input_da.sel(self.time_domain['test']),
                **self.xrds_kw,
                postpro_fn=post_fn,
            )




def matching_coords_test(cel_da,acc_da):

    dims = np.array(['time','lat','lon'])
    unmatched_dim = np.array([not np.array_equal(cel_da[dim].values, acc_da[dim].values) for dim in dims])
    if any(unmatched_dim):
        raise ValueError(f"Celerity and accoustic dataarrays don't have matching coordinates on {*dims[unmatched_dim],}")   ##* Unpacking with trailing comma. ##*https://stackoverflow.com/questions/42756537/f-string-syntax-for-unpacking-a-list-with-brace-suppression
  
  
def load_input(celerity_path,accoustic_path,acc_var,spatial_domain):
    acc_var = list(acc_var)
    cel_ds = xr.open_dataset(celerity_path).sel(spatial_domain)
    acc_ds = xr.open_dataset(accoustic_path).sel(spatial_domain)[acc_var]
    matching_coords_test(cel_ds,acc_ds)

    input_ds = xr.merge([cel_ds,acc_ds], join='outer').transpose('time','lon','lat','z').to_array() ##* t,x,y,z   ##* to_array() est long et pourrait poser probleme car toutes les variables ne dépendent pas de smemes coordonnées 
    ###* to_array() necessaire pour le XrDataset

    return input_ds


def run(trainer, train_dm, test_dm, lit_mod, ckpt=None):
    """
    Fit and test on two distinct domains.
    """
    if trainer.logger is not None:
        print()
        print('Logdir:', trainer.logger.log_dir)
        print()

    trainer.fit(lit_mod, datamodule=train_dm, ckpt_path=ckpt)
    trainer.test(lit_mod, datamodule=test_dm, ckpt_path='best')

In [44]:
def pprint_cfg(cfg):
    display(Markdown("""```yaml\n\n""" +yaml.dump(OmegaConf.to_container(cfg), default_flow_style=None, indent=2)+"""\n\n```"""))

def get_cfg(cfg_name):
    with hydra.initialize('./config/xp', version_base='1.3'):
        cfg = hydra.compose(config_name = cfg_name)
    pprint_cfg(cfg)
    return hydra.utils.call(cfg)

cfg = get_cfg("enatl_natl")

```yaml

dm:
  accoustic_var: [cutoff_freq, ecs]
  dl_kw: {batch_size: 4, num_workers: 1}
  xrds_kw:
    patch_dims: {lat: 240, lon: 240, time: 15}
    strides: {lat: 240, lon: 240, time: 1}
entrypoints: {train_dm: '${train_dm}'}
paths:
  accoustic: {test: /DATASET/NATL/NATL60GULF-CJM165_cutoff_freq_regrid_0_1000m.nc,
    train: /DATASET/eNATL/eNATL60_BLB002_cutoff_freq_regrid_0_1000m.nc}
  celerity: {test: /DATASET/NATL/NATL60GULF-CJM165_sound_speed_regrid_0_1000m.nc,
    train: /DATASET/eNATL/eNATL60_BLB002_sound_speed_regrid_0_1000m.nc}
spatial_domain:
  lat:
    _args_: [31, 43]
    _target_: builtins.slice
  lon:
    _args_: [-64, -55]
    _target_: builtins.slice
test_dm:
  dl_kw: ${dm.dl_kw}
  input_da: {accoustic_path: '${paths.accoustic.test}', accoustic_var: '${dm.accoustic_var}',
    celerity_path: '${paths.celerity.test}', spatial_domain: '${spatial_domain}'}
  norm_stats: {mean: 1511.844605664954, std: 15.242737332216553}
  time_domain:
    test:
      time:
        _args_: ['2012-10-01', '2013-09-30']
        _target_: builtins.slice
    train: null
    val: null
  xrds_kw: {patch_dims: '${dm.xrds_kw.patch_dims}', strides: '${dm.xrds_kw.strides}'}
train_dm:
  dl_kw: ${dm.dl_kw}
  input_da: {accoustic_path: '${paths.accoustic.train}', accoustic_var: '${dm.accoustic_var}',
    celerity_path: '${paths.celerity.train}', spatial_domain: '${spatial_domain}'}
  norm_stats: {mean: 1513.9706701708644, std: 15.007288853760143}
  time_domain:
    test: null
    train:
      time:
        _args_: ['2009-08-12', '2010-06-30']
        _target_: builtins.slice
    val:
      time:
        _args_: ['2009-07-01', '2009-08-11']
        _target_: builtins.slice
  xrds_kw: {patch_dims: '${dm.xrds_kw.patch_dims}', strides: '${dm.xrds_kw.strides}'}


```

In [48]:
type(cfg.dm.accoustic_var)

omegaconf.listconfig.ListConfig

In [49]:
celerity_path = cfg.train_dm.input_da.celerity_path
accoustic_path = cfg.train_dm.input_da.accoustic_path
spatial_domain = cfg.train_dm.input_da.spatial_domain
acc_var = cfg.train_dm.input_da.accoustic_var  
input_da = load_input(celerity_path,accoustic_path,acc_var,spatial_domain)
input_da

<class 'omegaconf.listconfig.ListConfig'>


KeyError: "No variable named ['cutoff_freq', 'ecs']. Variables on the dataset include ['cutoff_freq', 'ecs', 'ecs_sound_speed', 'surface_sound_speed', 'surface_temp', 'surface_sal', 'lat', 'lon', 'time']"

In [18]:
time_domain = cfg.train_dm.time_domain
xrds_kw = cfg.train_dm.xrds_kw
dl_kw = cfg.train_dm.dl_kw
norm_stats = cfg.train_dm.norm_stats
base_dm = BaseDataModule(input_da,time_domain,xrds_kw, dl_kw,norm_stats = norm_stats)
vars(base_dm)

{'_log_hyperparams': False,
 'prepare_data_per_node': True,
 'allow_zero_length_dataloader_with_multiple_devices': False,
 'trainer': None,
 'input_da': <xarray.DataArray (variable: 3, time: 365, lon: 180, lat: 220, z: 107)>
 array([[[[[ 1.53228537e+03,  1.53229925e+03,  1.53228822e+03, ...,
             1.49892827e+03,  1.49828064e+03,  1.49773538e+03],
           [ 1.53229658e+03,  1.53231050e+03,  1.53229999e+03, ...,
             1.49898530e+03,  1.49833786e+03,  1.49779079e+03],
           [ 1.53232989e+03,  1.53234395e+03,  1.53233493e+03, ...,
             1.49910255e+03,  1.49845946e+03,  1.49788259e+03],
           ...,
           [ 1.51036462e+03,  1.50679204e+03,  1.50408346e+03, ...,
                        nan,             nan,             nan],
           [ 1.50949525e+03,  1.50674023e+03,  1.50543108e+03, ...,
                        nan,             nan,             nan],
           [ 1.50552609e+03,  1.50091945e+03,  1.49877528e+03, ...,
                        nan,   

In [19]:
trans_dm = TransfertDataModule(input_da,time_domain,xrds_kw, dl_kw, norm_stats = norm_stats)
vars(trans_dm)

{'_log_hyperparams': False,
 'prepare_data_per_node': True,
 'allow_zero_length_dataloader_with_multiple_devices': False,
 'trainer': None,
 'input_da': <xarray.DataArray (variable: 3, time: 365, lon: 180, lat: 220, z: 107)>
 array([[[[[ 1.53228537e+03,  1.53229925e+03,  1.53228822e+03, ...,
             1.49892827e+03,  1.49828064e+03,  1.49773538e+03],
           [ 1.53229658e+03,  1.53231050e+03,  1.53229999e+03, ...,
             1.49898530e+03,  1.49833786e+03,  1.49779079e+03],
           [ 1.53232989e+03,  1.53234395e+03,  1.53233493e+03, ...,
             1.49910255e+03,  1.49845946e+03,  1.49788259e+03],
           ...,
           [ 1.51036462e+03,  1.50679204e+03,  1.50408346e+03, ...,
                        nan,             nan,             nan],
           [ 1.50949525e+03,  1.50674023e+03,  1.50543108e+03, ...,
                        nan,             nan,             nan],
           [ 1.50552609e+03,  1.50091945e+03,  1.49877528e+03, ...,
                        nan,   

train_mean_std is computationally expensive: 33min

<span style="color:red"> 
les valeurs retournées sont différentes de celle obtenues par numpy nanmean et nanstd
text</span>

In [20]:
#trans_dm.train_mean_std()

In [21]:
# print(trans_dm.input_da.mean())
# print(trans_dm.input_da.std())

In [22]:
trans_dm.post_fn()

functools.partial(<built-in function reduce>, <function BaseDataModule.post_fn.<locals>.<lambda> at 0x7f88c397d620>, [<bound method TrainingItem._make of <class '__main__.TrainingItem'>>, <function BaseDataModule.post_fn.<locals>.<lambda> at 0x7f88c397dd00>, <function BaseDataModule.post_fn.<locals>.<lambda> at 0x7f88c397d6c0>])

In [23]:
trans_dm.norm_stats

(1513.9706701708644, 15.007288853760143)

In [24]:
trans_dm.setup(stage='fit')

In [25]:
trans_dm.setup(stage='test')

# Model

In [None]:
def cosanneal_lr_adamw(lit_mod, lr, T_max, weight_decay=0.):
    opt = torch.optim.AdamW(
        [
            {'params': lit_mod.solver.grad_mod.parameters(), 'lr': lr},
            {'params': lit_mod.solver.obs_cost.parameters(), 'lr': lr},
            {'params': lit_mod.solver.prior_cost.parameters(), 'lr': lr / 2},
        ], weight_decay=weight_decay
    )
    return {
        'optimizer': opt,
        'lr_scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=T_max,
        ),
    }