In [2]:
gt = "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean_spi/src/model/experiments/earthformer_era_20240816_102733/inference_plots/all_ground_truths.nc"
pred = "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean_spi/src/model/experiments/earthformer_era_20240816_102733/inference_plots/all_predictions.nc"
clim = "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean_spi/src/model/experiments/earthformer_era_20240816_102733/inference_plots/all_climatology.nc"


In [3]:
import numpy as np

In [4]:
import xarray as xr

gt = xr.open_dataset(gt)
pred = xr.open_dataset(pred)
clim = xr.open_dataset(clim)

In [5]:
rmse_clim = np.sqrt(((gt - clim) ** 2).mean(dim='time'))

print("RMSE climatology", np.mean(rmse_clim))
print("RMSE std", np.std(rmse_clim))
np.mean(rmse_clim)

RMSE climatology <xarray.Dataset>
Dimensions:  ()
Data variables:
    spi      float64 0.9529
RMSE std <xarray.Dataset>
Dimensions:  ()
Data variables:
    spi      float64 0.632


In [6]:
rmse = np.sqrt(((gt - pred) ** 2).mean(dim='time'))
std_rmse = np.std(rmse)

print(f"RMSE: {np.mean(rmse)}")
print(f"STD RMSE: {std_rmse}")

np.mean(rmse)

RMSE: <xarray.Dataset>
Dimensions:  ()
Data variables:
    spi      float64 0.9458
STD RMSE: <xarray.Dataset>
Dimensions:  ()
Data variables:
    spi      float64 0.6133


In [7]:
gt.spi.values


array([[[[ 6.87663929e-01,  5.46409501e-01,  1.12014839e-03,
          -1.34186606e+00, -2.16326752e+00]],

        [[-3.00909207e-01, -3.35205019e-01, -7.93762065e-01,
          -4.99811705e-01,  4.13140627e-01]],

        [[ 5.75782145e-01,  3.80136077e-01, -2.12910585e-01,
          -1.58675404e-01, -8.98679059e-02]],

        ...,

        [[            nan,             nan,             nan,
                      nan,             nan]],

        [[            nan,             nan,             nan,
                      nan,             nan]],

        [[            nan,             nan,             nan,
                      nan,             nan]]],


       [[[            nan,             nan,             nan,
                      nan,             nan]],

        [[-3.00909207e-01, -3.35205019e-01, -7.93762065e-01,
          -4.99811705e-01,  4.13140627e-01]],

        [[ 5.75782145e-01,  3.80136077e-01, -2.12910585e-01,
          -1.58675404e-01, -8.98679059e-02]],

        ...,

In [8]:
da = gt.spi

In [9]:
def transform_data(da):
    # Create a mask for non-NaN values
    mask = ~np.isnan(da).any(dim=['latitude', 'longitude'])

    # Function to get valid time indices for each sample
    def get_valid_times(sample_mask):
        return np.where(sample_mask)[0]

    # Apply this function to each sample
    valid_time_indices = xr.apply_ufunc(
        get_valid_times,
        mask,
        input_core_dims=[['time']],
        output_core_dims=[['valid_time']],
        vectorize=True
    )

    # Create a new DataArray with only valid times for each sample
    new_da = da.isel(time=valid_time_indices)
    return new_da




In [10]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.ticker as mticker

def plot(new_da, name):
    # Assuming new_da has dimensions (sample, valid_time, latitude, longitude)
    num_samples = new_da.sizes['sample']

    for i, sample in enumerate(range(num_samples)):
        # Select data for the current sample
        sample_data = new_da.isel(sample=sample)
        
        # Create a figure for each sample
        fig = plt.figure(figsize=(5*sample_data.sizes['valid_time'], 6))  # Increased height for colorbar
        
        # Create gridspec with two rows: one for maps, one for colorbar
        gs = fig.add_gridspec(2, sample_data.sizes['valid_time'], height_ratios=[20, 1], hspace=0.1)
        
        # Create a separate axes for the maps
        map_axes = [fig.add_subplot(gs[0, i], projection=ccrs.PlateCarree()) 
                    for i in range(sample_data.sizes['valid_time'])]
        
        # Get the overall min and max for consistent colorbar
        vmin = sample_data.min().item()
        vmax = sample_data.max().item()
        
        for time_idx, ax in enumerate(map_axes):
            data = sample_data.isel(valid_time=time_idx)
            
            lats = [30] + list(data.latitude.values) + [45]
            lons = [-10] + list(data.longitude.values) + [40]

            im = ax.imshow(data, vmin=vmin, vmax=vmax, cmap='viridis',
                        extent=[lons[0], lons[-1], lats[0], lats[-1]],
                        transform=ccrs.PlateCarree())
            
            ax.coastlines(resolution='50m', color='black', linewidth=0.5)
            ax.add_feature(cfeature.BORDERS, linestyle=':', color='black', linewidth=0.5)
            ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgrey', alpha=0.3)
            ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='lightblue', alpha=0.3)
            
            gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                            linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
            gl.xlocator = mticker.FixedLocator(range(-10, 41, 10))
            gl.ylocator = mticker.FixedLocator(range(30, 46, 5))
            gl.top_labels = False
            gl.right_labels = False
            
            ax.set_title(f'Time step {time_idx}')
        
        # Add a common colorbar at the bottom
        cbar_ax = fig.add_subplot(gs[1, :])
        plt.colorbar(im, cax=cbar_ax, orientation='horizontal')
        
        plt.suptitle(f'Sample {sample}', y=0.98)
        plt.tight_layout()
        plt.savefig(f'{name}_sample_{sample}_plot.png', dpi=300, bbox_inches='tight')
        plt.close()
        if i > 10:
            break

SyntaxError: expected ':' (1939926497.py, line 6)

In [None]:

# do it for the 3 datasets
new_da_gt = transform_data(gt.spi)
plot(new_da_gt, "gt")

new_da_pred = transform_data(pred.spi)
plot(new_da_pred, "pred")

new_da_clim = transform_data(clim.spi)
plot(new_da_clim, "clim")

NameError: name 'transform_data' is not defined