<a href="https://colab.research.google.com/github/CIA-Oceanix/DLCourse_MOi_2022/blob/main/projects/notebook_example_torch_training_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import des données altimétriques

### Install dependencies

In [1]:
!pip install xarray[complete] eccodes -q

[K     |████████████████████████████████| 56 kB 1.4 MB/s 
[K     |████████████████████████████████| 3.2 MB 3.9 MB/s 
[K     |████████████████████████████████| 19.3 MB 1.2 MB/s 
[K     |████████████████████████████████| 185 kB 57.6 MB/s 
[K     |████████████████████████████████| 355 kB 53.8 MB/s 
[K     |████████████████████████████████| 2.3 MB 39.8 MB/s 
[K     |████████████████████████████████| 45 kB 2.1 MB/s 
[K     |████████████████████████████████| 114 kB 43.7 MB/s 
[K     |████████████████████████████████| 6.6 MB 10.1 MB/s 
[?25h  Building wheel for eccodes (setup.py) ... [?25l[?25hdone
  Building wheel for findlibs (setup.py) ... [?25l[?25hdone
  Building wheel for docopt (setup.py) ... [?25l[?25hdone
  Building wheel for asciitree (setup.py) ... [?25l[?25hdone


### Download netcdfs

#### Download NATL60 data

In [2]:
!wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc

--2022-10-04 07:31:40--  https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
Resolving s3.us-east-1.wasabisys.com (s3.us-east-1.wasabisys.com)... 38.27.106.51, 38.27.106.53
Connecting to s3.us-east-1.wasabisys.com (s3.us-east-1.wasabisys.com)|38.27.106.51|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://s3.eu-central-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc [following]
--2022-10-04 07:31:40--  https://s3.eu-central-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
Resolving s3.eu-central-1.wasabisys.com (s3.eu-central-1.wasabisys.com)... 130.117.252.28, 130.117.252.10, 130.117.252.32, ...
Connecting to s3.eu-central-1.wasabisys.com (s3.eu-central-1.wasabisys.com)|130.117.252.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 118023544 (113M) [application/x-netcdf]
Saving to: ‘NATL60-CJM165_GULFSTREAM_ss

# Example torch dataloading

## Concepts

In order to create the training data two torch classes are used:

- torch.utils.data.Dataset: collection of elementary item or items for one training iteration
    - Example: `(x, y)`: 1 image `x` + corresponding label `y`
    - Example forcasting: `(ssh_t, ssh_t_plus_1)`: passed ssh  + future ssh
    - Example downscaling: `(ssh_low_res, ssh_high_res)`: ...

- torch.utils.data.DataLoader: iterable that takes a `Dataset` as input and a `batch_size` and constitute the "batches": 
    - Example if the dataset return `(x, y)` with `x~(channel, height, width)` and `y~(label)`, `Dataloader(dataset, batch_size=batch_size)` will return (bx, by) with  with `bx~(batch_size, channel, height, width)` and `by~(batch_size, label)`


## Practical considerations

In order to create a dataset a simple way is to stack all training items along the first dimension and then instantiate a `torch.utils.data.TensorDataset` below an example of a dataset that return the `(ssh_t, ssh_t_plus_1)` 

In [30]:
import torch
import xarray as xr

# create stacked items
ref_ds = xr.open_dataset('NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc', decode_times=False).assign_coords(time=lambda ds: pd.to_datetime(ds.time))

ssh_t = ref_ds.ssh.isel(time=slice(None, -1)) # t0 ... tN-1
ssh_t_plus_1 = ref_ds.ssh.isel(time=slice(1, None)) # t1 ... tN

# convert to torch tensor
ssh_t_tensor = torch.from_numpy(ssh_t.values)
ssh_t_plus_1_tensor = torch.from_numpy(ssh_t_plus_1.values)

# Create dataset
torch_dataset = torch.utils.data.TensorDataset(ssh_t_tensor, ssh_t_plus_1_tensor)

example_item = torch_dataset[0]
ssh_t_item, ssh_t_plus_1_item = example_item
print("Item sizes")
print(f'ssh_t_item, {ssh_t_item.size()}')
print(f'ssh_t_plus_1_item, {ssh_t_plus_1_item.size()}')
print()

# Create dataloader
torch_dataloader = torch.utils.data.DataLoader(torch_dataset, batch_size=8)

print("Batch sizes")
for example_batch in torch_dataloader:
    break
ssh_t_batch, ssh_t_plus_1_batch = example_batch
print(f'ssh_t_batch, {ssh_t_batch.size()}')
print(f'ssh_t_plus_1_batch, {ssh_t_plus_1_batch.size()}')

Item sizes
ssh_t_item, torch.Size([201, 201])
ssh_t_plus_1_item, torch.Size([201, 201])

Batch sizes
ssh_t_batch, torch.Size([8, 201, 201])
ssh_t_plus_1_batch, torch.Size([8, 201, 201])


### TO DO FOR train validation and test datasets :)