In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
# Make sure we're in the right directory
if os.path.basename(os.getcwd()) in ["notebooks", "examples"]:
    os.chdir("..")

In [4]:
DATA_DIR = "/home/jupyter-dipti/work/processed"  # the data used for prediction must be here, as well as the cmip6 mean/std statistics
#DATA_DIR = "/home/jupyter-dipti/work/AiBEDO_simultaneousPreds_Salva/Data/Predictions/EOFInput"
# Input data filename (isosph is an order 6 icosahedron, isosph5 of order 5, etc.)
filename_input = "isosph5.nonorm.ERA5_Exp8_Input.nc"
# Output data filename is inferred from the input filename, do not edit!
# E.g.: "compress.isosph.CESM2.historical.r1i1p1f1.Output.nc"
filename_output = filename_input.replace("Input.Exp8.nc", "Output.nc")

In [5]:
import xarray as xr
import numpy as np
from typing import *
import wandb
import torch
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
#import proplot as pplt
from aibedo.models import BaseModel
from aibedo.utilities.wandb_api import reload_checkpoint_from_wandb, get_run_ids_for_hyperparams
import scipy.stats

import hydra
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, DictConfig
from aibedo.utilities.config_utils import get_config_from_hydra_compose_overrides
from aibedo.utilities.utils import rsetattr, get_logger, get_local_ckpt_path, rhasattr, rgetattr


ModuleNotFoundError: No module named 'aibedo.models'

In [5]:
# Get the appropriate device (GPU or CPU) to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
overrides = [f'datamodule.data_dir={DATA_DIR}', f"++model.use_auxiliary_vars=False"]

In [6]:
def concat_variables_into_channel_dim(data: xr.Dataset, variables: List[str]) -> np.ndarray:
    """Concatenate xarray variables into numpy channel dimension (last)."""
    assert len(data[variables[0]].shape) == 2, "Each input data variable must have two dimensions"
    data_ml = np.concatenate(
        [np.expand_dims(data[var].values, axis=-1) for var in variables],
        axis=-1  # last axis
    )
    return data_ml.astype(np.float32)

def get_month_of_output_data(output_xarray: xr.Dataset) -> np.ndarray:
    """ Get month of the snapshot (0-11)  """
    n_gridcells = len(output_xarray['ncells'])
    # .item() is required here as only one timestep is used, the subtraction with -1 because we want 0-indexed months
    month_of_snapshot = np.array(output_xarray['time.month'], dtype=np.float32) - 1
    # now repeat the month for each grid cell/pixel
    dataset_month = np.repeat(month_of_snapshot, n_gridcells)
    return dataset_month.reshape([month_of_snapshot.shape[0], n_gridcells, 1])  # Add a dummy channel/feature dimension

def get_pytorch_model_data(input_xarray: xr.Dataset, output_xarray: xr.Dataset, input_vars: List[str]) -> torch.Tensor:
    """Get the tensor input data for the ML model."""
    # Concatenate all variables into the channel/feature dimension (last) of the input tensor
    data_input = concat_variables_into_channel_dim(input_xarray, input_vars)
    # Get the month of the snapshot (0-11), which is needed to denormalize the model predictions into their original scale
    data_month = get_month_of_output_data(output_xarray)
    # For convenience, we concatenate the month information to the input data, but it is *not* used by the model!
    data_input = np.concatenate([data_input, data_month], axis=-1)
    # Convert to torch tensor and move to CPU/GPU
    data_input = torch.from_numpy(data_input).float().to(device)
    return data_input

