In [77]:
import os
import pandas as pd
import numpy as np
from dianna import visualization
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import onnx
import onnxruntime as ort
import dianna
from dianna.utils.downloader import download

import lilio
import urllib
import xarray as xr
from s2spy import preprocess
from s2spy import RGDR
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
from skl2onnx import to_onnx
import numpy as np
import onnxruntime as ort
from pathlib import Path

np.random.seed(0)


In [78]:
# create custom calendar based on the time of interest
calendar = lilio.Calendar(anchor="08-01", allow_overlap=True)
# add target periods
calendar.add_intervals("target", length="30d")
# add precursor periods
periods_of_interest = 8
calendar.add_intervals("precursor", "1M", gap="1M", n=periods_of_interest)

In [None]:
# Load precursor data from zarr
ar_full_37_1h = xr.open_zarr(
    'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
    chunks='auto',
    storage_options=dict(token='anon'),
)

da = ar_full_37_1h["sea_surface_temperature"].isel(time=slice(None, None, 6))
da = da.sel(time=slice("1959-01-01T00:00:00", "2021-12-31T00:00:00"))

sst_f = da.sel(latitude=slice(50, -10), longitude=slice(180, 240))

sst_f

In [None]:
precursor_field = sst_f.to_dataset(name='sst')

precursor_field

In [None]:
# URL of the dataset from zenodo
# sst_url = "https://zenodo.org/record/8186914/files/sst_daily_1959-2021_5deg_Pacific_175_240E_25_50N.nc"
t2m_url = "https://zenodo.org/record/8186914/files/t2m_daily_1959-2021_2deg_clustered_226_300E_30_70N.nc"
# sst_field = "sst_daily_1959-2021_5deg_Pacific_175_240E_25_50N.nc"
t2m_field = "t2m_daily_1959-2021_2deg_clustered_226_300E_30_70N.nc"

# urllib.request.urlretrieve(sst_url, sst_field)
urllib.request.urlretrieve(t2m_url, t2m_field)

In [None]:
# load data
# precursor_field = xr.open_dataset(sst_field)
target_field = xr.open_dataset(t2m_field)

precursor_field

In [None]:
# Convert Kelvin to Celsius
precursor_field["sst"] = precursor_field["sst"] - 273.15
target_field["t2m"] = target_field["t2m"] - 273.15

precursor_field["sst"]

In [None]:
# map calendar to data
calendar.map_to_data(precursor_field)
calendar.visualize(show_length=True)

In [85]:
# get 70% of instance as training
years = sorted(calendar.get_intervals().index)
train_samples = round(len(years) * 0.7)
start_year = years[0]

In [None]:
# create preprocessor
preprocessor = preprocess.Preprocessor(
    rolling_window_size=25,
    timescale="monthly",
    detrend="linear",
    subtract_climatology=True,
)

# fit preprocessor with training data
preprocessor.fit(
    precursor_field.sel(
        time=slice(str(start_year), str(start_year + train_samples - 1))
    )
)

In [None]:
# preprocess the whole precursor field
precursor_field_prep = preprocessor.transform(precursor_field)

In [None]:
precursor_field_resample = lilio.resample(calendar, precursor_field_prep)
target_field_resample = lilio.resample(calendar, target_field)

In [None]:
# select variables and intervals
precursor_field_sel = precursor_field_resample['sst']
target_series_sel = target_field_resample['t2m'].sel(cluster=3)

In [None]:
# load onnx model and check the prediction with it
model_path = Path('/Users/clairedonnelly/AI4S2S/cookbook/workflow','model.onnx')


In [None]:
idx = 6 # explained instance
data_instance = data_test[idx][np.newaxis, ...]
# precheck ONNX predictions
pred_onnx = run_model(data_instance)
pred_class = classes[np.argmax(pred_onnx)]
print("The predicted class is:", pred_class)
print("The actual class is:", classes[np.argmax(target_test.iloc[idx])])

In [None]:
import onnxruntime as ort

def run_model(data):
    # get ONNX predictions
    sess = ort.InferenceSession(model_path)
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    onnx_input = {input_name: data.astype(np.float32)}
    pred_onnx = sess.run([output_name], onnx_input)[0]
    
    return pred_onnx

In [None]:
explanation = dianna.explain_tabular(run_model, input_tabular=data_instance, method='rise',
                                     mode ='classification', training_data = X_train.to_numpy(),
                                     feature_names=input_features.columns, class_names=species)

In [None]:
run_model(data_instance[None,...])

In [None]:
from dianna.visualization import plot_tabular

_ = plot_tabular(explanation[np.argmax(predictions)], X_test.columns, num_features=10)