In [None]:
import torch
import sys
import modulus

sys.path.append('/global/cfs/cdirs/m4334/jerry/climsim3_dev/baseline_models/pao_model/training_v6/')


from pao_model import PaoModel

device = torch.device('cuda')

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_inf(model):
    rand_batch = torch.rand(2, 1399).to(device)
    with torch.no_grad():
        print('Before first traced_model call')
        pred = model(rand_batch)
        print('After first traced_model call')
        print('first batch done')
        print('Before second traced_model call')
        pred = model(rand_batch)
        print('After second traced_model call')
        print('second batch done')

new_model = PaoModel(input_profile_num = 23, input_scalar_num = 19).to(device)
print('begin tracing and scripting')
scripted_model = torch.jit.script(new_model)
scripted_model.eval()
traced_model = torch.jit.trace(new_model, torch.rand(2, 1399).to(device))
traced_model.eval()
check_inf(scripted_model)
check_inf(traced_model)

begin tracing and scripting
Before first traced_model call
After first traced_model call
first batch done
Before second traced_model call
After second traced_model call
second batch done
Before first traced_model call
After first traced_model call
first batch done
Before second traced_model call
After second traced_model call
second batch done


In [2]:
pao_model_path_modulus = '/global/homes/j/jerrylin/scratch/hugging/E3SM-MMF_ne4/saved_models/climsim3_ensembles_v6/pao_model/pao_model_seed_43/model.mdlus'
modulus_model = modulus.Module.from_checkpoint(pao_model_path_modulus).to(device)
new_model = PaoModel(
                input_profile_num = 23,
                input_scalar_num = 19,
                target_profile_num = 5,
                target_scalar_num = 8,
                output_prune = True,
                strato_lev_out = 12,
                loc_embedding = False,
                embedding_type = "positional",
                hidden_profile_num = 160,
                hidden_scalar_num = 160,
            ).to(device)
new_model.load_state_dict(modulus_model.state_dict())
new_scripted_model = torch.jit.script(new_model)
new_scripted_model.eval()

