In [1]:
import pandas as pd
import numpy as np
import xarray as xr

import matplotlib.pyplot as plt
import seaborn as sns

import gcsfs
from datetime import datetime as dt
import cftime
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/srv/conda/envs/notebook'
import sys

#from utils import load_dataset, split_dataset, split_vars

fs = gcsfs.GCSFileSystem()
fs.ls("gs://leap-persistent-ro/sungdukyu") # List files in the bucket where the E3SM-MMF dataset is stored

['leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.grid-info.zarr',
 'leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.input.zarr',
 'leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr',
 'leap-persistent-ro/sungdukyu/testing']

## Loading Data: .zarr --> xarray

In [2]:
input_vars = ['cam_in_ASDIR', 'pbuf_LHFLX', 'state_q0001']
output_vars = ['cam_out_NETSW', 'cam_out_PRECC', 'state_q0001']

In [3]:
%%time
def load_vars_xarray(input_vars, output_vars, downsample=True, chunks = False):
    # raw files, not interpolated according to Yu suggestion
    if(chunks):
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.input.zarr')
        inp = xr.open_dataset(mapper, engine='zarr', chunks={})
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr')
        output = xr.open_dataset(mapper, engine='zarr', chunks={})
    else:
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.input.zarr')
        inp = xr.open_dataset(mapper, engine='zarr')
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr')
        output = xr.open_dataset(mapper, engine='zarr')
    if downsample: # might as well do first
        inp = inp.isel(sample = np.arange(36,len(inp.sample),72)) #  every 1 day
        output = output.isel(sample = np.arange(36,len(output.sample),72))
    ds = inp[input_vars]
    for var in output_vars:
        ds['out_'+var] = output[var]

    time = pd.DataFrame({"ymd":inp.ymd, "tod":inp.tod})
    # rename sample to reformatted time column 
    f = lambda ymd, tod : cftime.DatetimeNoLeap(ymd//10000, ymd%10000//100, ymd%10000%100, tod // 3600, tod%3600 // 60)
    time = time.apply(lambda x: f(x.ymd, x.tod), axis=1)
    ds['sample'] = list(time)
    ds = ds.rename({'sample':'time'})
    ds = ds.assign_coords({'ncol' : ds.ncol})
    return(ds)

ds = load_vars_xarray(input_vars, output_vars)

CPU times: user 2.63 s, sys: 149 ms, total: 2.78 s
Wall time: 3.16 s


In [4]:
print(sys.getsizeof(ds))
print(sys.getsizeof(ds.state_q0001))

112
96


In [5]:
def split_vars(var_list, out=False):
    v = []
    leveled = []
    for var in var_list:
        if out:
            var = 'out_' + var
        if(len(ds[var].shape) > 2):
            leveled.append(var)
        else:
            v.append(var)
    return(v, leveled)

def split_input_output(ds):
    inp = []
    out = []
    for var in ds.data_vars:
        if(var[:3] == 'out'):
            out.append(var)
        else:
            inp.append(var)
    return(ds[inp], ds[out])

## Loading Minibatches: xarray --> batch

### Grab Xarray Chunks: xarray --> xarray batch sized

In [6]:
input_vars = ['cam_in_ASDIR', 'pbuf_LHFLX', 'state_q0001']
output_vars = ['cam_out_NETSW', 'cam_out_PRECC', 'state_q0001']

ds = load_vars_xarray(input_vars, output_vars)
ds

In [7]:
#ds = ds.chunk(chunks={'time':292, 'ncol' : 384})
#ds.pbuf_LHFLX.data.blocks[3]

#### Wait, why do we need Dask??

Assumptions going forward: ds is a data array with coordinates time and ncol, with a mix of level and not level variables. We are ignoring spatial and temporal linkages and looking at all vectors as i.i.d. Indexing will occur along the ncol axis first, and then the time. 

In [8]:
max_i = ds.time.size * ds.ncol.size # number of vectors in S
print(max_i)
import random
i = random.randrange(0, max_i) # randint is inclusive, but cannot actually have 
i

1121280


365183

In [9]:
def get_item(ds, index):
    # t * ds.ncol.size + col == i
    # given an index, wrap around (time x ncol) grid selecting specific variable
    # converting linear indexing into structured
    assert index < ds.time.size * ds.ncol.size, "Index is outside of range"
    t, col = index // ds.ncol.size, index % ds.ncol.size
    return(ds.isel(time=t, ncol=col))

In [10]:
%%time
get_item(ds, max_i-1)

CPU times: user 1.24 ms, sys: 0 ns, total: 1.24 ms
Wall time: 1.23 ms


In [11]:
ds

In [12]:
def get_batch(ds, batch_num, batch_size = 32, dim = 'ncol'):
    #same kind of linear index interpretation, except over a batch size
    # doing over ncol because 384 = 3 * 2**7, which splits nicely over powers of 2
    n_batch = ds[dim].size / batch_size
    other_dim_batch, dim_batch = int(batch_num // n_batch), batch_num % n_batch
    start, stop = int(dim_batch * batch_size), int((dim_batch+1) * batch_size)
    if(dim == 'ncol'):
        print(f"ncol from {start}-{stop}; time={other_dim_batch}")
        return(ds.isel(ncol=slice(start, stop), time=other_dim_batch))
    elif(dim=='time'):
        return(ds.isel(time=slice(start, stop), ncol=other_dim_batch))

In [13]:
%%time
get_batch(ds, 345)

ncol from 288-320; time=28
CPU times: user 957 µs, sys: 0 ns, total: 957 µs
Wall time: 966 µs


### Stacking: xarray --> numpy array... BOTTLENECK 4 seconds! :(

In [14]:
# manual into minibatch
ds = load_vars_xarray(input_vars, output_vars, chunks=False)
batch = get_batch(ds, 0)
print(sys.getsizeof(batch))
print(sys.getsizeof(batch.cam_in_ASDIR.data))
X, Y = split_input_output(batch)
batch

ncol from 0-32; time=0
112
112


In [18]:
%%time
arr = X.to_stacked_array("v", sample_dims=["ncol"])
arr.shape

CPU times: user 682 ms, sys: 375 ms, total: 1.06 s
Wall time: 1.34 s


(32, 62)

In [16]:
print(sys.getsizeof(arr))
#print(sys.getsizeof(arr.data)) # 16 Megabytes
#print(sys.getsizeof(arr.values)) # 16 Megabytes

96


In [17]:
arr.shape # YAY - batch_size x in_dims! 

(32, 62)

#### Tried and failed frontloading stacking, but crashed because many GB

In [None]:
#ds = ds.stack({'batch':{'ncol'}})
#ds = ds.to_stacked_array("mlvar", sample_dims=["batch"], name='mli')

In [23]:

# X, Y = split_dataset(ds)
#X.to_stacked_array("v", sample_dims=["ncol"])

KeyboardInterrupt: 

In [201]:
xarr = ds.isel(time=5, ncol=slice(14, 24))
xarr

In [202]:
%%time
xarr.to_stacked_array("v", sample_dims=["ncol"])

CPU times: user 1.78 s, sys: 1.01 s, total: 2.79 s
Wall time: 3.69 s


## Conversions of batches into various forms

### Generator

In [None]:
def gen():
    for file in filelist:
        # read mli
        ds = xr.open_dataset(file, engine='netcdf4')
        ds = ds[vars_mli]

        # read mlo
        dso = xr.open_dataset(file.replace('.mli.','.mlo.'), engine='netcdf4')

        # make mlo variales: ptend_t and ptend_q0001
        dso['ptend_t'] = (dso['state_t'] - ds['state_t'])/1200 # T tendency [K/s]
        dso['ptend_q0001'] = (dso['state_q0001'] - ds['state_q0001'])/1200 # Q tendency [kg/kg/s]
        dso = dso[vars_mlo]

        # normalizatoin, scaling
        ds = (ds-mli_mean)/(mli_max-mli_min)
        dso = dso*mlo_scale

        # stack
        #ds = ds.stack({'batch':{'sample','ncol'}}) # this line was for data files that include 'sample' dimension
        ds = ds.stack({'batch':{'ncol'}})
        ds = ds.to_stacked_array("mlvar", sample_dims=["batch"], name='mli')
        #dso = dso.stack({'batch':{'sample','ncol'}})
        dso = dso.stack({'batch':{'ncol'}})
        dso = dso.to_stacked_array("mlvar", sample_dims=["batch"], name='mlo')

        yield (ds.values, dso.values) # generating a tuple of (input, output)

### Torch Dataloader

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

In [24]:
def get_item(ds, index):
    # t * ds.ncol.size + col == i
    # given an index, wrap around (time x ncol) grid selecting specific variable
    # converting linear indexing into structured
    #assert index < ds.time.size * ds.ncol.size, "Index is outside of range"
    t, col = index // ds.ncol.size, index % ds.ncol.size
    return(ds.isel(time=t, ncol=[col]))

In [25]:
def xarr_to_num(batch): # this is what needs fixing!!
    X, Y = split_input_output(batch)
    X = X.to_stacked_array("v", sample_dims=["ncol"])
    Y = Y.to_stacked_array("v", sample_dims=["ncol"])
    return(X.data, Y.data)

In [26]:
%%time
i = 5
ds = load_vars_xarray(input_vars, output_vars, chunks=False)
item = get_item(ds, i)
arr = xarr_to_num(item)
arr

CPU times: user 2.58 s, sys: 1.23 s, total: 3.81 s
Wall time: 6.22 s


(array([[2.36712391e-02, 1.59473319e+01, 1.52635864e-06, 1.51713157e-06,
         1.50599154e-06, 1.50305306e-06, 1.49129121e-06, 1.48804727e-06,
         1.50053707e-06, 1.51420816e-06, 1.53359911e-06, 1.51273951e-06,
         1.37479784e-06, 1.25656976e-06, 1.21852404e-06, 1.22556620e-06,
         1.23969603e-06, 1.26160567e-06, 1.27633541e-06, 1.29765723e-06,
         2.09165354e-06, 5.23632810e-06, 9.14774245e-06, 1.22523793e-05,
         1.78239034e-05, 3.08388750e-05, 4.85934562e-05, 7.39302969e-05,
         1.11556188e-04, 1.63352869e-04, 2.31616018e-04, 3.19801025e-04,
         4.31475370e-04, 5.68454119e-04, 7.40290604e-04, 9.39499661e-04,
         1.16504894e-03, 1.43078763e-03, 1.76282431e-03, 2.11967426e-03,
         2.48124999e-03, 2.86422676e-03, 3.12534903e-03, 3.41342055e-03,
         3.68315439e-03, 3.94559087e-03, 4.22140036e-03, 4.50410976e-03,
         4.96764525e-03, 5.57184314e-03, 5.94211258e-03, 6.17732949e-03,
         6.43215444e-03, 6.69945628e-03, 6.95327891

In [57]:
class MyDataset(Dataset):
    def __init__(self, input_vars, output_vars, chunk_size = 'auto'):
        self.ds = load_vars_xarray(input_vars, output_vars, chunks=False)
        self.chunks = ds.chunk(chunks=chunk_size)
        #self.vars, self.leveled_vars = split_vars(input_vars)
        #self.out_vars, self.out_leveled = split_vars(output_vars, out=True)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        ## maybe store global dataset mean and std here so that not taking batch means
        self.variable_mean = np.zeros((1, 62))
        self.variable_std = np.ones((1, 62))
        
    def __len__(self):
        return( self.ds.time.size * self.ds.ncol.size ) 
    
    def __getitem__(self, idx):
        xarr = get_item(self.ds, idx)
        X, Y = split_input_output(xarr)
        X = X.to_stacked_array("v", sample_dims=["ncol"]).data
        Y = Y.to_stacked_array("v", sample_dims=["ncol"]).data
        #print(X)
        #X = (batch - batch.mean(axis=0)) / batch.std(axis=0)
        X = torch.tensor(X, device=self.device, dtype=torch.float32)
        
        
        #batch = np.concatenate(dimvars, axis=0).T
        Y = torch.tensor(Y, device=self.device, dtype=torch.float32)
        return(X, Y)

In [58]:
d = MyDataset(input_vars, output_vars)
dataloader = DataLoader(d, batch_size=1, shuffle=True)

In [59]:
d[0]

(tensor([[2.7105e-02, 7.8169e+01, 1.4841e-06, 1.4736e-06, 1.4525e-06, 1.4392e-06,
          1.4281e-06, 1.4339e-06, 1.4447e-06, 1.4151e-06, 1.4082e-06, 1.3991e-06,
          1.2833e-06, 1.2035e-06, 1.2033e-06, 1.2150e-06, 1.2564e-06, 1.2815e-06,
          1.2821e-06, 1.2917e-06, 1.8498e-06, 3.0511e-06, 5.3337e-06, 8.8857e-06,
          1.4328e-05, 2.2197e-05, 3.1389e-05, 5.2703e-05, 8.1712e-05, 1.2874e-04,
          1.8742e-04, 2.6024e-04, 3.5626e-04, 4.6913e-04, 6.0638e-04, 7.6703e-04,
          9.1635e-04, 1.0284e-03, 1.1332e-03, 1.2918e-03, 1.4838e-03, 1.6882e-03,
          2.1142e-03, 2.4468e-03, 2.8813e-03, 2.9321e-03, 3.2465e-03, 3.6840e-03,
          4.1014e-03, 4.9582e-03, 6.2123e-03, 7.1295e-03, 7.6929e-03, 8.2727e-03,
          9.1327e-03, 9.9513e-03, 1.0124e-02, 1.0189e-02, 1.0222e-02, 1.0260e-02,
          1.0309e-02, 1.0392e-02]], device='cuda:0'),
 tensor([[8.3459e+02, 0.0000e+00, 1.4841e-06, 1.4736e-06, 1.4525e-06, 1.4392e-06,
          1.4281e-06, 1.4339e-06, 1.4447e-06

In [60]:
%%time
X_sample, Y_sample = next(iter(dataloader))

CPU times: user 2.69 s, sys: 1.1 s, total: 3.79 s
Wall time: 4.21 s


### Tensorflow dataset 

In [None]:
dataset = tf.data.Dataset.from_generator(gen, 
  output_types=(tf.float64, tf.float64),
  output_shapes=((None,124),(None,128))
 )