In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import wandb
import xarray
from dataset.dataset import ImageDataset, valid_test_split, SeviriDataset, pickle_write
from dataset.station_dataset import GroundstationDataset
from dataset.normalization import MinMax, ZeroMinMax
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.plugins.environments import SLURMEnvironment
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from models.ConvResNet_Jiang import ConvResNet, ConvResNet_dropout
from models.LightningModule import LitEstimator, LitEstimatorPoint
from tqdm import tqdm

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from utils.plotting import best_worst_plot, prediction_error_plot
from utils.etc import benchmark

from dask.distributed import Client
import dask
dask.config.set(scheduler='synchronous')
client = Client()

In [31]:
from types import SimpleNamespace


config = {
    "batch_size": 2048,
    "patch_size": {
        "x": 15,
        "y": 15,
        "stride_x": 1,
        "stride_y": 1,
    },
    "x_vars": [
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "channel_8",
        "channel_9",
        "channel_10",
        "channel_11",
        "DEM",
    ],
    "y_vars": ["SIS"],
    "x_features": ["dayofyear", "lat", "lon", 'SZA', "AZI",],
    "transform": ZeroMinMax(),
    "target_transform": ZeroMinMax(),
    # Compute related
    'num_workers': 12,
    'ACCELERATOR': "gpu",
    'DEVICES': -1,
    'NUM_NODES': 32,
    'STRATEGY': "ddp",
    'PRECISION': "32",
    'EarlyStopping': {'patience':5},
}
config = SimpleNamespace(**config)

In [33]:
config.EarlyStopping['patience']

5

In [3]:
seviri = xarray.open_zarr("/scratch/snx3000/kschuurm/ZARR/SEVIRI_new.zarr")
solarpos = xarray.open_zarr("/scratch/snx3000/kschuurm/ZARR/SOLARPOS_new.zarr")
sarah = xarray.open_zarr("/scratch/snx3000/kschuurm/ZARR/SARAH3_new.zarr")
with benchmark('dem load'):
    dem = xarray.open_zarr("/scratch/snx3000/kschuurm/ZARR/DEM.zarr").fillna(0).load()
a = seviri.channel_data.to_dataset(dim='channel')


dem load : 1.594 seconds


In [4]:
with benchmark('any'):
    idx = seviri.channel_data.isnull().any(dim=['x','y','channel']).compute()
    idx_val = idx.values

/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/distributed/client.py:3162: Sending large graph of size 35.30 MiB.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


any : 957.176 seconds


In [None]:
with benchmark('mean'):
    idx = seviri.channel_data.isnull().mean(dim=['x','y','channel']).compute()
    idx_val = idx.values

/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/distributed/client.py:3162: Sending large graph of size 37.85 MiB.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


In [11]:
idx.time[~idx_val].values


array(['2016-01-01T00:00:00.000000000', '2016-01-01T00:15:00.000000000',
       '2016-01-01T00:30:00.000000000', ...,
       '2022-12-31T23:15:00.000000000', '2022-12-31T23:30:00.000000000',
       '2022-12-31T23:45:00.000000000'], dtype='datetime64[ns]')

In [12]:
np.save('/scratch/snx3000/kschuurm/ZARR/idxnotnan_seviri.npy', idx.time[~idx_val].values)

In [None]:
trans = {          "VIS006": "channel_1",
            "VIS008": "channel_2",
            "IR_016": "channel_3",
            "IR_039": "channel_4",
            "WV_062": "channel_5",
            "WV_073": "channel_6",
            "IR_087": "channel_7",
            "IR_097": "channel_8",
            "IR_108": "channel_9",
            "IR_120": "channel_10",
            "IR_134": "channel_11",}

In [None]:
i= np.random.randint(0, 400, size=1000) # lon samples
j = np.random.randint(0, 400, size=1000) # lat samples

with benchmark('sarah load'):
    a = seviri.isel(time=0).load() # seviri like zarr dataset

with benchmark('seviri rename'):
    nms = a.channel.values
    nms_trans = [trans[x] for x in nms]
    a['channel'] = nms_trans
    x_vars_available = [x for x in config.x_vars if x in nms_trans]
    a = a.sel(channel=x_vars_available) \
    .rename({
        'y':'lat',
        'x':'lon', })
    print(a)
    
    
with benchmark('no loop'):
    u = a.isel(channel=[1,3,4]).isel(lon=xarray.DataArray(i, dims=['sample']),
                lat=xarray.DataArray(j, dims=['sample']))
    print(u)
    u = u.values
    
with benchmark('loop'):
    v = []
    for k, h in zip(i, j):
        s = a.isel(channel=[1]).isel(lat=k).isel(lon=h).to_dataarray().values
        v.append(s)
    v = np.stack(v, axis=0)

print(u.shape, v.shape)
print((u==v).all())

In [None]:
i= np.random.randint(0, 400, size=5000) # lon samples
j = np.random.randint(0, 400, size=5000) # lat samples
slice_i = [list(range(k, k+15)) for k in i]
slice_j = [list(range(k, k+15)) for k in j]
slice_i = xarray.DataArray(slice_i, dims=['sample', 'lat'])
slice_j = xarray.DataArray(slice_j, dims=['sample', 'lon'])

slice_ij, slice_ji = xarray.broadcast(slice_i, slice_j) # to sample x lat x lon and sample x lon x lat


a = sarah[['SIS','SID']].isel(time=0).load() # seviri like zarr dataset

with benchmark('no loop'):
    u = a.isel(lon=slice_ij,
                lat=slice_ji,).to_dataarray().values
    
with benchmark('loop'):
    v = []
    X = []
    for i in range(len(i)):
        X_ = a.isel(lat=slice_i[i],
                    lon=slice_j[i]).to_dataarray().values
        X.append(np.expand_dims(X_, 1))
    X = np.concatenate(X, axis = 1)
print(u.shape, X.shape)
print((u==X).all())

# dataset testing

In [18]:
dataset = SeviriDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    patches_per_image=1000,
    seed =0,
)

In [26]:
X, x, y = dataset[0]

In [27]:
X2, x2, y2 = dataset[0]

In [30]:
(y==y2).all()

tensor(True)

In [None]:
with benchmark('asdf'):
    X, x, y = dataset[0]
    print(X, x, y)

In [None]:
def my_collate_fn(batch):
    

    X= torch.concat([x[0] for x in batch], dim=0)
    x= torch.concat([x[1] for x in batch], dim=0)
    y= torch.concat([x[2] for x in batch], dim=0)
    return X, x, y

dl = DataLoader(dataset, batch_size =None, num_workers =1, shuffle = True)

In [None]:
for X, x, y in tqdm(dl):
    print(X.shape, x.shape, y.shape)
    pass

In [None]:
stations = ['CAB', ]#'CAR', 'CEN' ,'MIL', 'NOR', 'PAL', 'PAY', 'TAB', 'TOR', 'VIS']

test_datasets = [GroundstationDataset(nm, 
                                config.y_vars, 
                                config.x_vars, 
                                config.x_features, 
                                patch_size=config.patch_size['x'],
                                transform=config.transform,
                                target_transform=config.target_transform) 
            for nm in stations] 

test_dataloaders = {nm: DataLoader(ds, batch_size=10000, shuffle=False) for nm, ds in zip(stations, test_datasets)}