In [1]:
#!/usr/bin/env python

import pandas as pd
import torch
import numpy as np
import xarray as xr
import time
import gc

def rmse_based_scores(da_rec, da_ref):
    # boost swot rmse score
    # logging.info('     Compute RMSE-based scores...')

    # RMSE(t) based score
    rmse_t = 1.0 - (((da_rec - da_ref)**2).mean(dim=('lon', 'lat')))**0.5 / ((
        (da_ref)**2).mean(dim=('lon', 'lat')))**0.5
    # RMSE(x, y) based score
    # rmse_xy = 1.0 - (((da_rec - da_ref)**2).mean(dim=('time')))**0.5/(((da_ref)**2).mean(dim=('time')))**0.5
    rmse_xy = (((da_rec - da_ref)**2).mean(dim=('time')))**0.5

    rmse_t = rmse_t.rename('rmse_t')
    rmse_xy = rmse_xy.rename('rmse_xy')

    # Temporal stability of the error
    reconstruction_error_stability_metric = rmse_t.std().values

    # Show leaderboard SSH-RMSE metric (spatially and time averaged normalized RMSE)
    leaderboard_rmse = 1.0 - (((da_rec - da_ref)**2).mean())**0.5 / ((
        (da_ref)**2).mean())**0.5

    # logging.info('          => Leaderboard SSH RMSE score = %s',
    #              np.round(leaderboard_rmse.values, 2))
    # logging.info(
    #     '          Error variability = %s (temporal stability of the mapping error)',
    #     np.round(reconstruction_error_stability_metric, 2))

    return rmse_t, rmse_xy, np.round(leaderboard_rmse.values, 5), np.round(
        reconstruction_error_stability_metric, 5)

def prepare_oi_batch(
    obs_values,
    obs_time,
    obs_lon,
    obs_lat,
    c_time,
    c_lon,
    c_lat,
    ps={
        'time': 1,
        'lon': 2.5,
        'lat': 2.5
    },
    lt=7.,
    lx=1.,
    ly=1.,
    forecast=False,
    leadtime=0,
):
    grid_time, grid_lon, grid_lat = (
        x.flatten() for x in torch.meshgrid(c_time, c_lon, c_lat))

    def gen_patch():
        c_time_start = c_time.min()
        while c_time_start <= c_time.max():
            c_lon_start = c_lon.min()
            while c_lon_start <= c_lon.max():
                c_lat_start = c_lat.min()
                while c_lat_start <= c_lat.max():
                    yield (
                        (c_time_start, c_time_start + ps['time']),
                        (c_lon_start, c_lon_start + ps['lon']),
                        (c_lat_start, c_lat_start + ps['lat']),
                    )
                    c_lat_start += ps['lat']
                c_lon_start += ps['lon']
            c_time_start += ps['time']

    for bounds in gen_patch():
        (ts, te), (los, loe), (las, lae) = bounds
        msk_grid = (grid_time.ge(ts) & grid_time.lt(te)) & (
            grid_lon.ge(los) & grid_lon.lt(loe)) & (grid_lat.ge(las)
                                                    & grid_lat.lt(lae))
        if not forecast:
            msk_obs = ((obs_time.ge(ts - 2 * lt) & obs_time.lt(te + 2 * lt))
                       & (obs_lon.ge(los - 2 * lx) & obs_lon.lt(loe + 2 * lx))
                       & (obs_lat.ge(las - 2 * ly) & obs_lat.lt(lae + 2 * ly)))
        else:
            msk_obs = (
                (obs_time.ge(ts - 4 * lt - leadtime) & obs_time.lt(te - leadtime))
                & (obs_lon.ge(los - 2 * lx) & obs_lon.lt(loe + 2 * lx))
                & (obs_lat.ge(las - 2 * ly) & obs_lat.lt(lae + 2 * ly)))

        yield (
            obs_values[msk_obs],
            obs_time[msk_obs],
            obs_lat[msk_obs],
            obs_lon[msk_obs],
            grid_time[msk_grid],
            grid_lat[msk_grid],
            grid_lon[msk_grid],
        )


