## Predict sea ice concentration with sea surface flux using transformer with multi-head attention
This notebook serves as an example of a basic workflow with `s2spy` & `lilio` packages. <br>
We will predict sea ice concentration in the Arctic at subseasonal time scales using ERA5 dataset with multi-head attention transformer. <br>

This recipe includes the following steps:
- Define a calendar (`lilio`)
- Download/load input data
- Map the calendar to the data (`lilio`)
- Train-test split (70%/30%) (`torch`)

In [9]:
from datetime import date

import lilio
import numpy as np
import xarray as xr
from pathlib import Path
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as f

#### Define a calendar with `lilio` to specify time range for targets and precursors.

In [6]:
# create NWP calendar for constrained forecasts/hindcasts
# create custom calendar based on the time of interest
calendar = lilio.Calendar(anchor="01-01", allow_overlap=True)
# add target periods
# determine number of weeks
w_start = date(1979, 1, 1)
w_end = date(2017, 12, 31)
days = (w_end - w_start).days
weeks = days // 7
# add target periods
for _ in range(weeks):
    calendar.add_intervals("target", length="7d")

In [None]:
# load data
data_folder = '~/AI4S2S/data'
precursor_field = xr.open_dataset(Path(data_folder, 'rad_daily'))
target_field = xr.open_dataset(Path(data_folder,'sic_daily'))

#### Map the calendar to the data

In [None]:
# map calendar to data
calendar.map_to_data(precursor_field)
calendar.visualize(show_length=True)

In [None]:
# get 70% of instance as training
years = calendar.get_intervals().index
train_samples = round(len(years) * 0.7)
start_year = years[-1]

#### Fit preprocessor with training samples and preprocess data
Remove trend and take anomalies for the precursor field.

#### Train-test split based on the anchor years (70%/30% split)

#### Create tansformer with multihead attention using pytorch

In [None]:
def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor:
    temp = query.bmm(key.transpose(1, 2))
    scale = query.size(-1) ** 0.5
    softmax = f.softmax(temp / scale, dim=-1)
    return softmax.bmm(value)
     

class AttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.q = nn.Linear(dim_in, dim_q)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_k)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))
     

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

In [None]:
# positional embedding

#### Hyper-parameter tuning with W&B

#### Train model
Implement early stop function.

#### Evaluate model