def predict_with_aibedo_model(aibedo_model: BaseModel, input_tensor: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Predict with the AiBEDO model.
    Returns:
        A dictionary of output-variable -> prediction-tensor key->value pairs for each variable {var}.
        Keys with name {var} (e.g. 'pr') are in denormalized scale. Keys with name {var}_pre or {var}_nonorm are raw predictions of the ML model.
        To only get the raw predictions, please use aibedo_model.raw_predict(input_tensor)
    """
    model.eval()
    with torch.no_grad():  # No need to track the gradients during inference
        prediction = aibedo_model.predict(input_tensor, return_normalized_outputs=True)  # if true, also return {var}_nonorm (or {var}_pre)
    return prediction


from aibedo.interface import reload_model_from_config_and_ckpt
def load_model(config_path,config_name,ckpt_path,ckpt_name):
    overrides = [f'datamodule.data_dir={DATA_DIR}', f"++model.use_auxiliary_vars=False"]

    ### Load Hydra config file
    GlobalHydra.instance().clear() 
    hydra.initialize(config_path=config_path, version_base=None)
    config = hydra.compose(config_name=config_name, overrides=overrides)
    config['ckpt_dir'] = ckpt_path
    config['callbacks']['model_checkpoint']['dirpath'] = config['ckpt_dir']

    ## Modify config dict
    if config.model.get('input_transform'):
        OmegaConf.update(config, f'model.input_transform._target_',
                         str(rgetattr(config, f'model.input_transform._target_')).replace('aibedo_salva', 'aibedo'))
    for k in ['model', 'datamodule', 'model.mixer', 'model.input_transform']:
        if config.get(k):
            OmegaConf.update(config, f'{k}._target_',
                             str(rgetattr(config, f'{k}._target_')).replace('aibedo_salva', 'aibedo'))
    
    ## Load model
    loadmodel = reload_model_from_config_and_ckpt(config, ckpt_path+ckpt_name, load_datamodule=True)

    return loadmodel[0], config

In [None]:
dataPath='/home/jupyter-dipti/work/AiBEDO_simultaneousPreds_Salva/Data/ckpoints/'
#modelsName=os.listdir(dataPath)
### Load the actual data and process it
#ds_input = xr.open_dataset(f"{DATA_DIR}/{filename_input}")  # Input data
ds_output = xr.open_dataset(f"{DATA_DIR}/{filename_output}") # Ground truth data


In [None]:
dataPath='/home/jupyter-dipti/work/AiBEDO_simultaneousPreds_Salva/Data/ckpoints/'
#modelsName=os.listdir(dataPath)
### Load the actual data and process it
#ds_input = xr.open_dataset(f"{DATA_DIR}/{filename_input}")  # Input data
ds_output = xr.open_dataset(f"{DATA_DIR}/{filename_output}") # Ground truth data

gt_era5=xr.open_dataset('/home/jupyter-dipti/work/processed/isosph5.nonorm.ERA5_Exp8_Output.nc')
gt_era5

gt_tas_nonorm=gt_era5.tas_nonorm
gt_ps_nonorm=gt_era5.ps_nonorm
gt_pr_nonorm=gt_era5.pr_nonorm


#modelsName = [  'FGOALS-g3',  'GISS-E2-1-H',  'MIROC6',   'SAM0-UNICON',  'ERA5',
#              'GFDL-ESM4',  'MIROC-ES2L',   'MRI-ESM2-0']#
modelsName = [x for x in os.listdir(dataPath) if not x.startswith('.')]
for m in modelsName:
    if not m=='E3SM-1-1' :
        chkPath=dataPath+m+'/checkpoints/'
        print(m)
        tag=os.listdir(chkPath)[0]
        #print(tag)
        ckpt_path = chkPath+tag+'/'
        #print(ckpt_path)
        cnf=dataPath+m+'/wandb/'
        extCnf=[x for x in os.listdir(cnf) if (x.startswith('run-2022') & x.endswith(tag))][0]
        #print(extCnf)
        cnfString='../../Data/ckpoints/'+m+'/wandb/'+extCnf+'/files'
        print(cnfString)
        #config_path=
        #config_path = run-20220927_'++'-cicbmrwi/files'
        #print(cnfString)
        config_name = 'hydra_config.yaml'
        ckptFile=[x for x in os.listdir(ckpt_path) if  x.startswith('nonorm_0h_epoch0')][0]
        ckpt_name =ckptFile
        print(ckpt_name)
        model,config = load_model(cnfString,config_name,ckpt_path,ckpt_name)
        inFiles=['netSurfcs_nonorm_othersRegressedOnPC','crelSurf_nonorm_othersRegressedOnPC','cresSurf_nonorm_othersRegressedOnPC',
            'netTOAcs_nonorm_othersRegressedOnPC','cres_nonorm_othersRegressedOnPC','crel_nonorm_othersRegressedOnPC']
        for inputs in inFiles : #### loop for input files
            filename_input='ERA5_Input_EOFs_{0}.nc'.format(inputs) 
            DATA_DIR2 = "/home/jupyter-dipti/work/AiBEDO_simultaneousPreds_Salva/Data/Predictions" 
            ds_input = xr.open_dataset(f"{DATA_DIR2}/{filename_input}")  # Input data            
            input_ml = get_pytorch_model_data(ds_input, ds_output, input_vars=model.main_input_vars)
    #print(config)
    ### Get AiBEDO predictions
            predictions_ml = predict_with_aibedo_model(model, input_ml)

            tas_nonorm=predictions_ml['tas_nonorm'].numpy()
            pred_tas_nonorm=gt_tas_nonorm.copy() ### just to copy coordinates
            pred_tas_nonorm.values=tas_nonorm
    
            ps_nonorm=predictions_ml['ps_nonorm'].numpy()
            pred_ps_nonorm=gt_ps_nonorm.copy() ### just to copy coordinates
            pred_ps_nonorm.values=ps_nonorm
    
            pr_nonorm=predictions_ml['pr_nonorm'].numpy()
            pred_pr_nonorm=gt_pr_nonorm.copy() ### just to copy coordinates
            pred_pr_nonorm.values=pr_nonorm
    
            output_ds = pred_tas_nonorm.to_dataset(name = 'pred_tas_nonorm')
            # Add next DataArray to existing dataset (ds)
            output_ds['pred_ps_nonorm'] = pred_ps_nonorm
            output_ds['pred_pr_nonorm'] = pred_pr_nonorm
            out_path = '/home/jupyter-keighan/work/AiBEDOwork/Simultaneous_Preds/ENSO_FNO_trained_on_{0}_{1}.nc'.format(m,inputs)
            output_ds.to_netcdf(path=out_path,mode='w',format='NETCDF4')
print('allDone')


CESM2-FV2
../../Data/ckpoints/CESM2-FV2/wandb/run-20220926_202705-84ad4gyn/files
nonorm_0h_epoch013_seed7.ckpt
CESM2-WACCM-FV2
../../Data/ckpoints/CESM2-WACCM-FV2/wandb/run-20220926_212814-d6et07fq/files
nonorm_0h_epoch014_seed7.ckpt
CESM2-WACCM
../../Data/ckpoints/CESM2-WACCM/wandb/run-20220926_205901-koxjbllh/files
nonorm_0h_epoch014_seed7.ckpt
CESM2
../../Data/ckpoints/CESM2/wandb/run-20220926_195019-b8he6k5n/files
nonorm_0h_epoch014_seed7.ckpt
CMCC-CM2-SR5
../../Data/ckpoints/CMCC-CM2-SR5/wandb/run-20220926_215722-3b6a0h56/files
nonorm_0h_epoch014_seed7.ckpt
CanESM5
../../Data/ckpoints/CanESM5/wandb/run-20220926_222554-dtgu4t2c/files
nonorm_0h_epoch014_seed7.ckpt
FGOALS-g3
../../Data/ckpoints/FGOALS-g3/wandb/run-20220926_232552-3no05o85/files
nonorm_0h_epoch014_seed7.ckpt
GISS-E2-1-H
../../Data/ckpoints/GISS-E2-1-H/wandb/run-20220927_014857-27bwymay/files
nonorm_0h_epoch013_seed7.ckpt
MIROC-ES2L
../../Data/ckpoints/MIROC-ES2L/wandb/run-20220927_121224-3nodghgw/files
nonorm_0h_epoch