In [4]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from dataset.dataset import SeviriDataset, pickle_read
from dataset.station_dataset import GroundstationDataset
from dataset.normalization import ZeroMinMax
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from models.ConvResNet_Jiang import ConvResNet_batchnormMLP
from models.FCN import residual_FCN
from models.LightningModule import LitEstimatorPoint
from tqdm import tqdm
import xarray

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from train import get_dataloaders

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
from types import SimpleNamespace

config = {
    "batch_size": 2048,
    "patch_size": {
        "x": 15,
        "y": 15,
        "stride_x": 1,
        "stride_y": 1,
    },
    "x_vars": [
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "channel_8",
        "channel_9",
        "channel_10",
        "channel_11",
        "DEM",
    ],
    "y_vars": ["SIS"],
    "x_features": ["dayofyear", "lat", "lon", 'SZA', "AZI"],
    "transform": ZeroMinMax(),
    "target_transform": ZeroMinMax(),
    # Compute related
    'ACCELERATOR': "gpu",
    'DEVICES': -1,
    'NUM_NODES': 1,
    # 'STRATEGY': "ddp",
    'PRECISION': "32",
    'num_workers':12,
}
config = SimpleNamespace(**config)

In [27]:
sarah_nulls = xarray.open_zarr('/scratch/snx3000/kschuurm/ZARR/SARAH3_nulls.zarr')
timeindex = sarah_nulls['any'].where((sarah_nulls['nullssum'] > 5000).compute(), drop=True).time.values
timeindex = pd.DatetimeIndex(timeindex)
# timeindex = timeindex[(timeindex.hour >10) & (timeindex.hour <17)]
traintimeindex = timeindex[(timeindex.year <= 2021)]
validtimeindex = timeindex[(timeindex.year  == 2022)]

train_dataset = SeviriDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    patches_per_image=config.batch_size,
    timeindices=traintimeindex,
)
valid_dataset = SeviriDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    patches_per_image=config.batch_size,
    timeindices=validtimeindex,
    seed=0,
)

sampler setup : 0.271 seconds


In [None]:
X, x, y = train_dataset[0]
print(X.shape, x.shape, y.shape)


In [28]:

early_stopping = EarlyStopping('val_loss')


wandb_logger = WandbLogger(name='Emulator', project="SIS_point_estimation", log_model=True)

if rank_zero_only.rank == 0:  # only update the wandb.config on the rank 0 process
    wandb_logger.experiment.config.update(vars(config))

mc_sarah = ModelCheckpoint(
    every_n_epochs=1, save_top_k = 3
) 

# trainer_station = Trainer(
#     # profiler="simple",
#     # fast_dev_run=True,
#     # num_sanity_val_steps=2,
#     logger=wandb_logger,
#     accelerator=ACCELERATOR,
#     devices=DEVICES,
#     min_epochs=1,
#     max_epochs=100,
#     precision=PRECISION,
#     log_every_n_steps=500,
#     check_val_every_n_epoch=5,
#     callbacks=[early_stopping, mc],
# )


trainer_sarah = Trainer(
    logger=wandb_logger,
    accelerator=config.ACCELERATOR,
    devices=config.DEVICES,
    min_epochs=1,
    max_epochs=35,
    precision=config.PRECISION,
    log_every_n_steps=500,
    # val_check_interval=1,
    callbacks=[early_stopping, mc_sarah],
    max_time="00:05:00:00"
)

/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/lightning/pytorch/loggers/wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


MisconfigurationException: ModelCheckpoint(save_top_k=3, monitor=None) is not a valid configuration. No quantity for top_k to track.

In [None]:
train_dataloaders = DataLoader(train_dataset, shuffle=True, batch_size=None, num_workers=config.num_workers)
valid_dataloaders = DataLoader(valid_dataset, shuffle=False, batch_size=None, num_workers=config.num_workers)

In [26]:

estimator = LitEstimatorPoint(
    config=config,
    learning_rate=0.0001,
)
trainer_sarah.fit(
    estimator, train_dataloaders=train_dataloaders, val_dataloaders=valid_dataloaders,
    # ckpt_path='/scratch/snx3000/kschuurm/irradiance_estimation/train/SIS_point_estimation/tt2pie1v/checkpoints/epoch=30-step=4020.ckpt'
)

[autoreload of dataset.dataset failed: Traceback (most recent call last):
  File "/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 500, in superreload
    update_generic(old_obj, new_obj)
  File "/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 365, in update_class
    update_instances(old, new)
  File "/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 319, in update_instances
    refs = gc.get_referrers(old)
KeyboardInterrupt
]
/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


2022-08-18T17:00:00.000000000


IndexError: index 697 is out of bounds for axis 1 with size 658

In [None]:
stations = ['CAB', 'CAR', 'CEN' ,'MIL', 'NOR', 'PAL', 'PAY', 'TAB', 'TOR', 'VIS']

test_datasets = [GroundstationDataset(nm, 
                                      config.y_vars, 
                                      config.x_vars, 
                                      config.x_features, 
                                      patch_size=15,
                                      transform=config.transform,
                                      target_transform=config.target_transform) 
                 for nm in stations] 

In [None]:
test_dataset = torch.utils.data.ConcatDataset(test_datasets)

train_ds, valid_ds = random_split(test_dataset, [0.7, 0.3])

train_dl = DataLoader(train_ds, batch_size=2048, shuffle=True, num_workers= 5)
valid_dl = DataLoader(valid_ds, batch_size=2048, shuffle=False, num_workers=5)