In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import numpy as np

import cartopy.crs as ccrs

import torch
import torch.nn as nn
import torch.utils.data as data
import torch_geometric
from torch.nn import Sequential as Seq, Linear, ReLU
from Networks import *
from Data_Functions import *
from matplotlib.animation import FuncAnimation
from Utils import *
from Subgrid_Funcs import *
import torch.distributed as dist

from Parallel import *
from torch.utils.data import Dataset, DataLoader
import os 
import sys
import random

Here is a set of functions and use case snippets to help give an overview of where we would most need guidence in optimizing our code. The biggest need at the moment is in managing data throughout the training process, particularily as we aim to scale up the quantity of avalible training data.

I think there are a few primary areas where we become bottlenecked with data:

1. Loading data onto RAM prior to training. Since the data that we have is large and stored in Xarray, the current process involves loading the dataset full dataset onto RAM prior to training. This is already a large memory cost, particularly when we train across multiple recurrent passes (further bottlneck details below). It would be helpful to get advice on addressing this memory cost, or better practices with how to train from data on disk (particularly how to strike a balance between storing a copy of highly preprocessed data and efficiency)
2. When constructing the dataset, we build copies of the data for each recurrent pass, i.e. the inputs for a given step, n, would be $\Phi(t=n\Delta t:(N_{samples}+n)\Delta t)$. It would probably be more memory efficient to store a single copy of values from $0:(N_{samples}+N_{steps})\Delta t)$, where $N_{samples}$ is the total number of recurrent passes in training and $N_{steps}$ is the number of reccurent passes during training. I imagine that using pointers in place of copies would be better, but I am not sure quite how to go about that.
3. At the moment, when using multiple GPUs, I load a chunk of the data set onto each GPU. We use staggered data in each chunk so that each GPU has access to data throughout the dataset. This is to avoid the bottleneck of having to load all of the data linearly and speeds up the code significantly. Not sure if this is best practice or if the adjusments to other points may allow for a better split across GPUs.
4. We are moving to synthesise data across several sources. For this we want to both be able to train simultaneously, but it is not exactly clear the best way to synthesize this given the previous constraints.

In [1]:
# function for loading data from zarr stores

# input_vars - variables given as input to the network (usually the same as the output)
# extra_vars - aditional boundary or residual values given to the network (these are replaced with ground truth during rollout)
# output_vars - variables predicted by the network (usually the input at a future time step)
# lag - number of time steps for predictions (our time step is 1 day)
# lat - latitude bounds for the region to consider
# lon - longitude bounds for the region to consider
# Nb - the number of lateral boundary points for the data
# run_type = forcing condition to use (can be "" for preindustrial control or "2x" for CO2 doubling experiment) 

def gen_data_025_lateral(input_vars,extra_vars,output_vars,lag,lat,lon,Nb=2,run_type=""):
    var_dict = {"um":"u_mean","vm":"v_mean","Tm":"T_mean",
                "ur":"u_res","vr":"v_res","Tr":"T_res",
               "u":"u","v":"v","T":"T",
               "tau_u":"tau_u","tau_v":"tau_v","tau":"tau",
               "t_ref":"t_ref"}
    
    
    if run_type != "":
        run_type = "_" + run_type
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        data = xr.open_zarr("/scratch/as15415/Data/Emulation_Data/Global_Ocean_025deg"+run_type+".zarr").sortby("time")
        try:
            data_atmos = xr.open_zarr("/scratch/as15415/Data/Emulation_Data/Data_Atmos_025_deg"+run_type+".zarr").drop(["xu_ocean","T_mean"]).assign_coords({"lon":data.xu_ocean.data}).sortby("time")
        except:
            data_atmos = xr.open_zarr("/scratch/as15415/Data/Emulation_Data/Data_Atmos_025_deg"+run_type+".zarr").sortby("time")    
        data_atmos = data_atmos.rename_dims({"lat":"yu_ocean","lon":"xu_ocean"})
        data_atmos = data_atmos.rename({"lat":"yu_ocean","lon":"xu_ocean"})
    
    
        data = data.sel(time=slice(data_atmos.time[0],data_atmos.time[-1]))
        data_atmos = data_atmos.sel(time=slice(data.time[0],data.time[-1]))
    
    data = xr.merge([data,data_atmos])
    
    data = data.sel(yu_ocean=slice(lat[0],lat[1]),xu_ocean=slice(lon[0],lon[1]))
    
    inputs = []
    extra_in = []
    outputs = []
    
    for var in input_vars:
        inputs.append(data[var_dict[var]])

    for var in extra_vars:
        extra_in.append(data[var_dict[var]])

    for var in input_vars:             
        temp = data[var_dict[var]].copy(deep=True)
        temp[:,Nb:-Nb,Nb:-Nb]  = 0.0 *temp[0,Nb:-Nb,Nb:-Nb]
        extra_in.append(temp)
        
    for var in output_vars:
        outputs.append(data[var_dict[var]][lag:])
        
    inputs = tuple(inputs)
    extra_in = tuple(extra_in)
    outputs = tuple(outputs)

    return inputs, extra_in, outputs


# function for generating a set of lagged inputs
# s - starting index that the dataset is indexing from
# e - ending index that the dataset is indexing from (note, the data may go past this index
#    when including outputs or future reccurent passes)
# interval - subsampling between training snapshots
# lag - number of time steps for predictions (our time step is 1 day)
# hist - number of previous time steps included as an input the network input (currently only use hist = 1)
# inputs - array of xarray DataArrays for each input variable as inputs are defined above
# extra_in - array of xarray DataArrays for each boundary or residual variable given to the network (these are replaced with ground truth during rollout)



