In [1]:
# Python core
from typing import Optional, Callable, TypedDict, Union, Iterable, Tuple, NamedTuple, List
from dataclasses import dataclass
import datetime
from itertools import product
from concurrent import futures

# Scientific python
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# PyTorch
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
import pytorch_lightning as pl
from data_wrangling import get_forward_data, get_contiguous_segments, get_zarr_chunk_sequences,ERA5Dataset,ToTensor,worker_init_fn

## Data

In [2]:
ZARR = '/glade/derecho/scratch/wchapman/STAGING/All_2010_staged.zarr'

plt.rcParams['figure.figsize'] = (5, 5)
plt.rcParams['image.interpolation'] = 'none'

torch.cuda.is_available()



False

## Configure GPU

In [3]:
##############################################################################   
##+++++ configure GPUS 
##############################################################################       
gpu_id=-1
if gpu_id >= 0:
    device = "cuda"
    set_gpu(gpu_id)
    print('device available :', torch.cuda.is_available())
    print('device count: ', torch.cuda.device_count())
    print('current device: ',torch.cuda.current_device())
    print('device name: ',torch.cuda.get_device_name())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device('cpu')
        
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
        
print('using device: ', device)
##############################################################################   
##----- configure GPUS 
##############################################################################
    

using device:  cpu


In [4]:
%%time
forcing_data = get_forward_data(filename=ZARR).unify_chunks()
#this could be a problem!: 
zarr_chunk_boundaries = np.concatenate(([0], np.cumsum(forcing_data.chunks['time'])))

CPU times: user 2.46 s, sys: 135 ms, total: 2.6 s
Wall time: 2.67 s


In [5]:
Array = Union[np.ndarray, xr.DataArray]
IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images')

In [6]:
datetimes = forcing_data.time.values
contiguous_segments = get_contiguous_segments(
    dt_index = datetimes,
    min_timesteps = 36 * 1.5,
    max_gap = pd.Timedelta('1 hour'))

In [7]:
zarr_chunk_sequences = get_zarr_chunk_sequences(
    n_chunks_per_disk_load=1,
    zarr_chunk_boundaries=zarr_chunk_boundaries,
    contiguous_segments=contiguous_segments)

In [8]:
torch.manual_seed(42)    
dataset = ERA5Dataset(filename=ZARR,
    zarr_chunk_sequences=zarr_chunk_sequences,
    transform=transforms.Compose([
        ToTensor(),
    ]),
)

In [9]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=5,
    num_workers=1,  # timings:  4=13.8s; 8=11.6; 10=11.3s; 11=11.5s; 12=12.6s.  10=3it/s
    worker_init_fn=worker_init_fn,
    pin_memory=True,
    #persistent_workers=True
)

In [10]:
%%time
for i, batch in enumerate(dataloader):
    print(i, batch['historical_ERA5_images'].shape)

0 torch.Size([5, 1, 61, 640, 1280])
1 torch.Size([5, 1, 61, 640, 1280])
2 torch.Size([5, 1, 61, 640, 1280])
3 torch.Size([5, 1, 61, 640, 1280])
4 torch.Size([5, 1, 61, 640, 1280])
5 torch.Size([5, 1, 61, 640, 1280])
6 torch.Size([5, 1, 61, 640, 1280])
7 torch.Size([5, 1, 61, 640, 1280])


KeyboardInterrupt: 