# Aurora inference & fine-tuning in Azure ML

This notebook explains the steps implemented in `aurora_demo_core.py`.

## Input data

#### General requirements

Aurora expects a 3D (space + vertical) grid with a consistent spatial resolution. Input data can be interpreted as a *snapshot of the state of the atmosphere* at a particular time.

Input variables are unnormalised: the model normalises each variable independently. 

Input data must be in the [aurora.Batch format](https://microsoft.github.io/aurora/batch.html) that includes:

1. **surface-level variables** (a dictionary)
2. **static variables** (a dictionary)
3. **atmospheric variables** (all at the same collection of pressure levels; a dictionary)
4. **metadata** describing these variables: latitudes, longitudes, the pressure levels of the atmospheric variables, and the time when these variables were recorded. This is an [aurora.Metadata](https://microsoft.github.io/aurora/batch.html#batch-metadata) object.

The dictionaries of variables map the predefined names of the variables to their numerical values. [See the full list of supported variables and their short names](https://microsoft.github.io/aurora/batch.html#batch-surf-vars).

1. The **surface-level variables** must be of the form (<span style='color: purple;'>b</span>, <span style='color: orange;'>t</span>, <span style='color: green;'>h</span>, <span style='color: red;'>w</span>) where <span style='color: purple;'>b</span> is the batch size, <span style='color: orange;'>t</span> is the history dimension, <span style='color: green;'>h</span> is the number of latitudes, and <span style='color: red;'>w</span> is the number of longitudes.

All Aurora models produce the prediction for the next time step from the current time step (surface variables that represent it are in `surf_vars[:, 1, :, :]`) and the previous time step (`surf_vars[:, 0, :, :]`). Note that `Metadata.time` corresponds to the current time step. 

The batch size <span style='color: purple;'>b</span> corresponds to the number of *spatiotemporal samples*. For example, <span style='color: purple;'>b > 1</span> when samples come from different *ensemble members* (estimations based on varying initial conditions) that are used to predict the weather for the same target time. Aurora can consume ensemble members as independent samples stacked along the <span style='color: purple;'>b</span> dimension. In this case, the output contains independent forecasts for all members. Samples from different spatial tiles can also be used in one batch.

2. The **static variables** must be of the form (<span style='color: green;'>h</span>, <span style='color: red;'>w</span>) where <span style='color: green;'>h</span> is the number of latitudes and <span style='color: red;'>w</span> the number of longitudes. These variables do not change with time.

3. The **atmospheric variables** must be of the form (<span style='color: purple;'>b</span>, <span style='color: orange;'>t</span>, <span style='color: grey;'>c</span>, <span style='color: green;'>h</span>, <span style='color: red;'>w</span>) where <span style='color: purple;'>b</span> is the batch size, <span style='color: orange;'>t</span> the history dimension, <span style='color: grey;'>c</span> the number of pressure levels, <span style='color: green;'>h</span> the number of latitudes, and <span style='color: red;'>w</span> the number of longitudes. All atmospheric variables must contain the same collection of pressure levels in the same order.


4. **metadata** contains the following fields:

    * `Metadata.lat` is the vector of latitudes. The latitudes must be decreasing. The latitudes can either include both endpoints: `linspace(90, -90, 721)`, or not include the south pole: `linspace(90, -90, 721)[:-1]`. For curvilinear grids, this can also be a matrix, in which case the foregoing conditions apply to every column.

    * `Metadata.lon` is the vector of longitudes. The longitudes must be increasing. The longitudes must be in the range `[0, 360)`, so they can include zero and cannot include 360. For curvilinear grids, this can also be a matrix, in which case the foregoing conditions apply to every row.

    * `Metadata.atmos_levels` is a tuple of the pressure levels of the atmospheric variables in hPa. Note that these levels must correspond to the order of the atmospheric variables. Note also that `Metadata.atmos_levels` should be a tuple, not a list.

    * `Metadata.time` is a tuple with a `datetime.datetime` representing the time of the data (for each sample in the batch). If the batch size <span style='color: purple;'>b</span> is one, then this will be a one-element tuple, e.g. `(datetime(2024, 1, 1, 12, 0),)`. Since all Aurora models require variables for the current and the previous step, `Metadata.time` corresponds to the time of the current step. Specifically, `Metadata.time[i]` corresponds to the time of `Batch.surf_vars[i, -1]`.

#### Dummy data (in GPU memory)

In the first example, you will create a dummy dataset stored in the GPU memory. You will use a call to `make_lowres_batch` to create your first `aurora.Batch`.

The values of all variables are random but the latitude-longitude grid is realistic for a global forecast (17 latitudes × 32 longitudes).   



#### ERA5 (in Azure Blob Storage)

TODO



## Inference

You will use `run_inference` to load a pretrained Aurora, create input data, run one forward pass or multiple forward passes (if the `rollout_steps` argument is provided), and return the output of Aurora. The output of Aurora will have the same format as the input batch with one exception -- the history dimension <span style='color: orange;'>t</span> will be equal to 1.

* `model = AuroraPretrained()`

Instantiate the model. `model` has the same architecture as Aurora but the weights are not loaded yet.
* `model.load_checkpoint()`

Load the pretrained model weights from HuggingFace. We will use a pretrained 0.25° Aurora (a generic ERA5-ish model). Note that `Aurora()` refers to the 0.25° model fine-tuned on IFS HRES instead of the generic version. [See all available model variants](https://huggingface.co/microsoft/aurora/tree/main).

The authors recommend `AuroraPretrained()` for predictions based on ERA5 at 0.25°. This variant of the model is also recommended for fine-tuning to new applications.

* `model = model.to(device)`

Move the model to a GPU (if available).

* `model.eval()`

Set the model to evaluation mode. This disables dropout and fixes batch normalisation.

* `batch = make_lowres_batch(device=device)`

Create an input batch using the function you have just defined. 

Aurora can be used to generate predictions one lead time forward (one step ahead) or multiple lead times forward (by applying the model autoregressively). These two options differ in implementation: they are provided here as two modes in `run_inference`.

#### One step prediction

* `with torch.inference_mode():`
* `prediction = model(batch)`

Disable gradient tracking and run one forward pass.

#### Autoregressive rollout

* `with torch.inference_mode():`
*    `preds = [pred.to("cpu") for pred in rollout(model, batch, steps=rollout_steps)]`

Run the model for `rollout_steps` steps ahead using its own outputs as inputs at the next step. Move the predictions to a CPU to prevent GPU memory buildup. 






## Fine-tuning

Steps for fine-tuning Aurora (sources: [Aurora GitHub Issues](https://github.com/microsoft/aurora/issues) and [Official docs: Aurora fine-tuning](https://microsoft.github.io/aurora/finetuning.html)):

1. Collect data in the form of surface, static and atmospheric variables. For new variables, set normalisation statistics (means and standard deviations). 

2. Define the loss. The authors of Aurora use a loss function based on Mean Absolute Error. 

3. Load the pretrained model and set up the optimiser. 

4. Fine-tuning loop (repeat for several epochs):  

    a) Sample a random data point. 

    b) Run the model to get the model prediction. 

    c) Normalise predictions and target data.  

    d) Compute the loss and update model parameters. 
    
    e) Monitor how the loss changes during fine-tuning. 

5. Save the fine-tuned model. 

You will use `run_finetuning` to update all model weights based on the new dataset. 

First, you will load the pretrained model weights and move them to a GPU (if available). These steps are the same as in the inference function (`run_inference`).

```python
model = AuroraPretrained()
model.load_checkpoint()
model = model.to(device)
```

Set the model to training mode:

`model.train()`

Create an optimiser that will update all model parameters. In your 'real' experiments, you might want to treat learning rate as a hyperparameter and try out a range of values.

`optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)`

We will return the model outputs from the last epoch (`last_prediction`) and the values of the loss function from all epochs (`loss_history`). We will later plot the loss values.

```python
last_prediction = None
loss_history: list[float] = []
```

In the loop below, you will:

1. Generate a random input batch (`batch`).
2. Clear any gradients from the previous iteration.
3. Produce predictions (`prediction`) based on input data (`batch`).
4. Compute the loss.
5. Compute gradients with respect to all parameters.
6. Use the gradients to update all model weights.
7. Append the loss value from the current iteration to `loss_history`.
8. Print the epoch number and the loss value.

```python
for step in range(steps):
    batch = make_lowres_batch(device=device)

    optimizer.zero_grad()
    prediction = model(batch)
    loss_value = loss(prediction)
    loss_value.backward()
    optimizer.step()

    last_loss_value = float(loss_value.detach().cpu().item())
    loss_history.append(last_loss_value)

    print(f"[step {step}] loss = {last_loss_value:.4f}")
```

After the predefined number of iterations (`steps`), you will return the outputs from the last epoch and the loss values from all epochs.

`return last_prediction, loss_history`
