In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from dataset.dataset import SeviriDataset, pickle_read, MemmapSeviriDataset
from dataset.station_dataset import GroundstationDataset
from dataset.normalization import ZeroMinMax
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from models.LightningModule import LitEstimatorPoint
from tqdm import tqdm
import xarray

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from train import get_dataloaders

In [2]:
from types import SimpleNamespace

config = {
    "batch_size": 512,
    "patch_size": {
        "x": 15,
        "y": 15,
        "stride_x": 1,
        "stride_y": 1,
    },
    "x_vars": [
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "channel_8",
        "channel_9",
        "channel_10",
        "channel_11",
        "DEM",
    ],
    "y_vars": ["SIS"],
    "x_features": ["dayofyear", "lat", "lon", 'SZA', "AZI"],
    "transform": ZeroMinMax(),
    "target_transform": ZeroMinMax(),
    # Compute related
    'ACCELERATOR': "gpu",
    'DEVICES': -1,
    'NUM_NODES': 1,
    # 'STRATEGY': "ddp",
    'PRECISION': "32",
    'num_workers':24,
    'val_check_interval': 0.1,
}
config = SimpleNamespace(**config)

In [18]:
from glob import glob
import os


# train_dataset = SeviriDataset(
#     x_vars=config.x_vars,
#     y_vars=config.y_vars,
#     x_features=config.x_features,
#     patch_size=config.patch_size,
#     transform=config.transform,
#     target_transform=config.target_transform,
#     patches_per_image=config.batch_size,
#     validation=False,
# )
train_dataset = MemmapSeviriDataset()
valid_dataset = SeviriDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    patches_per_image=2048,
    validation=True,
)

zarr_fns = glob('../../ZARR/IEA_PVPS/IEA_PVPS_*.zarr')
station_names_bsrn = [os.path.basename(fn).split('IEA_PVPS_')[-1].split('.')[0] for fn in zarr_fns]
bsrn_datasets = [GroundstationDataset(f'../../ZARR/IEA_PVPS/IEA_PVPS_{x}.zarr', 
                                        config.y_vars, config.x_vars, config.x_features, config.patch_size['x'], 
                                        config.transform, config.target_transform)
                            for x in tqdm(station_names_bsrn)]
bsrn_dataset = torch.utils.data.ConcatDataset(bsrn_datasets)


  0%|          | 0/14 [00:00<?, ?it/s][A
  7%|▋         | 1/14 [00:02<00:33,  2.54s/it][A
 14%|█▍        | 2/14 [00:03<00:19,  1.61s/it][A
 21%|██▏       | 3/14 [00:04<00:13,  1.26s/it][A
 29%|██▊       | 4/14 [00:05<00:12,  1.24s/it][A
 36%|███▌      | 5/14 [00:06<00:10,  1.20s/it][A
 43%|████▎     | 6/14 [00:07<00:08,  1.07s/it][A
 50%|█████     | 7/14 [00:08<00:07,  1.07s/it][A
 57%|█████▋    | 8/14 [00:10<00:07,  1.27s/it][A
 64%|██████▍   | 9/14 [00:11<00:05,  1.15s/it][A
 71%|███████▏  | 10/14 [00:12<00:04,  1.15s/it][A
 79%|███████▊  | 11/14 [00:13<00:03,  1.20s/it][A
 86%|████████▌ | 12/14 [00:15<00:02,  1.28s/it][A
 93%|█████████▎| 13/14 [00:15<00:01,  1.16s/it][A
100%|██████████| 14/14 [00:16<00:00,  1.21s/it][A


In [19]:

from lightning import LightningDataModule


class DataModule(LightningDataModule):

  def __init__(self, train_dataset, val_dataset,  bsrn_dataset, batch_size):

    super(DataModule, self).__init__()
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.bsrn_dataset = bsrn_dataset
    self.num_dataloaders = 1
    self.batch_size = batch_size
    
  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = None, shuffle = True, num_workers=12)
  
  def val_dataloader(self):
    # val_loader1 = DataLoader(self.val_dataset, batch_size = None, shuffle = False, num_workers=config.num_workers)
    val_loader2 = DataLoader(self.bsrn_dataset, batch_size = 4096, shuffle = False, num_workers=config.num_workers)
    return [val_loader2]

dm = DataModule(train_dataset, valid_dataset, bsrn_dataset, config.batch_size)


In [16]:

early_stopping = EarlyStopping('val_loss', patience=10, mode='min')


wandb_logger = WandbLogger(name='final', project="SIS_point_estimation", log_model=True)

if rank_zero_only.rank == 0:  # only update the wandb.config on the rank 0 process
    wandb_logger.experiment.config.update(vars(config))

mc_sarah = ModelCheckpoint(
    monitor='val_loss',
    every_n_epochs=1, save_top_k = 3,
    save_last=True,
) 


trainer_sarah = Trainer(
    logger=wandb_logger,
    accelerator=config.ACCELERATOR,
    devices=config.DEVICES,
    min_epochs=1,
    max_epochs=15,
    precision=config.PRECISION,
    log_every_n_steps=1000,
    val_check_interval=config.val_check_interval,
    callbacks=[early_stopping, mc_sarah],
    max_time="00:03:50:00"
)

/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/lightning-env/lib/python3. ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
from torchmetrics import MeanSquaredError


estimator = LitEstimatorPoint(
    config=config,
    learning_rate=1e-6,
    metric=MeanSquaredError()
)

trainer_sarah.fit(
    estimator, dm
)

In [None]:
train_dataset

In [15]:
wandb_logger.experiment.finish()



VBox(children=(Label(value='0.432 MB of 0.432 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁
loss_step,▃▁▅▄▃▄▂█▇▂▄
trainer/global_step,▁▂▂▃▄▅▅▆▇▇▇▇▇████

0,1
epoch,0.0
loss_step,0.20192
trainer/global_step,5499.0
