<a href="https://colab.research.google.com/github/CIA-Oceanix/DLGD2022/blob/main/tutorial-1-hydra/DLGD2022_tuto_hydra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training an ocean forecast model: a configuration management use case: 


#### Install dependencies and download data

In [1]:
!pip install xarray hvplot hydra-core torchmetrics -q
!wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
!wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_sst_y2013.1y.nc

[K     |████████████████████████████████| 3.2 MB 14.0 MB/s 
[K     |████████████████████████████████| 151 kB 93.3 MB/s 
[K     |████████████████████████████████| 529 kB 64.5 MB/s 
[K     |████████████████████████████████| 117 kB 73.9 MB/s 
[K     |████████████████████████████████| 79 kB 8.6 MB/s 
[?25h  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone
--2022-11-15 09:52:09--  https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
Resolving s3.us-east-1.wasabisys.com (s3.us-east-1.wasabisys.com)... 38.27.106.51, 38.27.106.53
Connecting to s3.us-east-1.wasabisys.com (s3.us-east-1.wasabisys.com)|38.27.106.51|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://s3.eu-central-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc [following]
--2022-11-15 09:52:09--  https://s3.eu-central-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ss

In [2]:
!git clone https://github.com/CIA-Oceanix/DLGD2022.git --branch tuto-hydra --depth=1

Cloning into 'DLGD2022'...
remote: Enumerating objects: 37, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 37 (delta 2), reused 14 (delta 0), pack-reused 0[K
Unpacking objects: 100% (37/37), done.


In [10]:
!cd DLGD2022/tutorial-1-hydra

In [3]:
import sys
sys.path.append('DLGD2022/tutorial-1-hydra')

# Project Walkthrough

In [4]:
#@title Imports
import xarray as xr
import pandas as pd
import scipy.ndimage as ndi
import numpy as np
import holoviews as hv
import torch
import hydra
import functools
hv.extension('matplotlib')

import src


In [5]:
#@title Read data
print('\n############### function reading data : src.load_ssh ###############\n' )
#help(src.load_ssh)
ssh_da = src.load_ssh()
print(ssh_da)


print('\n############### Animation preview ###############\n' )

def anim(da, name, climda=None):
    climda = climda if climda is not None else da
    clim = climda.pipe(lambda da: (da.quantile(0.005).item(), da.quantile(0.995).item()))
    return  (hv.Dataset(da)
            .to(hv.QuadMesh, ['lon', 'lat']).relabel(name)
            .options(cmap='RdBu',clim=clim, colorbar=True))

hv.output(
    anim(ssh_da.isel(time=slice(None,30, 2)), 'Sea surface height (m)'),
    holomap='gif', fps=2, dpi=50, size=150)


############### function reading data : src.load_ssh ###############

<xarray.DataArray 'ssh' (time: 365, lat: 201, lon: 201)>
[14746365 values with dtype=float64]
Coordinates:
  * lon      (lon) float64 -65.0 -64.95 -64.9 -64.85 ... -55.1 -55.05 -55.0
  * lat      (lat) float64 33.0 33.05 33.1 33.15 33.2 ... 42.85 42.9 42.95 43.0
  * time     (time) datetime64[ns] 2012-10-01T12:00:00 ... 2013-09-30T12:00:00

############### Animation preview ###############



* Data taken from an ocean simulation
* Ocean surface topography relates to surface currents

**Objective : Given the sea surface height (SSH) from the N last days $[D_{-N+1}, D_{-N +2},..., D_{0}]$, estimate the  SSH at $D_{+1}$**

In [6]:
print('\n############### function instantiating dataloaders: src.dataloaders ###############\n' )
#help(src.dataloaders)

train_dl, val_dl = src.dataloaders(
    ssh_da,
    training_periods=[slice('2013-01-01', '2013-09-30')],
    validation_periods=[slice('2012-10-01', '2012-12-31')],
    number_of_past_days=5,
    batch_size=16
)

x, y = next(iter(train_dl))
print(f"""
input shape: \t\t\t {x.shape}
target shape: \t\t\t {y.shape} \n
number of training batches: \t {len(train_dl)}
number of validation batches: \t {len(val_dl)}
""")


############### function instantiating dataloaders: src.dataloaders ###############


input shape: 			 torch.Size([16, 5, 201, 201])
target shape: 			 torch.Size([16, 1, 201, 201]) 

number of training batches: 	 17
number of validation batches: 	 6



![](https://i.imgur.com/4xx5TdM.png)

In [7]:
print('\n############### function training model: src.train ###############\n\n' )
#help(src.train)

model = src.train(
    model=torch.nn.Conv2d(5, 1, kernel_size=5, padding=2),
    partial_optimizer=functools.partial(torch.optim.Adam, lr=3e-3),
    dataloaders=(train_dl, val_dl),
    n_epochs=50
)


############### function training model: src.train ###############




Epoch: 49 	 - train err: 0.037 - val err: 0.030 (base err: 0.035) (m): 100%|██████████| 50/50 [01:06<00:00,  1.33s/it]


![](https://i.imgur.com/4raqoYs.png)

In [8]:
print('############### function computing forecast metrics: src.forecast_diagnostic ###############\n' )
#help(src.forecast_diagnostic)

xrds, metrics_df = src.forecast_diagnostic(model, val_dl, number_of_forecast_days=5)
print(metrics_df.to_markdown())

############### function computing forecast metrics: src.forecast_diagnostic ###############

|                        |   day + 1 |   day + 2 |   day + 3 |   day + 4 |   day + 5 |
|:-----------------------|----------:|----------:|----------:|----------:|----------:|
| persistence error (cm) |      3.48 |      6.25 |      8.61 |     10.61 |     12.33 |
| forecast error (cm)    |      2.95 |      5.3  |      7.61 |      9.61 |     11.4  |
| improvement (%)        |     15.12 |     15.2  |     11.6  |      9.41 |      7.56 |


![](https://i.imgur.com/et6sMye.png)

# The problem

## Given a project structured liked this:

```python
##################### main.py #####################

# Training data
train_dl, val_dl = src.dataloaders(
    dataarray=src.load_ssh(), 
    training_periods=[slice('2013-01-01', '2013-09-30')],
    validation_periods=[slice('2012-10-01', '2012-12-31')],
    number_of_past_days=5,
    batch_size=16
)

# Train model
model = src.train(
    model=torch.nn.Conv2d(5, 1, kernel_size=5, padding=2),
    partial_optimizer=functools.partial(torch.optim.Adam, lr=8e-3),
    dataloaders=(train_dl, val_dl),
    n_epochs=50
)

# Generate metrics
xrds, metrics_df = src.forecast_diagnostic(model, val_dl, number_of_forecast_days=5)
print(metrics_df.to_markdown())
```

## Let's say I want to run some additional experiments using this project:


- What if I evaluate on the summer season and test on the rest of the year (**change train/val split**) 
- What is the impact of the number of input days (**change input format**)
- What is the impact of the batch_size and learning rate (**change  training hyperparameters**)
- What other model architecture might work better (**change classes and parameters**)
- What results do I get if I train a model to forecast sea surface temperature instead of sea surface height (**change dataset**)

## But I don't want to:
- turn my code into spaghetti code
- break previous experiments when designing new ones


## In order to be able to:

- reproduce "any" intermediate results
- Collaborate with other people:
    - Share my results with people using the same project
    - Run some experiments designed by someone else and "quickly understand" what has changed
    - Use some config or code developped by someone else in my experiments 
 
----------

## **How would YOU do it ?**




# Possible DIY solutions


- Change the values inside main.py ?
- Duplicate main.py and change some values ?
- Add if statements for each case ?
- Move the "configurables" to a different file and read it in main.py
    * duplicate the config and change the relevant fields 
    * read arguments from the command line

----------- 

## Some personal comments:
- With duplication comes some maintanability issues when an interface changes
- With duplication + modifications it may be hard to track what changes have been made 
- Having a lot of if statements in a file may render the code difficult to read / share
- modifying scripts may make difficult to remember which source code and parameters were used and to produce some results
- Using configuration files, it may make difficult to switch between different configs without modifying the source code (e.g network architecture)


# The [Hydra](https://hydra.cc/) solution



## Restructure the `main.py` to `hydra_main.py` and `config/main.yaml`


### In `config.yaml`:

- use `${now:'%y-%m-%d--%H-%M-%S'}` to inject the formatted date in the config
- use `${path.to.node}` to reuse a config value 
- use node with the `_target_` key to point to python classes or function
- use the `defaults` key to combine the config from multiple files 


```yaml
##################### config/main.yaml #####################
logdir: hydra_logs/${now:'%y-%m-%d--%H-%M-%S'}
data:
    _target_: src.dataloaders
    dataarray: {_target_: "src.load_ssh"}
    training_periods:
        - {_target_: "builtins.slice", _args_:["2013-01-01", "2013-09-30"]} # slice("2013-01-01", "2013-09-30")
    validation_periods: 
        - {_target_: "builtins.slice", _args_:["2012-10-01", "2012-12-31"]}
    batch_size: 16
    number_of_past_days: 5
training:
    _target_: src.train
    model: ${model}
    partial_optimizer:
        _target_: torch.optim.Adam
        lr: 0.008
        _partial_: true # functools.partial(torch.optim.Adam, lr=0.008)
    n_epochs: 50
diagnostic:
    _target_: src.forecast_diagnostic
    number_of_forecast_days: ${data.number_of_past_days} 

defaults:
    - model: simple_conv  # Load config/model/simple_conv.yaml config in "model" key
```

```yaml
##################### config/model/simple_conv.yaml #####################
_target_: torch.nn.Conv2d
in_channels: ${data.number_of_past_days}
out_channels: 1
kernel_size: 5
padding: 2
```


-------------

### In `hydra_main.py`:

- use the `hydra.main` decorator to inject the config in the main function
- use `hydra.utils.call` to instantiate or call functions from the configuration (nodes with `_target_`)


```python
##################### hydra_main.py #####################

@hydra.main(config_path='config', config_name='main', version_base="1.2")
def main(cfg):
    # Training data
    train_dl, val_dl = hydra.utils.call(cfg.data) 
    
    # Training model
    model = hydra.utils.call(cfg.training, dataloaders=(train_dl, val_dl))

    # Compute diagnostics
    xrds, metrics_df = hydra.utils.call(cfg.diagnostic, model=model, dl=val_dl)
    print(metrics_df.to_markdown())


if __name__ =='__main__':
    main()
```

-----------

### Run with:

```sh
python DLGD2022/tutorial-1-hydra/hydra_main.py
```



## Extending the project using hydra

### Use and save set of parameters

**Run a job and modify parameters on the fly**
```
python DLGD2022/tutorial-1-hydra/hydra_main.py  \ 
                    data.number_of_past_days=7 \
                    data.batch_size=8 \   
                    training.partial_optimizer=0.002    
```



**Or save this specific config and call it from the CLI : `config/xp/my_params.yaml`**
```yaml
#@package _global_
data:
    number_of_past_days: 7
    batch_size: 8
training:
    partial_optimizer:
        lr: 2e-3
```

Run with
```
python hydra_main.py +xp=my_params
```


### Extend the code and use it in the config



#### Write some new python module **new_src.py**


```python
################################ new_src.py ################################
def load_sst():
    """Return Sea surface temperature xarray dataarray"""
    return xr.open_dataset('NATL60-CJM165_GULFSTREAM_sst_y2013.1y.nc').sst


class MyForecastModel(torch.nn.Module):
    def __init__(self, ninput, nhidden, nlayers, kernel_size=3, residual=False):
        super().__init__()
        assert kernel_size % 2 ==1, "Please use odd kernel_size" 
        self.residual = residual
        in_channel, out_channel = ninput, nhidden
        self.net = torch.nn.Sequential()

        for layer in range(nlayers):
            self.net.add_module(f'conv_{layer}', torch.nn.Conv2d(in_channel, out_channel, kernel_size, padding=kernel_size//2))
            self.net.add_module(f'bn_{layer}',torch.nn.BatchNorm2d(out_channel))
            self.net.add_module(f'act_{layer}',torch.nn.ReLU())
            in_channel = out_channel
        
        self.net.add_module(f'conv_out',torch.nn.Conv2d(in_channel, 1, kernel_size, padding=kernel_size//2))


    def forward(self, x):
        out = self.net(x)
        if self.residual:
            return x[:, -1:] + out
        return out
```

----

####  Add a new model configuration **config/model/small_cnn.yaml**


```yaml
################################ new_src.py ################################
_target_: new_src.MyForecastModel
ninput: ${data.number_of_past_days}
nhidden: 16
nlayers: 1
```


-----------


#### Then use it directly in the CLI ...

```bash
python DLGD2022/tutorial-1-hydra/hydra_main.py \ 
    model=small_cnn \ 
    data.dataarray._target_=new_src.load_sst
```


--------

#### Or save it in a specific config file **config/xp/sst_small_cnn.yaml** and use it in the cli


```yaml
# @package _global_ 
data:
    dataarray: {_target_: new_src.load_sst}
training:
    partial_optimizer:
        lr: 1e-3
defaults:
    - override /model: small_cnn
```


```bash
python DLGD2022/tutorial-1-hydra/hydra_main.py +xp=sst_small_cnn
```

### Run

In [13]:
!python DLGD2022/tutorial-1-hydra/hydra_main.py +xp=sst_small_cnn 

########################## Job config ##########################
model:
  _target_: new_src.MyForecastModel
  ninput: 5
  nhidden: 16
  nlayers: 1
  residual: true
logdir: hydra_logs/22-11-15--10-15-29
data:
  _target_: src.dataloaders
  dataarray:
    _target_: new_src.load_sst
  training_periods:
  - _target_: builtins.slice
    _args_:
    - '2013-01-01'
    - '2013-09-30'
  validation_periods:
  - _target_: builtins.slice
    _args_:
    - '2012-10-01'
    - '2012-12-31'
  batch_size: 16
  number_of_past_days: 5
training:
  _target_: src.train
  model:
    _target_: new_src.MyForecastModel
    ninput: 5
    nhidden: 16
    nlayers: 1
    residual: true
  partial_optimizer:
    _target_: torch.optim.Adam
    lr: 0.001
    _partial_: true
  n_epochs: 50
diagnostic:
  _target_: src.forecast_diagnostic
  number_of_forecast_days: 5

[2022-11-15 10:15:30,043][numexpr.utils][INFO] - NumExpr defaulting to 2 threads.
Epoch: 49 	 - train err: 0.511 - val err: 0.433 (base err: 0.517) (m): 100

# Takeaways

Thanks to the hydra library we were able to:
- Store another training configuration in `config/xp/my_params.yaml`
- Extend the project with a new data source and a new model and use them in new experiments

And doing this:
- We only specified the parameters that change w.r.t. the base experiments -> **Easier to read and to understand**
- We didn't change the code or the config of the base experiment ->  **We know we didn't break the initial training**



**To know more, go to the doc https://hydra.cc/**

# To try it yourself:
- Go to the github repo https://github.com/CIA-Oceanix/DLGD2022
- Open the notebook tutorial-1-hydra/main_notebook.ipynb
- Run some of your own experiments **without** changing existing files
- Try sharing your experiments with your friends
- Try running experiments developped by your friends

# Bonus

## Sweep over multiple parameters values

In [None]:


# We can sweep over multiple values at once
!python DLGD2022/tutorial-1-hydra/hydra_main.py --multirun \
            training.partial_optimizer.lr=0.01,0.005,0.001 # try different learning rate

## Use hydra from a notebook

In [None]:
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf

# create configs programatically
cs = ConfigStore.instance()
cs.store(
    name='medium_cnn', 
    node={
        '_target_': 'new_src.MyForecastModel',
        'ninput': '${data.number_of_past_days}',
        'nhidden': 32,
        'nlayers': 2,
    }, group='model', package='model'
)



# Compose configs explicitely
with hydra.initialize(config_path='config', version_base='1.2'):
    cfg = hydra.compose('main', overrides=['model=medium_cnn'])

print(OmegaConf.to_yaml(cfg.model))

_target_: new_src.MyForecastModel
ninput: ${data.number_of_past_days}
nhidden: 32
nlayers: 2



## Use hydra to manage your slurm job configurations : [See doc](https://hydra.cc/docs/plugins/submitit_launcher/)


## Use hydra to help with your hyperparameter search: [See doc](https://hydra.cc/docs/plugins/optuna_sweeper/)
