# Precipitation forecast example

This notebook demonstrates how to perform precipitation forecasts using the Prithvi Precip model. The Prithvi Precip is a finetuned version of the Prithvi-WxC foundation model that was finetuned using  IMERG data from January 1st 2000 to December 31 2019.

In [7]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Required packages

The code for the Prithvi Precip model is available from this [GitHub](https://github.com/simonpf/prithvi_precip) repository. 

> **Disclaimer**: This is still early development. From a quality perspective the code is not yet where I would like it to be and is likely to still change in the future.

In [13]:
pip install  git+https://github.com/simonpf/prithvi_precip

Collecting git+https://github.com/simonpf/prithvi_precip
  Cloning https://github.com/simonpf/prithvi_precip to /tmp/pip-req-build-tp49ln48
  Running command git clone --filter=blob:none --quiet https://github.com/simonpf/prithvi_precip /tmp/pip-req-build-tp49ln48
  Resolved https://github.com/simonpf/prithvi_precip to commit 9695b2afb4ba2c598859860ff091c4ca80db6c0c
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting pytorch_retrieve>=0.1.4
  Downloading pytorch_retrieve-0.1.4-py3-none-any.whl (192 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.0/193.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m
Collecting axial-attention
  Downloading axial_attention-0.6.1-py3-none-any.whl (6.0 kB)
Collecting tensorboard
  Downloading tensorboard-2.20.0-py3-none-any.whl (5.5 MB)
[2K     [38;2;114;156;31m━━━━━━

In [8]:
!conda run -n prithvi_wxc pip install git+https://github.com/simonpf/prithvi_precip

Collecting git+https://github.com/simonpf/prithvi_precip
  Cloning https://github.com/simonpf/prithvi_precip to /tmp/pip-req-build-dnw6m428
  Resolved https://github.com/simonpf/prithvi_precip to commit 9695b2afb4ba2c598859860ff091c4ca80db6c0c
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'

  Running command git clone --filter=blob:none --quiet https://github.com/simonpf/prithvi_precip /tmp/pip-req-build-dnw6m428



## Example data

The trained Prithvi Precip model is available on HuggingFace. The repository contains the model itself as well as input data from August 26 - 31, 2020. The function below download the model as well as the input data.

In [9]:
from huggingface_hub import snapshot_download


local_dir = snapshot_download(
    repo_id="simonpf/prithvi_precip",
    local_dir="."
)

Fetching 412 files:   0%|          | 0/412 [00:00<?, ?it/s]

The input data is downloaded to the ``data`` dir. The directory contains the static data that the Prithvi-WxC model requires such as the climatology, scaling factors, static input data, as well as the dynamic input data and reference precipitation data.

## Loading the model

The Prithvi Precip is a slightly modified version of the Prithvi-WxC model. It adds two components to the model:

- An observation encoder
- A MLP head that predicts precipitation from the original model output.

The code for the model itself is part of the [pytorch_retrieve](https://github.com/simonpf/pytorch_retrieve) package. This package was originally developed to simplify building precipitation retrievals (hence the name)  but it now also handles forecasts. The implementation can be found [here](https://github.com/simonpf/pytorch_retrieve/blob/8910dce482cfe4e7259415e727e9323afbfe8a33/pytorch_retrieve/models/prithvi_wxc.py#L374). 

I provide this here mostly as a future reference in case this may be useful if you aim to develop your own models built on the Prithvi-WxC FM. It is not critical for understanding the remainder of this notebook.


In [None]:
%env PRITHVI_DATA_PATH=data/scaling_factors

In [None]:
from pytorch_retrieve import load_model
model = load_model("prithvi_precip_obs.pt").eval()

## Input data

The input data expected by the Prithvi Precip is very similar to that of the Prithvi-WxC model. Just as the Prithvi-WxC model, the model expects a dictionary containing the dynamic input (``x``), static input (``static``), the climatology (``climatology``), as well as the lead time in hours (``lead_time``) and the time difference between the two input steps (``input_time``).

Additionally, the model expects two tensors ``obs`` and ``obs_meta`` containing the satellite observations and their meta data. The meta data is information describing the observations. It consists of:

 - The log of the observations frequency
 - The channel offsets for passive microwave channels with symmetric offsets (0 for all others)
 - The relative time within the input data window.
 - One-hot encoded polarization: "H", "V", "QH", "QV", or none.

### Loading data

The loading is again performed using a custom dataset class.

In [14]:
from prithvi_precip.datasets import DirectPrecipForecastWithObsDataset

dataset = DirectPrecipForecastWithObsDataset(
    "/home/simon/src/prithvi_precip_hf/data/input_data/"
)
initialization_time = np.datetime64("2020-08-26T12:00:00")
input_data = dataset.get_direct_forecast_input(initialization_time, 1)

ModuleNotFoundError: No module named 'prithvi_precip'

In [None]:
import cartopy.crs as ccrs
from matplotlib.colors import LogNorm
from matplotlib.gridspec import GridSpec

gs = GridSpec(1, 3, width_ratios=[1.0, 1.0, 0.05])
fig = plt.figure(figsize=(13, 4))

ax = fig.add_suplot(gs[0, 0])

## Forecast from MERRA data

As a baseline, we will run a precipitation forecast using only the MERRA-2 data as input. An important difference to the original Pritvhvi-WxC model is that the Prithvi Precip model does not require unrolling because it has been trained to predict lead time from 3 to 96 hours.

> **Note**: We set ``model_only=True``. This sets all observations in the input to 0 forcing the model to ignore them.

In [None]:
import torch
from tqdm import tqdm

device = "cuda:1"
dtype = torch.float32
forecast_steps = 16

model = model.to(device=device, dtype=dtype)

batch_iterator = dataset.get_batched_direct_forecast_input(initialization_time, n_steps=forecast_steps, batch_size=1)
results = []

step = 1
for inpt in tqdm(batch_iterator, total=forecast_steps, desc="Running forecast"):
    lats = np.rad2deg(inpt["static"][0, 0].float().cpu().numpy())[:, 0]
    lons = np.rad2deg(inpt["static"][0, 1].float().cpu().numpy())[0]
    valid_time = initialization_time + np.timedelta64(3, "h") * step
    step += 1

    # Move input to GPU and case to target dtype.
    inpt = {
        name: tnsr.to(device=device, dtype=dtype) for name,tnsr in inpt.items()
    }

    # Run inference
    with torch.no_grad():
        pred = model(inpt, model_only=True)["surface_precip"]
        expected_value = pred.expected_value().cpu().float().numpy()[0, 0] # Drop batch and feature dimensions.
        p_1 = pred.probability_greater_than(1.0).cpu().float().numpy()[0, 0]
        p_10 = pred.probability_greater_than(10.0).cpu().float().numpy()[0, 0]

    result = xr.Dataset({
        "latitude": (("latitude",), lats),
        "longitude": (("longitude",), lons),
        "valid_time": valid_time,
        "surface_precip": (("latitude", "longitude"), expected_value),
        "p_1mm": (("latitude", "longitude"), p_1,),
        "p_10mm": (("latitude", "longitude"), p_10)
    })
    results.append(result)
    
results = xr.concat(results, dim="valid_time")
results["initialization_time"] = initialization_time
        

## Reference data

In order to have something to compare the forecasts to, we load reference precipitation data from MERRA-2 and IMERG. MERRA-2 is the dataset that the original Prithvi-WxC model was trained on, however, the model did not use precipitation as a input or output variables. IMERG, on the other hand, is the satellite precipitation product that the Prithvi Precip model was trained on.

In [None]:
results_imerg = []
results_merra = []

reference_data_path = Path("data/input_data/")


for time in results.valid_time.data:
    time = time.astype("datetime64[s]").item()

    imerg_file = time.strftime("imerg_3/%Y/%m/imerg_%Y%m%d%H%M.nc")
    imerg_data = xr.load_dataset(reference_data_path / imerg_file)
    imerg_data["p_10mm"] = 5 < imerg_data.surface_precip
    results_imerg.append(imerg_data)
    
    merra_file = time.strftime("merra2_precip_3/%Y/%m/merra2_precip_%Y%m%d%H%M.nc")
    merra_data = xr.load_dataset(reference_data_path / merra_file)[{"lat": slice(0, -1)}].rename(lat="latitude", lon="longitude")
    merra_data["p_10mm"] = 5 < imerg_data.surface_precip
    results_merra.append(merra_data)

results_imerg = xr.concat(results_imerg, dim="time").rename(time="valid_time")
results_imerg["initialization_time"] = initialization_time
results_merra = xr.concat(results_merra, dim="time").rename(time="valid_time")
results_merra["initialization_time"] = initialization_time

## Forecast results

### Global

In [None]:
from prithvi_precip.plotting import set_style, animate_results
from IPython.display import HTML
set_style()

ani = animate_results({
    "Prithvi Precip": results,
    "MERRA-2": results_merra,
    "IMERG": results_imerg
})
HTML(ani.to_jshtml())

## Hurricane Laura

The forecast period covers the landfall of Hurricane Laura in 2020 so we take a closer look at the Gulf of America. 

> **Note**: The variables ``lon_min, lon_max, lat_min, lat_max`` define the bounding box for the region of interes (ROI). Feel free to change it to explore other regions of the forecast.

In [None]:
lon_min = -105
lon_max = -75
lat_min = 15
lat_max = 40

lon_mask = (lon_min <= results.longitude) * (results.longitude <= lon_max)
lat_mask = (lat_min <= results.latitude) * (results.latitude <= lat_max)

results_roi = results[{"latitude": lat_mask, "longitude": lon_mask}]
results_merra_roi = results_merra[{"latitude": lat_mask.data, "longitude": lon_mask.data}]
results_imerg_roi = results_imerg[{"latitude": lat_mask.data, "longitude": lon_mask.data}]

In [None]:
from prithvi_precip.plotting import set_style, animate_results
from IPython.display import HTML

ani = animate_results(
    {
        "IMERG": results_imerg_roi,
        "MERRA-2": results_merra_roi,
        "Prithvi Precip": results_roi,
    },
    include_metrics=True
)
HTML(ani.to_jshtml())

# Adding observations

As mentioned above, the input for the Prithvi Precip model adds observations (``obs``) and observation meta data to the input (``obs_meta``). Due to the way the model handles observations, the observations are not loaded on the global grid but on the tiling that the Prithvi-WxC model expects. The shape of the tensors is $[n_b, n_t, n_{g,lat}, n_{g,lon}, n_{obs}, 1, n_{l,lat}, n_{l, lon}]$ where

- $n_b$: Is the batch dimension
- $n_t$: Is the number of input timesteps
- $n_{g, lat}$: Is the number of tiles along the meridional direction.
- $n_{g, lon}$: Is the number of tiles along the zonal direction.
- $n_{obs}$: Is the number of loaded observation layers (32)
- $n_{l, lat}$: Is the size of each tile along the meridional dimension.
- $n_{l, lon}$: Is the size of each tile along the zonal dimension.

## Plotting a single observation tile

The observation data passed into the model is already normalized. The model uses a very simple normalization scheme. Observations are stored using reflectivities in percent and brightness temperatures. These are normalized by mapping values in the range [0, 300] to [-1, 1].

In [None]:
input_data = dataset.get_direct_forecast_input(initialization_time, 1)

In [None]:
plt.pcolormesh(input_data["obs"][0, 0, 4, 0, 0, 0])
plt.colorbar()

Visualizing a single tile is generally not very helpful so the ``prithvi_precip`` package also provides a function to visualize the full grid of tiles. We use this function below to visualize all of the 32 observation layers.

Note that you may want to load the image in a new tab (right click + open image in new tab) and zoom in to make out any detail.

In [None]:
obs, meta = dataset.obs_loader.load_observations(np.datetime64("2020-08-26T12:00:00"), randomize=False)

In [None]:
from prithvi_precip.plotting import plot_tiles

import torch
from matplotlib.gridspec import GridSpec
from prithvi_precip.plotting import plot_tiles, set_style
from matplotlib import colormaps
set_style()

cmap = colormaps.get_cmap("magma")
cmap.set_bad("grey")
fig = plt.figure(figsize=(20, 8))
gs = GridSpec(4, 8, wspace=0.2, hspace=0.2)

for row_ind in range(4):
    for col_ind in range(8):
        layer_ind = row_ind * 8 + col_ind
        ax = fig.add_subplot(gs[row_ind, col_ind])
        ax.set_title(f"Layer {layer_ind + 1}")
        tile = obs[:, :, layer_ind, 0].clone()
        tile[tile < -1] = torch.nan
        plot_tiles(tile, 0, 1, 2, 3, ax=ax, colorbar=False, cmap=cmap)

        ax.set_xticks([])
        ax.set_yticks([])

## Forecast with observations

In [None]:
import torch
from tqdm import tqdm

device = "cuda:1"
dtype = torch.float32
forecast_steps = 32

model = model.to(device=device, dtype=dtype)

batch_iterator = dataset.get_batched_direct_forecast_input(initialization_time, n_steps=forecast_steps, batch_size=1)
results_obs = []

step = 1
for inpt in tqdm(batch_iterator, total=forecast_steps, desc="Running forecast"):
    lats = np.rad2deg(inpt["static"][0, 0].float().cpu().numpy())[:, 0]
    lons = np.rad2deg(inpt["static"][0, 1].float().cpu().numpy())[0]
    valid_time = initialization_time + np.timedelta64(3, "h") * step
    step += 1

    # Move input to GPU and case to target dtype.
    inpt = {
        name: tnsr.to(device=device, dtype=dtype) for name,tnsr in inpt.items()
    }

    # Run inference
    with torch.no_grad():
        pred = model(inpt, model_only=False)["surface_precip"]
        expected_value = pred.expected_value().cpu().float().numpy()[0, 0] # Drop batch and feature dimensions.
        p_1 = pred.probability_greater_than(1.0).cpu().float().numpy()[0, 0]
        p_10 = pred.probability_greater_than(10.0).cpu().float().numpy()[0, 0]

    result = xr.Dataset({
        "latitude": (("latitude",), lats),
        "longitude": (("longitude",), lons),
        "valid_time": valid_time,
        "surface_precip": (("latitude", "longitude"), expected_value),
        "p_1mm": (("latitude", "longitude"), p_1,),
        "p_10mm": (("latitude", "longitude"), p_10)
    })
    results_obs.append(result)
    
results_obs = xr.concat(results_obs, dim="valid_time")
results_obs["initialization_time"] = initialization_time
results_obs_roi = results_obs[{"latitude": lat_mask.data, "longitude": lon_mask.data}]

In [None]:
from prithvi_precip.plotting import set_style, animate_results
from IPython.display import HTML

ani = animate_results(
    {
        "IMERG": results_imerg,
        "MERRA-2": results_merra,
        "Prithvi Precip": results,
        "Prithvi Precip (Obs)": results_obs,
    },
    include_metrics=True,
    panel_width=8,
    n_cols=2
)
HTML(ani.to_jshtml())

In [None]:
from prithvi_precip.plotting import set_style, animate_results
from IPython.display import HTML

ani = animate_results(
    {
        "IMERG": results_imerg_roi,
        "MERRA-2": results_merra_roi,
        "Prithvi Precip": results_roi,
        "Prithvi Precip (Obs)": results_obs_roi,
    },
    include_metrics=True,
    n_cols=2
)
HTML(ani.to_jshtml())

## Observation-only forecast

Finally we also run an observation-only forecast. We do this by setting the ``obs_only`` flag to ``True``.

In [None]:
import torch
from tqdm import tqdm

device = "cuda:1"
dtype = torch.float32
forecast_steps = 32

model = model.to(device=device, dtype=dtype)

batch_iterator = dataset.get_batched_direct_forecast_input(initialization_time, n_steps=forecast_steps, batch_size=1)
results_obs_only = []

step = 1
for inpt in tqdm(batch_iterator, total=forecast_steps, desc="Running forecast"):
    lats = np.rad2deg(inpt["static"][0, 0].float().cpu().numpy())[:, 0]
    lons = np.rad2deg(inpt["static"][0, 1].float().cpu().numpy())[0]
    valid_time = initialization_time + np.timedelta64(3, "h") * step
    step += 1

    # Move input to GPU and case to target dtype.
    inpt = {
        name: tnsr.to(device=device, dtype=dtype) for name,tnsr in inpt.items()
    }

    # Run inference
    with torch.no_grad():
        pred = model(inpt, obs_only=True)["surface_precip"]
        expected_value = pred.expected_value().cpu().float().numpy()[0, 0] # Drop batch and feature dimensions.
        p_1 = pred.probability_greater_than(1.0).cpu().float().numpy()[0, 0]
        p_10 = pred.probability_greater_than(10.0).cpu().float().numpy()[0, 0]

    result = xr.Dataset({
        "latitude": (("latitude",), lats),
        "longitude": (("longitude",), lons),
        "valid_time": valid_time,
        "surface_precip": (("latitude", "longitude"), expected_value),
        "p_1mm": (("latitude", "longitude"), p_1,),
        "p_10mm": (("latitude", "longitude"), p_10)
    })
    results_obs_only.append(result)
    
results_obs_only = xr.concat(results_obs_only, dim="valid_time")
results_obs_only["initialization_time"] = initialization_time
results_obs_only_roi = results_obs_only[{"latitude": lat_mask.data, "longitude": lon_mask.data}]

In [None]:
results_obs_only_roi = results_obs_only[{"latitude": lat_mask.data, "longitude": lon_mask.data}]

## Global

In [None]:
from prithvi_precip.plotting import set_style, animate_results
from IPython.display import HTML

ani = animate_results(
    {
        "IMERG": results_imerg,
        "Prithvi Precip": results,
        "Prithvi Precip (Obs)": results_obs,
        "Prithvi Precip (Obs only)": results_obs_only,
    },
    include_metrics=True,
    n_cols=2,
    panel_width=8
)
HTML(ani.to_jshtml())

## Local

In [None]:
from prithvi_precip.plotting import set_style, animate_results
from IPython.display import HTML

ani = animate_results(
    {
        "IMERG": results_imerg_roi,
        "MERRA-2": results_merra_roi,
        "Prithvi Precip": results_roi,
        "Prithvi Precip (Obs)": results_obs_roi,
        "Prithvi Precip (Obs only)": results_obs_only_roi,
    },
    include_metrics=True,
    n_cols=2
)
HTML(ani.to_jshtml())