# Use of foscat to forecast ECMWF global information

This notebook ....

## TODO:

- Problem while using statistics of the previous timestep, should compress dimensinality of foscat

## Installation of required packages

In [1]:
#!pip install foscat==3.1.0 
#!pip install --upgrade "xarray<=2025.4.0"

In [2]:
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
import foscat.scat_cov as sc
import foscat.Synthesis as synthe

### Choose the resolution

In [3]:
nside=32

In [4]:
import gcsfs
import xarray
gcs = gcsfs.GCSFileSystem(token='anon')

era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(gcs.get_mapper(era5_path), chunks=None)
full_era5

In [5]:
temperature=full_era5['temperature'].sel(level=20)
temperature

In [6]:
from scipy.interpolate import RegularGridInterpolator

# convert the input data in a nside=nside healpix map based on repixelisation at l_nside
def to_Healpix(arr_val,itime,nside,l_nside=512):
    
    latitude=(90.0-arr_val.latitude.compute().to_numpy())/180.*np.pi
    longitude=(360.0-arr_val.longitude.compute().to_numpy())/180.*np.pi
    im=arr_val.isel(time=itime).compute().to_numpy()
    xsize,ysize=im.shape

    # Define the new row and column to be added to prepare the interpolation
    new_row = im[0:1,:]  # A new row with N elements (the other longitude)
    new_column = np.concatenate([im[:,0:1],im[-2:-1,0:1]],0)  # A new column with N+1 elements to add previous latitude

    # Add the new row to the array
    im = np.vstack([im, new_row])
    longitude = np.concatenate([longitude,2*longitude[-1:]-longitude[-2:-1]],0)
    latitude = np.concatenate([latitude,2*latitude[-1:]-latitude[-2:-1]],0)

    # Add the new column to the array with the new row
    im = np.hstack([im, new_column])

    # Create an interpolator
    interpolator = RegularGridInterpolator((latitude,longitude), im)

    # List of healpix coordinate to interpol
    colatitude,longitude = hp.pix2ang(l_nside,np.arange(12*l_nside**2),nest=True)
    coords = np.concatenate([colatitude,longitude],0).reshape(2,colatitude.shape[0]).T

    # Perform the interpolation
    heal_im = interpolator(coords)

    #reduce the final map to the expected resolution
    if nside>l_nside:
        th,ph=hp.pix2ang(nside,np.arange(12*nside**2),nest=True)
        heal_im=hp.get_interp_val(heal_im,th,ph,nest=True)
    else:
        heal_im=np.mean(heal_im.reshape(12*nside**2,(l_nside//nside)**2),1)

    return heal_im

In [None]:
ntest=1000
nvalid=ntest//10

try:
    heal_im=np.load('heal_im_%d.npy'%(ntest))
    valid_im=np.load('valid_im_%d.npy'%(ntest))
except:
    heal_im=np.zeros([ntest,12*nside**2])
    valid_im=np.zeros([nvalid,12*nside**2])
    for k in range(ntest):
        print('test ',k)
        heal_im[k]=to_Healpix(temperature,1000000+k,nside)
    
    amp_shum=np.std(heal_im[1:]-heal_im[:-1])
    mean_shum=np.median(heal_im[1:]-heal_im[:-1])
    
    heal_im=(heal_im-mean_shum)/amp_shum
    np.save('heal_im_%d.npy'%(ntest),heal_im)
    
    for k in range(nvalid):
        print('valid ',k)
        valid_im[k]=to_Healpix(temperature,1100000+k*10,nside)
    
    valid_im=(valid_im-mean_shum)/amp_shum
    np.save('valid_im_%d.npy'%(ntest),valid_im)

test  0
test  1
test  2
test  3
test  4
test  5
test  6
test  7
test  8
test  9
test  10
test  11
test  12
test  13
test  14
test  15
test  16


In [None]:
amp=3000
plt.figure(figsize=(12,4))
hp.orthview(heal_im[0],cmap='coolwarm',nest=True,hold=False,sub=(1,3,1),
            title='ECMWF 04/28/2024-12:00',min=-amp,max=amp,cbar=False,norm='hist')
hp.orthview(heal_im[1],cmap='coolwarm',nest=True,hold=False,sub=(1,3,2),
            title='ECMWF 04/29/2024-12:00',min=-amp,max=amp,cbar=False,norm='hist')
hp.orthview(heal_im[2],cmap='coolwarm',nest=True,hold=False,sub=(1,3,3),
            title='ECMWF 04/30/2024-12:00',min=-amp,max=amp,cbar=False,norm='hist')

In [None]:
from foscat.healpix_unet_torch import HealpixUNet,fit

model = HealpixUNet(
    in_nside=nside,
    n_chan_in=1,
    chanlist=[8,16,16,16],
    cell_ids=np.arange(12*nside**2),           # finest-resolution nested ids
    KERNELSZ=3,
    task='regression',         # or 'regression' or 'segmentation'
    out_channels=1,              # e.g., #classes
    final_activation=None      # defaults: sigmoid if 1 class, else softmax; 'none' for regression
)

In [None]:
f=sc.funct()
x_train=f.backend.bk_cast(heal_im[:-1])
y_train=f.backend.bk_cast(heal_im[1:]-heal_im[:-1])
x_valid=f.backend.bk_cast(valid_im[:-1])
y_valid=f.backend.bk_cast(valid_im[1:]-valid_im[:-1])

In [None]:
# refine with LBFGS
hist  = fit(model, x_train[:,None,:],y_train[:,None,:], n_epoch=100, view_epoch=1,optimizer='ADAM',batch_size=10)

In [None]:
y_pred=model.predict(x_train[:,None,:]).cpu().numpy()
y_pred.shape

In [None]:
plt.figure(figsize=(12,6.5))
amp=3
rot=[0,0]
for i, k in enumerate([ntest//4, 2*ntest//4, 3*ntest//4]):
    hp.orthview(heal_im[k+1]-heal_im[k],rot=rot,cmap='coolwarm',nest=True,hold=False,sub=(3,3,1+i),min=-amp,max=amp,cbar=False,title=r'Input $\Delta_t$ humidity t=%d'%(k+1))
    hp.orthview(y_pred[k,0],rot=rot,cmap='coolwarm',nest=True,hold=False,sub=(3,3,4+i),min=-amp,max=amp,cbar=False,title=r'U-NET $\Delta_t$ humidity t=%d'%(k+1))
    hp.orthview(y_pred[k,0]-(heal_im[k+1]-heal_im[k]),rot=rot,cmap='coolwarm',nest=True,hold=False,sub=(3,3,7+i),min=-amp,max=amp,cbar=False,title=r'Diff $\Delta_t$ humidity t=%d'%(k+1))

In [None]:
y_pred=model.predict(x_valid[:,None,:]).cpu().numpy()
y_pred.shape

In [None]:
plt.figure(figsize=(12,6.5))
amp=3
rot=[0,0]
for i, k in enumerate([nvalid//4, 2*nvalid//4, 3*nvalid//4]):
    hp.orthview(valid_im[k+1]-valid_im[k],rot=rot,cmap='coolwarm',nest=True,hold=False,sub=(3,3,1+i),min=-amp,max=amp,cbar=False,title=r'Input $\Delta_t$ humidity t=%d'%(k+1))
    hp.orthview(y_pred[k,0],rot=rot,cmap='coolwarm',nest=True,hold=False,sub=(3,3,4+i),min=-amp,max=amp,cbar=False,title=r'U-NET $\Delta_t$ humidity t=%d'%(k+1))
    hp.orthview(y_pred[k,0]-(valid_im[k+1]-valid_im[k]),rot=rot,cmap='coolwarm',nest=True,hold=False,sub=(3,3,7+i),min=-amp,max=amp,cbar=False,title=r'Diff $\Delta_t$ humidity t=%d'%(k+1))