def torch_oi(
    obs_values,
    obs_time,
    obs_lon,
    obs_lat,
    grid_time,
    grid_lon,
    grid_lat,
    lt=7.,
    lx=1.,
    ly=1.,
    ps={
        'time': 1,
        'lat': 12,
        'lon': 12
    },
    noise=0.05,
):
    bh_t = torch.exp(-((grid_time[:, None] - obs_time[None, :]) / lt)**2 -
                     ((grid_lon[:, None] - obs_lon[None, :]) / lx)**2 -
                     ((grid_lat[:, None] - obs_lat[None, :]) / ly)**2)
    hbh_t = torch.exp(-((obs_time[:, None] - obs_time[None, :]) / lt)**2 -
                      ((obs_lon[:, None] - obs_lon[None, :]) / lx)**2 -
                      ((obs_lat[:, None] - obs_lat[None, :]) / ly)**2)

    nobs = len(obs_time)
    R = torch.diag(torch.full((nobs, ), noise**2, device=grid_time.device))

    coo = hbh_t + R
    mi = torch.linalg.inv(coo)
    iw = torch.mm(bh_t, mi).float()
    sol = torch.mv(iw, obs_values.float())
    
    del(coo)
    del(mi)
    del(iw) 
    del(bh_t)
    del(hbh_t)
    
    gc.collect()
    torch.cuda.empty_cache()
    

    return sol

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
device = 'cuda'
# device = 'cpu'

xp="OSE"
if xp=="OSE":
    path_ref_daily = '/DATASET/mbeauchamp/IMT/data/natl_ose_6nadirs.nc'
    inputs = '/DATASET/mbeauchamp/IMT/data/natl_ose_6nadirs.nc'
    obs_var = 'sst_obs'
    target_var = 'sst'
    simu_start_date = np.datetime64('2016-12-01')  # domain min time
    dt = np.timedelta64(1, 'D')  # temporal grid step
    time_min = np.datetime64('2017-01-01')
    time_max = np.datetime64('2017-01-31')
else:
    path_ref_daily = '/DATASET/mbeauchamp/IMT/data/natl_4nadirs_swot.nc'
    inputs = '/DATASET/mbeauchamp/IMT/data/natl_4nadirs_swot.nc'
    obs_var = 'sst_obs'
    target_var = 'sst'
    #obs_var = 'nadir_obs'
    #target_var = 'ssh'
    simu_start_date = np.datetime64('2012-10-01')  # domain min time
    dt = np.timedelta64(1, 'D')  # temporal grid step
    time_min = np.datetime64('2012-10-22')
    time_max = np.datetime64('2012-12-02')
    
lon_min = -65.  # domain min longitude
lon_max = -55.  # domain max longitude
lat_min = 33.  # domain min latitude
lat_max = 43.  # domain max latitude
dx = 1 / 10  # 0.1  # zonal grid spatial step (in degree)
dy = 1 / 10  # 0.1  # meridional grid spatial step (in degree)

glon = torch.arange(lon_min, lon_max + dx,
                    dx).to(device)  # output OI longitude grid
glat = torch.arange(lat_min, lat_max + dy,
                    dy).to(device)  # output OI latitude grid
gtime = torch.arange((time_min - simu_start_date) / dt,
                     (time_max - simu_start_date) / dt,
                     1).to(device)  # output OI time grid

# OI parameters
lx = 1.  # Zonal decorrelation scale (in degree)
ly = 1.  # Meridional decorrelation scale (in degree)
lt = 7.  # Temporal decorrelation scale (in days)
noise = 0.05  # Noise level (5%)


# TODO: change to 4 nadirs
ref_daily_ds = xr.open_mfdataset(path_ref_daily).sel(
              lat=slice(lat_min - 2 * ly, lat_max +
              2 * ly)).sel(lon=slice(lon_min - 2 * lx, lon_max + 2 * lx)).sel(
              time=slice(time_min - 2 * lt * dt, time_max + 2 * lt * dt)).coarsen(lon=2).mean().coarsen(lat=2).mean()
ref_daily_ds.time.attrs['units'] = 'seconds since 2012-10-01'
ref_daily_ds = xr.decode_cf(ref_daily_ds)

ds_obs = (xr.open_dataset(inputs).sel(
              lat=slice(lat_min - 2 * ly, lat_max +
              2 * ly)).sel(lon=slice(lon_min - 2 * lx, lon_max + 2 * lx)).sel(
              time=slice(time_min - 2 * lt * dt, time_max + 2 * lt * dt))).coarsen(lon=2).mean().coarsen(lat=2).mean()