RecursiveScriptModule(
  original_name=PaoModel
  (feature_scale_list): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(original_name=FeatureScale)
    (1): RecursiveScriptModule(original_name=FeatureScale)
    (2): RecursiveScriptModule(original_name=FeatureScale)
    (3): RecursiveScriptModule(original_name=FeatureScale)
    (4): RecursiveScriptModule(original_name=FeatureScale)
    (5): RecursiveScriptModule(original_name=FeatureScale)
    (6): RecursiveScriptModule(original_name=FeatureScale)
    (7): RecursiveScriptModule(original_name=FeatureScale)
    (8): RecursiveScriptModule(original_name=FeatureScale)
    (9): RecursiveScriptModule(original_name=FeatureScale)
    (10): RecursiveScriptModule(original_name=FeatureScale)
    (11): RecursiveScriptModule(original_name=FeatureScale)
    (12): RecursiveScriptModule(original_name=FeatureScale)
    (13): RecursiveScriptModule(original_name=FeatureScale)
    (14): RecursiveScriptModule(original_name=

In [3]:
check_inf(new_scripted_model)

Before first traced_model call
After first traced_model call
first batch done
Before second traced_model call
After second traced_model call
second batch done


In [4]:
save_path = '/global/homes/j/jerrylin/scratch/hugging/E3SM-MMF_ne4/saved_models/climsim3_ensembles_v6/pao_model/pao_model_seed_43/model.pt'

In [5]:
new_scripted_model.save(save_path)

In [None]:
check_inf(traced_model)

In [None]:
check_inf(torchscript_model_2)

In [None]:
type(new_model)

In [None]:
torch.save(new_model, 'new_model.pth')

In [None]:
loaded_model = torch.load('new_model.pth')

In [None]:
type(loaded_model)

In [None]:
loaded_model.eval()

In [None]:
check_inf(loaded_model)

In [None]:
from climsim_utils.data_utils import *
import numpy as np
import xarray as xr

In [None]:
grid_info = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/grid_info/ClimSim_low-res_grid-info.nc')
input_mean = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_mean_v2_rh_mc_pervar.nc')
input_max = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_max_v2_rh_mc_pervar.nc')
input_min = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_min_v2_rh_mc_pervar.nc')
output_scale = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/outputs/output_scale_std_lowerthred_v2_rh_mc.nc')
data = data_utils(grid_info = grid_info, 
                  input_mean = input_mean, 
                  input_max = input_max, 
                  input_min = input_min, 
                  output_scale = output_scale)
data.set_to_v2_rh_mc_vars()

In [None]:
plt.hist(data.lats)

In [None]:
train_set_paths = [f'/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/train_set/{x}/' for x in \
                   ['11', '12', '21', '22', '31', '32', '41', '42', '51', '52', '61', '62', '71', '72']]
train_input = np.concatenate([np.load(f'{train_set_path}/train_input.npy') for train_set_path in train_set_paths], axis = 0)

In [None]:
npy_mean = np.mean(train_input, axis = 0)
npy_std = np.std(train_input, axis = 0)

In [None]:
grid_info = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/grid_info/ClimSim_low-res_grid-info.nc')
input_mean = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_mean_v2_rh_mc_pervar.nc')
input_max = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_max_v2_rh_mc_pervar.nc')
input_min = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_min_v2_rh_mc_pervar.nc')
output_scale = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/outputs/output_scale_std_lowerthred_v2_rh_mc.nc')
data = data_utils(grid_info = grid_info, 
                  input_mean = input_mean, 
                  input_max = input_max, 
                  input_min = input_min, 
                  output_scale = output_scale)
data.set_to_v2_rh_mc_vars()
input_sub, input_div, out_scale = data.save_norm(write=False)

In [None]:
train_input.shape

In [None]:
train_input[0,0]

In [None]:
num_samples = 20000000
np.sum(train_input[:num_samples,:10], axis=0)/num_samples

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_input[:,0], bins=100)
plt.show()
#np.nanmean(train_input[:,:10], axis=0)

In [None]:
train_input[:,0][:60]

In [None]:
np.mean(train_input, axis = 0)[0:60]

In [None]:
np.mean(train_input, axis = 0, dtype = np.float64)[0:60]

In [None]:
16777215/60

In [None]:
print(npy_mean[:60])

In [None]:
print(input_sub[:60])

In [None]:
path_dict = {
input_mean_path: '/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_mean_v2_rh_mc_pervar.nc'
input_max_path: '/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_max_v2_rh_mc_pervar.nc'
input_min_path: '/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/input_min_v2_rh_mc_pervar.nc'}

In [None]:
10089984/2

In [None]:
woah.shape

### Instantiating class

The example below will save training data in both .h5 and .npy format. Adjust if you only need one format. Also adjust input_abbrev to the input data files you will use. We expanded the original '.mli.' input files to include additional features such as previous steps' information, and '.mlexpand.' was just an arbitrary name we used for the expanded input files.

Currently the training script would assume the training set is in .h5 format while the validation set is in .npy form. It's fine to only keep save_h5=True in the block below for generating training data.

In [None]:
grid_path = '/global/u2/z/zeyuanhu/nvidia_codes/Climsim_private/grid_info/ClimSim_low-res_grid-info.nc'
norm_path = '/global/u2/z/zeyuanhu/nvidia_codes/Climsim_private/preprocessing/normalizations/'

grid_info = xr.open_dataset(grid_path)
#no naming issue here. Here these normalization-related files are just placeholders since we set normalize=False in the data_utils.
input_mean = xr.open_dataset(norm_path + 'inputs/input_mean_v5_pervar.nc')
input_max = xr.open_dataset(norm_path + 'inputs/input_max_v5_pervar.nc')
input_min = xr.open_dataset(norm_path + 'inputs/input_min_v5_pervar.nc')
output_scale = xr.open_dataset(norm_path + 'outputs/output_scale_std_lowerthred_v5.nc')


In [None]:
data.input_vars

### Create training data

Below is an example of creating the training data by integrating the 7 year climsim simulation data. A subsampling of 1000 is used as an example. In the actual work we did, we used a stride_sample=1. We could not fit the full 7-year data into the memory wihout subsampling. If that's also the case for you, try to only process a subset of data at one time by adjusting regexps in set_regexps method. We saved 14 separate input .h5 files. For each year, we saved two files by setting start_idx=0 or 1. We have a folder like v2_full, which includes 14 subfolders named '11', '12', '21', '22', ..., '71','72', and each subfolder contains a train_input.h5 and train_target.h5. How you split to save training data won't influence the training. The training script will read in all the samples and randomly select samples across all the samples to form each batch.

In [None]:
# set regular expressions for selecting training data
data.set_regexps(data_split = 'train', 
                regexps = ['E3SM-MMF.mlexpand.000[1234567]-*-*-*.nc', # years 1 through 7
                        'E3SM-MMF.mlexpand.0008-01-*-*.nc']) # first month of year 8
# set temporal subsampling
data.set_stride_sample(data_split = 'train', stride_sample = 1000)
# create list of files to extract data from
data.set_filelist(data_split = 'train', start_idx=0)
# save numpy files of training data
data.save_as_npy(data_split = 'train', save_path = '/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/preprocessing/v2_example/')

### Create validation data

In [None]:
# set regular expressions for selecting validation data
data.set_regexps(data_split = 'val',
                 regexps = ['E3SM-MMF.mlexpand.0008-0[23456789]-*-*.nc', # months 2 through 9 of year 8
                            'E3SM-MMF.mlexpand.0008-1[012]-*-*.nc', # months 10 through 12 of year 8
                            'E3SM-MMF.mlexpand.0009-01-*-*.nc']) # first month of year 9
# set temporal subsampling
# data.set_stride_sample(data_split = 'val', stride_sample = 7)
data.set_stride_sample(data_split = 'val', stride_sample = 700)
# create list of files to extract data from
data.set_filelist(data_split = 'val')
# save numpy files of validation data
data.save_as_npy(data_split = 'val', save_path = '/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/preprocessing/v2_example/')

### Create test data

In [None]:
data.data_path = '/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/_test/'

data.set_to_v4_vars()

# set regular expressions for selecting validation data
data.set_regexps(data_split = 'test',
                 regexps = ['E3SM-MMF.mlexpand.0009-0[3456789]-*-*.nc', 
                            'E3SM-MMF.mlexpand.0009-1[012]-*-*.nc',
                            'E3SM-MMF.mlexpand.0010-*-*-*.nc',
                            'E3SM-MMF.mlexpand.0011-0[12]-*-*.nc'])
# set temporal subsampling
# data.set_stride_sample(data_split = 'test', stride_sample = 7)
data.set_stride_sample(data_split = 'test', stride_sample = 700)
# create list of files to extract data from
data.set_filelist(data_split = 'test')
# save numpy files of validation data
data.save_as_npy(data_split = 'test', save_path = '/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/preprocessing/v2_example/')

In [None]:
!ls /global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/preprocessing/v2_example/