In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
# Make sure we're in the right directory
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [None]:
import sys
import h5py
import json
import torch
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict
import rtml.data_wrangling.constants as constants
from rtml.models.interface import get_model, is_gnn, is_graph_net, get_trainer
from rtml.models.column_handler import ColumnPreprocesser
from rtml.data_wrangling.constants import LEVELS, LAYERS, GLOBALS, OUTPUT, TRAIN_YEARS
from rtml.data_wrangling.h5_dataset import RT_HdF5_Dataset
from rtml.utils.utils import set_seed, get_name, year_string_to_list, identity
from rtml.data_wrangling.constants import TEST_YEARS, LAYERS

In [None]:
plt.rcParams['figure.figsize'] = [20, 8]  # general matplotlib parameters
plt.rcParams['figure.dpi'] = 70 
np.set_printoptions(suppress=True, threshold=sys.maxsize)

In [None]:
hdf5_years_dir = "/miniscratch/salva.ruhling-cachay/ECC_data/snapshots/1979-2014/hdf5/inputs"
model_dir = "scripts/out"
year = 2011
h5_path = os.path.join(hdf5_years_dir, str(year) + '.h5')

In [None]:
def get_lat_lon(data: np.ndarray = None):
    coords_data = xr.open_dataset(
        '/miniscratch/venkatesh.ramesh/ECC_data/snapshots/coords_data/areacella_fx_CanESM5_amip_r1i1p1f1_gn.nc'
    )
    lat = list(coords_data.get_index('lat'))
    lon = list(coords_data.get_index('lon'))

    latitude = []
    longitude = []
    for i in lat:
        for j in lon:
            latitude.append(i)
            longitude.append(j)
    lat_var = np.array(latitude)
    lon_var = np.array(longitude)
    return {'latitude': lat, 'longitude': lon, 'latitude_flattened': lat_var, 'longitude_flattened': lon_var}

### On GPU

In [None]:
def get_preds(ckpt: str, year:str, device='cuda'):
    """ init_batches are run but not accounted for in the benchmark, i.e. warm-up runs. """
    model_ckpt = torch.load(f"{model_dir}/{ckpt}.pkl", map_location=torch.device(device))
    params = model_ckpt['hyper_params']
    net_params = model_ckpt['model_params']
    model_type = params['model']
    
    dataset_kwargs = dict(
        exp_type=params['exp_type'],
        target_type=params['target_type'],
        target_variable=params['target_variable'],
        input_transform=get_model(params['model'], only_class=True)._input_transform,
        input_normalization=params['in_normalize'],
        spatial_normalization_in=params['spatial_normalization_in'],
        load_h5_into_mem=True
    )
    
    dset = RT_HdF5_Dataset(years=year_string_to_list(str(year)), name='Eval', output_normalization=None, **dataset_kwargs)
    dloader = torch.utils.data.DataLoader(dset, batch_size=512, pin_memory=True, shuffle=False, num_workers=2)
    output_postprocesser = dset.output_variable_splitter

    d = dset.h5_dsets[0].get_raw_input_data()
    lvl_pressure = d[LEVELS][..., 2]
    lay_pressure = d[LAYERS][..., 2]
    cszrow = d[GLOBALS][..., 0]
    print(cszrow.shape, lvl_pressure.shape, lay_pressure.shape)
    
    trainer_kwargs = dict(
        model_name=params['model'], model_params=net_params,
        device=params['device'], seed=params['seed'],
        model_dir=params['model_dir'],
        output_postprocesser=output_postprocesser,
    )
    if is_gnn(params['model']) or is_graph_net(params['model']):
        trainer_kwargs['column_preprocesser'] = ColumnPreprocesser(
            n_layers=dset.spatial_dim[LAYERS], input_dims=dset.input_dim, **params['preprocessing_dict']
        )
        tranform_name = trainer_kwargs['column_preprocesser'].preprocessing_type
        if tranform_name not in ['mlp', 'mlp_projection']:
            transform = trainer_kwargs['column_preprocesser'].get_preprocesser()
            dset.set_input_transform(transform)

    print(net_params)
    trainer = get_trainer(**trainer_kwargs)
    trainer.reload_model(model_state_dict=model_ckpt['model'])
    preds, Y, _ = trainer.evaluate(dloader, verbose=True)
    
    
    dset.close()    
    return {'preds': preds, 'targets': Y, 'pressure': lvl_pressure, 'layer_pressure': lay_pressure, 'cstrow': cszrow}
                
                

In [None]:
def save_preds(preds, targets, exp='pristine', model=None, year=None, **kwargs):
    lat_lon = get_lat_lon()
    lat, lon = lat_lon['latitude'], lat_lon['longitude']
    n_levels = 50
    n_layers = 49
    shape = ['snapshot', 'latitude', 'longitude', 'level']
    shape_lay = ['snapshot', 'latitude', 'longitude', 'layer']
    shape_glob = ['snapshot', 'latitude', 'longitude']
        
    data_vars = dict()
    for k, v in preds.items():
        data_vars[f"{k}_preds"] = (shape, v.reshape((-1, len(lat), len(lon), n_levels)))
    for k, v in targets.items():
        data_vars[f"{k}_targets"] = (shape, v.reshape((-1, len(lat), len(lon), n_levels)))
            
    data_vars["pressure"] = (shape, kwargs['pressure'].reshape((-1, len(lat), len(lon), n_levels)))
    data_vars["layer_pressure"] = (shape_lay, kwargs['layer_pressure'].reshape((-1, len(lat), len(lon), n_layers)))
    data_vars["cszrow"] = (shape_glob, kwargs['cszrow'].reshape((-1, len(lat), len(lon))))
        
    xr_dset = xr.Dataset(
        data_vars=data_vars,
        coords=dict(
            longitude=lon,
            latitude=lat,
            level=list(range(n_levels))[::-1],
            layer=list(range(n_layers))[::-1],
        ),
        attrs=dict(description="ML emulated RT outputs."),
    )
    if model is not None and year is not None:
        xr_dset.to_netcdf(f'~/RT-DL/example_{exp}_preds_{model}_{year}.nc')
    else:
        print("Not saving to NC!")
    return xr_dset

In [None]:
year = 2012

In [None]:
best_gn_ckpt = "0.2706valMAE_141ep_GN+READOUT_1985-90+1998-2004train_2005val_Z_7seed_15h50m_on_Aug_22_27kn4tto"
p_gn = get_preds(best_gn_ckpt, year=year, device = 'cuda')
save_preds(**p_gn, model='graph_net', year=year)