def gen_data_in(step,s,e,interval,lag,hist,inputs,extra_in):
    s = s+lag*step
    e = e+lag*step    
    num_outs = len(inputs)
    num_extra = len(extra_in)
    temp_inputs = []
    for j in range(num_outs):
        temp_inputs.append(inputs[j][s:e:interval].to_numpy())
    temp_extra = []
    for j in range(num_extra):
        temp_extra.append(extra_in[j][s:e:interval].to_numpy())
        
    data_in = np.stack((*temp_inputs,
                         *temp_extra),-1)
    
    for i in range(hist):
        temp_inputs = []
        for j in range(num_outs):
            temp_inputs.append(np.expand_dims(inputs[j][s-(hist-i)*lag:e-(hist-i)*lag:(interval)].to_numpy(),-1))
        data_in = np.concatenate((data_in,*temp_inputs),axis=3)
    return data_in


# function for generating a corresponding set of lagged inputs
# s - starting index that the dataset is indexing from
# e - ending index that the dataset is indexing from (note, the data may go past this index
#    when including outputs or future reccurent passes)
# interval - subsampling between training snapshots
# inputs - array of xarray DataArrays for each output variable as outputs are defined above

def gen_data_out(step,s,e,lag,interval,outputs):
    s = s+lag*step
    e = e+lag*step
    
    num_outs = len(outputs)
    temp_outputs = []
    for j in range(num_outs):
        temp_outputs.append(outputs[j][s:e:interval].to_numpy())    
    
    data_out = np.stack(temp_outputs,-1)
    return data_out




In [None]:
# sample code for loading a dataset for a region of the globe


#select the inputs and additional boundary information

# choose u,v,T
exp_num_in = "3"
# choose tau_u, tau_v, T_atm
exp_num_extra = "12"
# choose u,v,T
exp_num_out = "2"

#Choose region for training
region = "Gulf_Stream_Ext"  

#Choose subsampling interval (we use 1 in all our cases)
interval = 1

#Choose number of samples for training and validation
N_samples = 4000
N_val = 300


# choose the number of previous steps included in the input (we use hist = 0, this is from previous exploration
# and kept for potential future exploration)
hist = 0
# lag - number of time steps between predictions (our time step is 1 day)
lag = 1
# set number of recurrent passes used during training
steps = 4

# set number of boundary points given during training
Nb = 4

# get domain boundaries for the region
if region == "Gulf_Stream_Ext"
    lat = [27, 50]
    lon = [-82,-35]      
    
    
s_train = lag*hist
e_train = s_train + N_samples*interval
e_test = e_train + interval*N_val



device = "cpu"


inpt_dict = {"1":["um","vm"],"2":["um","vm","ur","vr"],"3":["um","vm","Tm"],
            "4":["um","vm","ur","vr","Tm","Tr"],"5":["ur","vr"],"6":["ur","vr","Tr"],
            "7":["Tm"],"8":["Tm","Tr"],"9":["u","v"],"10":["u","v","T"],
            "11":["tau_u","tau_v"],"12":["tau_u","tau_v","t_ref"]} 
extra_dict = {"1":["ur","vr"],"2":["ur","vr","Tm"],
            "3":["Tm"],"4":["ur","vr","Tm","Tr"],"5":[],"6":["um","vm"],
             "7":["um","vm","Tm"], "8": ["um","vm","Tm","Tr"],
              "9":["ur","vr","tau_u","tau_v"],"10":["tau_u","tau_v"],
              "11":["t_ref"],"12":["tau_u","tau_v","t_ref"],
             "13":["ur","vr","Tr","tau_u","tau_v","t_ref"]} 
out_dict = {"1":["um","vm"],"2":["um","vm","Tm"],"3":["ur","vr"],
           "4":["ur","vr","Tr"],"5":["u","v"],"6":["u","v","T"]}

grids = xr.open_dataset('/scratch/zanna/data/CM2_grids/Grid_cm25_Vertices.nc')
grids = grids.sel({"yu_ocean":slice(lat[0],lat[1]),"xu_ocean":slice(lon[0],lon[1])})

area = torch.from_numpy(grids["area_C"].to_numpy()).to(device=device)

inputs = inpt_dict[exp_num_in]
extra_in = extra_dict[exp_num_extra]
outputs = out_dict[exp_num_out]

str_in = "".join([i + "_" for i in inputs])
str_ext = "".join([i + "_" for i in extra_in])
str_out = "".join([i + "_" for i in outputs])

print("inputs: " + str_in)
print("extra inputs: " + str_ext)
print("outputs: " + str_out)

N_atm = len(extra_in)
N_in = len(inputs)
N_extra = N_atm + N_in
N_out = len(outputs)

num_in = int((hist+1)*N_in + N_extra)


inputs, extra_in, outputs = gen_data_025_lateral(inputs,extra_in,outputs,lag,lat,lon,Nb)

wet = xr.zeros_like(inputs[0][0])
for data in inputs:
    wet +=np.isnan(data[0])
wet = np.isnan(xr.where(wet==0,np.nan,0))
wet = np.nan_to_num(wet.to_numpy())
wet = torch.from_numpy(wet).type(torch.float32).to(device="cpu")

data_in_train = []
data_out_train = []
for i in range(steps):
    data_in_train.append(gen_data_in(i,s_train,e_train,
                                     interval,lag,
                                     hist,inputs,extra_in))
    data_out_train.append(gen_data_out(i,s_train,e_train,
                                       lag,interval,
                                       outputs))

train_data = data_CNN_steps_Lateral(data_in_train,data_out_train,
                                        steps,wet,N_atm,Nb,device=device)   