In [6]:
ref_daily_ds

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.82 MiB 8.82 MiB Shape (59, 140, 140) (59, 140, 140) Dask graph 1 chunks in 11 graph layers Data type float64 numpy.ndarray",140  140  59,

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.82 MiB 8.82 MiB Shape (59, 140, 140) (59, 140, 140) Dask graph 1 chunks in 11 graph layers Data type float64 numpy.ndarray",140  140  59,

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.82 MiB 8.82 MiB Shape (59, 140, 140) (59, 140, 140) Dask graph 1 chunks in 11 graph layers Data type float64 numpy.ndarray",140  140  59,

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.82 MiB 8.82 MiB Shape (59, 140, 140) (59, 140, 140) Dask graph 1 chunks in 11 graph layers Data type float64 numpy.ndarray",140  140  59,

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.82 MiB 8.82 MiB Shape (59, 140, 140) (59, 140, 140) Dask graph 1 chunks in 11 graph layers Data type float64 numpy.ndarray",140  140  59,

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.82 MiB 8.82 MiB Shape (59, 140, 140) (59, 140, 140) Dask graph 1 chunks in 11 graph layers Data type float64 numpy.ndarray",140  140  59,

Unnamed: 0,Array,Chunk
Bytes,8.82 MiB,8.82 MiB
Shape,"(59, 140, 140)","(59, 140, 140)"
Dask graph,1 chunks in 11 graph layers,1 chunks in 11 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [None]:
from tqdm.autonotebook import tqdm

full_outs = {}

obs = np.flatnonzero(np.isfinite(ds_obs[obs_var].values))

obs_values = torch.from_numpy(
        np.ravel(  # Flatten the observations
            ds_obs[obs_var].values)[obs]).to(device)
obs_lon = torch.from_numpy(
        np.ravel(ds_obs.lon.broadcast_like(
            ds_obs[obs_var]).values)[obs]).to(device)
obs_lat = torch.from_numpy(
        np.ravel(ds_obs.lat.broadcast_like(
            ds_obs[obs_var]).values)[obs]).to(device)
obs_time = torch.from_numpy(
        (np.ravel(ds_obs.time.broadcast_like(ds_obs[obs_var]).values)[obs] -
         simu_start_date) / dt).float().to(device)

for leadtime in tqdm([-1,0,2,4]):
    outputs = []
    t0 = time.time()
    with torch.no_grad():
        i=0
        batches = prepare_oi_batch(obs_values,
                                      obs_time,
                                      obs_lon,
                                      obs_lat,
                                      gtime,
                                      glon,
                                      glat, {
                                          'time': 1,
                                          'lon': 1,
                                          'lat': 1
                                      },
                                      lt,
                                      lx,
                                      ly,
                                      forecast=(leadtime>=0),
                                      leadtime=leadtime)
        for batch in batches:
            torch.cuda.empty_cache()
            sol = torch_oi(*batch, lt, lx, ly, noise)
            outputs.append((*(x.detach().cpu() for x in batch), sol.cpu()))
            del(sol)
            gc.collect()
            torch.cuda.empty_cache()

    dfs = []
    for chunk in outputs:
        (*_, grid_time, grid_lat, grid_lon, grid_sol) = chunk
        dfs.append(
            pd.DataFrame({
                target_var: grid_sol.numpy(),
                'time': grid_time.numpy() * dt + simu_start_date,
                'lat': grid_lat.numpy(),
                'lon': grid_lon.numpy(),
            }))

    out_ds = pd.concat(dfs).set_index(['time', 'lat',
                                       'lon']).pipe(xr.Dataset.from_dataframe)
    ref_daily = ref_daily_ds.interp(out_ds[['time', 'lat', 'lon']].coords)
    #duacs = duacs_ds.interp(out_ds[['time', 'lat', 'lon']].coords)
    #duacs_mse = (duacs.ssh_mod -
    #             ref_daily.ssh).pipe(lambda da: np.sqrt(np.mean(da**2))).compute()
    mse = (out_ds[target_var] -
           ref_daily[target_var]).pipe(lambda da: np.sqrt(np.mean(da**2))).compute()
    full_outs[obs_var] = (out_ds, mse)

    print([(k, v[1]) for k, v in full_outs.items()])
    _, _, mu, _ = rmse_based_scores(out_ds[target_var], ref_daily[target_var])
    print(f'{mu=}')
    if leadtime==-1:
        suf = "mapping"
    elif leadtime==0:
        suf = "nrt"
    else:
        suf = 'frsct'+str(leadtime)
    out_ds.to_netcdf('OI_'+target_var+'_'+suf+'_xp_'+xp+'.nc')

  0%|                                                     | 0/4 [00:00<?, ?it/s]