In [75]:
import pandas as pd
import numpy as np
import xarray as xr

import matplotlib.pyplot as plt
import seaborn as sns

import gcsfs
from datetime import datetime as dt
import cftime
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/srv/conda/envs/notebook'
import sys

#from utils import load_dataset, split_dataset, split_vars

fs = gcsfs.GCSFileSystem()
fs.ls("gs://leap-persistent-ro/sungdukyu") # List files in the bucket where the E3SM-MMF dataset is stored

['leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.grid-info.zarr',
 'leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.input.zarr',
 'leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr',
 'leap-persistent-ro/sungdukyu/testing']

## Loading Data: .zarr --> xarray

In [76]:
input_vars = ['cam_in_ASDIR', 'pbuf_LHFLX', 'state_q0001']
output_vars = ['cam_out_NETSW', 'cam_out_PRECC', 'state_q0001']

In [77]:
def load_vars_xarray(input_vars, output_vars, downsample=True, chunks = True):
    # raw files, not interpolated according to Yu suggestion
    if(chunks):
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.input.zarr')
        inp = xr.open_dataset(mapper, engine='zarr', chunks={'sample' : 720})
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr')
        output = xr.open_dataset(mapper, engine='zarr', chunks={'sample' : 720})
    else:
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.input.zarr')
        inp = xr.open_dataset(mapper, engine='zarr')
        mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr')
        output = xr.open_dataset(mapper, engine='zarr')
        
    ds = inp[input_vars]
    for var in output_vars:
        ds['out_'+var] = output[var]
        
    if downsample: # might as well do first
        inp = inp.isel(sample = np.arange(36,len(inp.sample),72)) #  every 1 day
        output = output.isel(sample = np.arange(36,len(output.sample),72))
        if(chunks): # can afford to do?
            print("Daily average")
            ds = ds.coarsen(sample = 72).mean()
        else:
            print("Noon each day")
            ds = ds.isel(sample = np.arange(36,len(inp.sample),72))
    time = pd.DataFrame({"ymd":inp.ymd, "tod":inp.tod})
    # rename sample to reformatted time column 
    f = lambda ymd, tod : cftime.DatetimeNoLeap(ymd//10000, ymd%10000//100, ymd%10000%100, tod // 3600, tod%3600 // 60)
    time = time.apply(lambda x: f(x.ymd, x.tod), axis=1)
    ds['sample'] = list(time)
    ds = ds.rename({'sample':'time'})
    ds = ds.assign_coords({'ncol' : ds.ncol})
    
    ds['lat'] = (('ncol'),lat.T)
    ds['lon'] = (('ncol'),lon.T)
    
    ds = ds.assign_coords({'lat' : ds.lat, 'lon' : ds.lon})
    
    return(ds)

In [78]:
%%time
ds = load_vars_xarray(input_vars, output_vars, downsample=True, chunks=True)



Daily average
CPU times: user 1.61 s, sys: 147 ms, total: 1.75 s
Wall time: 2.55 s


In [79]:
#ds = ds.coarsen(time = 72).mean()
ds

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.55 MiB 30.00 kiB Shape (2920, 384) (10, 384) Dask graph 292 chunks in 5 graph layers Data type float64 numpy.ndarray",384  2920,

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.55 MiB 30.00 kiB Shape (2920, 384) (10, 384) Dask graph 292 chunks in 5 graph layers Data type float64 numpy.ndarray",384  2920,

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,513.28 MiB,1.76 MiB
Shape,"(2920, 60, 384)","(10, 60, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 513.28 MiB 1.76 MiB Shape (2920, 60, 384) (10, 60, 384) Dask graph 292 chunks in 5 graph layers Data type float64 numpy.ndarray",384  60  2920,

Unnamed: 0,Array,Chunk
Bytes,513.28 MiB,1.76 MiB
Shape,"(2920, 60, 384)","(10, 60, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.55 MiB 30.00 kiB Shape (2920, 384) (10, 384) Dask graph 292 chunks in 5 graph layers Data type float64 numpy.ndarray",384  2920,

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.55 MiB 30.00 kiB Shape (2920, 384) (10, 384) Dask graph 292 chunks in 5 graph layers Data type float64 numpy.ndarray",384  2920,

Unnamed: 0,Array,Chunk
Bytes,8.55 MiB,30.00 kiB
Shape,"(2920, 384)","(10, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,513.28 MiB,1.76 MiB
Shape,"(2920, 60, 384)","(10, 60, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 513.28 MiB 1.76 MiB Shape (2920, 60, 384) (10, 60, 384) Dask graph 292 chunks in 5 graph layers Data type float64 numpy.ndarray",384  60  2920,

Unnamed: 0,Array,Chunk
Bytes,513.28 MiB,1.76 MiB
Shape,"(2920, 60, 384)","(10, 60, 384)"
Dask graph,292 chunks in 5 graph layers,292 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [80]:
def split_vars(var_list, out=False):
    v = []
    leveled = []
    for var in var_list:
        if out:
            var = 'out_' + var
        if(len(ds[var].shape) > 2):
            leveled.append(var)
        else:
            v.append(var)
    return(v, leveled)

def split_input_output(ds):
    inp = []
    out = []
    for var in ds.data_vars:
        if(var[:3] == 'out'):
            out.append(var)
        else:
            inp.append(var)
    return(ds[inp], ds[out])

### Integrate spatial information

In [81]:
def load_latlon():
    mapper = fs.get_mapper("gs://leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.grid-info.zarr")
    ds_grid = xr.open_dataset(mapper, engine='zarr')
    lat = ds_grid.lat.values.round(2) 
    lon = ds_grid.lon.values.round(2)  
    return(lat, lon)
lat, lon = load_latlon()

#### Filter by lat lon info

In [82]:
def select_region(condition):
    # assumes condition is a lambda function taking in a lat and lon
    # returns the indices for which this is true
    lat, lon = load_latlon()
    latlon = pd.DataFrame({"lat" : lat, "lon": lon})
    return(list(latlon[latlon.apply(condition, axis=1)].index)) # the indices of the matching latlons

def split_ds_by_area(ds, condition):
    match = select_region(condition)
    unmatch = select_region(lambda row : not condition(row))
    return(ds.isel(ncol=match), ds.isel(ncol=unmatch))

In [68]:
# EXAMPLE 
f = lambda row : abs(row.lat) < 30

In [69]:
train, test = split_ds_by_area(ds, f)

In [70]:
train

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 256.64 MiB 900.00 kiB Shape (2920, 60, 192) (10, 60, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  60  2920,

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 256.64 MiB 900.00 kiB Shape (2920, 60, 192) (10, 60, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  60  2920,

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [72]:
test

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 256.64 MiB 900.00 kiB Shape (2920, 60, 192) (10, 60, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  60  2920,

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.28 MiB 15.00 kiB Shape (2920, 192) (10, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  2920,

Unnamed: 0,Array,Chunk
Bytes,4.28 MiB,15.00 kiB
Shape,"(2920, 192)","(10, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 256.64 MiB 900.00 kiB Shape (2920, 60, 192) (10, 60, 192) Dask graph 292 chunks in 6 graph layers Data type float64 numpy.ndarray",192  60  2920,

Unnamed: 0,Array,Chunk
Bytes,256.64 MiB,900.00 kiB
Shape,"(2920, 60, 192)","(10, 60, 192)"
Dask graph,292 chunks in 6 graph layers,292 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


### Frontload Computation: load into memory if size fits

In [73]:
ds.nbytes / 1e9 # GB

1.112342336

In [84]:
from dask.diagnostics import ProgressBar

# visualize with progress bar
# use .load() or .compute() to do the math and get the daily mean data

with ProgressBar():
    # if use lat and lon as the dimensions
    # ds_processed = ds_sel[['state_q0001','cam_out_PRECC']].load()
    # if use ncol as the dimensions
    ds.load()

[########################################] | 100% Completed | 12m 39s


## Loading Minibatches: xarray --> batch

### Grab Xarray Chunks: xarray --> xarray batch sized

In [7]:
#ds = ds.chunk(chunks={'time':292, 'ncol' : 384})
#ds.pbuf_LHFLX.data.blocks[3]

#### Wait, why do we need Dask??

Assumptions going forward: ds is a data array with coordinates time and ncol, with a mix of level and not level variables. We are ignoring spatial and temporal linkages and looking at all vectors as i.i.d. Indexing will occur along the ncol axis first, and then the time. 

In [8]:
max_i = ds.time.size * ds.ncol.size # number of vectors in S
print(max_i)
import random
i = random.randrange(0, max_i) # randint is inclusive, but cannot actually have 
i

1121280


1051639

In [93]:
def get_item(ds, index):
    # t * ds.ncol.size + col == i
    # given an index, wrap around (time x ncol) grid selecting specific variable
    # converting linear indexing into structured
    assert index < ds.time.size * ds.ncol.size, "Index is outside of range"
    t, col = index // ds.ncol.size, index % ds.ncol.size
    return(ds.isel(time=t, ncol=col))

In [10]:
%%time
get_item(ds, max_i-1)

CPU times: user 1.28 ms, sys: 0 ns, total: 1.28 ms
Wall time: 1.23 ms


In [94]:
def get_batch(ds, batch_num, batch_size = 32, dim = 'ncol'):
    #same kind of linear index interpretation, except over a batch size
    # doing over ncol because 384 = 3 * 2**7, which splits nicely over powers of 2
    n_batch = ds[dim].size / batch_size
    other_dim_batch, dim_batch = int(batch_num // n_batch), batch_num % n_batch
    start, stop = int(dim_batch * batch_size), int((dim_batch+1) * batch_size)
    if(dim == 'ncol'):
        print(f"ncol from {start}-{stop}; time={other_dim_batch}")
        return(ds.isel(ncol=slice(start, stop), time=other_dim_batch))
    elif(dim=='time'):
        return(ds.isel(time=slice(start, stop), ncol=other_dim_batch))

In [95]:
%%time
get_batch(ds, 345)

ncol from 288-320; time=28
CPU times: user 1.08 ms, sys: 40 µs, total: 1.12 ms
Wall time: 1.11 ms


### Stacking: xarray --> numpy array... BOTTLENECK 4 seconds! :(

In [96]:
# manual into minibatch
batch = get_batch(ds, 0)
print(sys.getsizeof(batch))
print(sys.getsizeof(batch.cam_in_ASDIR.data))
X, Y = split_input_output(batch)
batch

ncol from 0-32; time=0
112
112


In [97]:
%%time
arr = X.to_stacked_array("v", sample_dims=["ncol"])
arr.shape

CPU times: user 17.2 ms, sys: 2.08 ms, total: 19.3 ms
Wall time: 21.3 ms


(32, 62)

In [98]:
print(sys.getsizeof(arr))
#print(sys.getsizeof(arr.data)) # 16 Megabytes
#print(sys.getsizeof(arr.values)) # 16 Megabytes

96


In [17]:
arr.shape # YAY - batch_size x in_dims! 

(32, 62)

#### Tried and failed frontloading stacking, but crashed because many GB

In [18]:
#ds = ds.stack({'batch':{'ncol'}})
#ds = ds.to_stacked_array("mlvar", sample_dims=["batch"], name='mli')

In [19]:

# X, Y = split_dataset(ds)
#X.to_stacked_array("v", sample_dims=["ncol"])

In [20]:
xarr = ds.isel(time=5, ncol=slice(14, 24))
xarr

In [21]:
%%time
xarr.to_stacked_array("v", sample_dims=["ncol"])

CPU times: user 2.04 s, sys: 1.17 s, total: 3.2 s
Wall time: 4.99 s


## Conversions of batches into various forms

### Generator

In [22]:
def gen():
    for file in filelist:
        # read mli
        ds = xr.open_dataset(file, engine='netcdf4')
        ds = ds[vars_mli]

        # read mlo
        dso = xr.open_dataset(file.replace('.mli.','.mlo.'), engine='netcdf4')

        # make mlo variales: ptend_t and ptend_q0001
        dso['ptend_t'] = (dso['state_t'] - ds['state_t'])/1200 # T tendency [K/s]
        dso['ptend_q0001'] = (dso['state_q0001'] - ds['state_q0001'])/1200 # Q tendency [kg/kg/s]
        dso = dso[vars_mlo]

        # normalizatoin, scaling
        ds = (ds-mli_mean)/(mli_max-mli_min)
        dso = dso*mlo_scale

        # stack
        #ds = ds.stack({'batch':{'sample','ncol'}}) # this line was for data files that include 'sample' dimension
        ds = ds.stack({'batch':{'ncol'}})
        ds = ds.to_stacked_array("mlvar", sample_dims=["batch"], name='mli')
        #dso = dso.stack({'batch':{'sample','ncol'}})
        dso = dso.stack({'batch':{'ncol'}})
        dso = dso.to_stacked_array("mlvar", sample_dims=["batch"], name='mlo')

        yield (ds.values, dso.values) # generating a tuple of (input, output)

### Torch Dataloader

In [86]:
ds

In [103]:
%%time
batch = get_batch(ds, 324, 1)
X, Y = split_input_output(batch)
X = X.to_stacked_array("v", sample_dims=["ncol"]).values
Y = Y.to_stacked_array("v", sample_dims=["ncol"]).values
X = torch.tensor(X, device='cuda', dtype=torch.float32)
Y = torch.tensor(Y, device='cuda', dtype=torch.float32)

ncol from 324-325; time=0
CPU times: user 31.1 ms, sys: 1.05 ms, total: 32.2 ms
Wall time: 31.5 ms


In [104]:
X

tensor([[6.9022e-01, 6.4177e+01, 1.4482e-06, 1.4273e-06, 1.3772e-06, 1.3208e-06,
         1.2743e-06, 1.2605e-06, 1.2669e-06, 1.2840e-06, 1.3485e-06, 1.4484e-06,
         1.4654e-06, 1.4297e-06, 1.3907e-06, 1.3600e-06, 1.3280e-06, 1.3215e-06,
         1.3276e-06, 1.3878e-06, 1.4228e-06, 2.4018e-06, 6.5749e-06, 9.9586e-06,
         1.0047e-05, 1.1068e-05, 1.3776e-05, 1.9508e-05, 2.8774e-05, 4.4119e-05,
         6.7585e-05, 1.0138e-04, 1.5071e-04, 2.1822e-04, 3.1029e-04, 4.3376e-04,
         5.9584e-04, 7.8916e-04, 1.0090e-03, 1.2619e-03, 1.5507e-03, 1.8482e-03,
         2.1329e-03, 2.4052e-03, 2.7050e-03, 2.9954e-03, 3.2995e-03, 3.6188e-03,
         3.9257e-03, 4.1684e-03, 4.3542e-03, 4.4863e-03, 4.6727e-03, 4.8765e-03,
         5.1133e-03, 5.3703e-03, 5.6767e-03, 5.9696e-03, 6.3691e-03, 6.6860e-03,
         6.8863e-03, 7.0488e-03]], device='cuda:0')

In [87]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

In [112]:
class MyDataset(Dataset):
    def __init__(self, ds):
        self.X_ds, self.Y_ds = split_input_output(ds)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.variable_mean = np.zeros((1, 62))
        self.variable_std = np.ones((1, 62))
        
    def __len__(self):
        return( self.X_ds.time.size * self.X_ds.ncol.size )
    
    def __getitem__(self, idx):
        t, col = idx // ds.ncol.size, idx % ds.ncol.size
        X = self.X_ds.isel(time=t, ncol=[col])
        Y = self.Y_ds.isel(time=t, ncol=[col])
        
        X = X.to_stacked_array("v", sample_dims=["ncol"]).values
        Y = Y.to_stacked_array("v", sample_dims=["ncol"]).values
        
        #X = (batch - batch.mean(axis=0)) / batch.std(axis=0)
        X = torch.as_tensor(X, device=self.device, dtype=torch.float32)
        
        Y = torch.as_tensor(Y, device=self.device, dtype=torch.float32)
        return(X, Y)

In [113]:
d = MyDataset(ds)
dataloader = DataLoader(d, batch_size=32, shuffle=True)

In [114]:
%%time
d[1503]

CPU times: user 32.8 ms, sys: 1.86 ms, total: 34.6 ms
Wall time: 37.1 ms


(tensor([[ 9.2766e-01, -4.9767e-01,  1.4728e-06,  1.4250e-06,  1.3276e-06,
           1.2436e-06,  1.1400e-06,  1.0619e-06,  1.0184e-06,  1.0561e-06,
           1.1603e-06,  1.3021e-06,  1.4517e-06,  1.5212e-06,  1.5381e-06,
           1.5434e-06,  1.5266e-06,  1.5167e-06,  1.5181e-06,  1.5224e-06,
           1.9029e-06,  3.9588e-06,  6.0397e-06,  6.1680e-06,  5.9012e-06,
           5.6712e-06,  6.0622e-06,  7.2396e-06,  8.8883e-06,  1.0821e-05,
           1.3040e-05,  1.6128e-05,  2.0336e-05,  2.4144e-05,  2.6510e-05,
           2.9723e-05,  3.9142e-05,  5.4668e-05,  7.8857e-05,  1.0298e-04,
           1.3485e-04,  1.7922e-04,  2.2610e-04,  2.6713e-04,  3.0284e-04,
           3.3600e-04,  3.6952e-04,  4.0281e-04,  4.3344e-04,  4.6309e-04,
           4.8926e-04,  5.1498e-04,  5.4001e-04,  5.6491e-04,  5.8982e-04,
           6.1702e-04,  6.4506e-04,  6.7137e-04,  6.9694e-04,  7.1842e-04,
           7.1162e-04,  6.9523e-04]], device='cuda:0'),
 tensor([[2.9556e+01, 3.8167e-09, 1.4728e-06

In [115]:
%%time
X_sample, Y_sample = next(iter(dataloader))

CPU times: user 1.58 s, sys: 51.2 ms, total: 1.63 s
Wall time: 1.03 s


### Tensorflow dataset 

In [31]:
dataset = tf.data.Dataset.from_generator(gen, 
  output_types=(tf.float64, tf.float64),
  output_shapes=((None,124),(None,128))
 )

NameError: name 'tf' is not defined