In [5]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import wandb
import xarray
from dataset.dataset import ImageDataset, valid_test_split, SeviriDataset
from dataset.station_dataset import GroundstationDataset
from dataset.normalization import MinMax, ZeroMinMax
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.plugins.environments import SLURMEnvironment
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from models.ConvResNet_Jiang import ConvResNet, ConvResNet_dropout
from models.LightningModule import LitEstimator, LitEstimatorPoint
from tqdm import tqdm

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from utils.plotting import best_worst_plot, prediction_error_plot
from utils.etc import benchmark

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from dask.distributed import Client
client = Client()

/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/distributed/node.py:182: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 37387 instead


In [3]:
from types import SimpleNamespace

config = {
    "batch_size": 2048,
    "patch_size": {
        "x": 15,
        "y": 15,
        "stride_x": 10,
        "stride_y": 10,
    },
    "x_vars": [
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "channel_8",
        "channel_9",
        "channel_10",
        "channel_11",
        "DEM",
    ],
    "y_vars": ["SIS"],
    "x_features": ["dayofyear", "lat", "lon", 'SZA', "AZI"],
    "transform": ZeroMinMax(),
    "target_transform": ZeroMinMax(),
    # Compute related
    'ACCELERATOR': "gpu",
    'DEVICES': -1,
    'NUM_NODES': 1,
    # 'STRATEGY': "ddp",
    'PRECISION': "32",
}
config = SimpleNamespace(**config)

In [None]:
sarah_bnds = xarray.open_zarr('/scratch/snx3000/kschuurm/ZARR/SARAH3_bnds.zarr').load()
sarah_bnds = sarah_bnds.isel(time = sarah_bnds.pixel_count != -1)
seviri = xarray.open_zarr("/scratch/snx3000/kschuurm/ZARR/SEVIRI_new.zarr")
seviri_time = pd.DatetimeIndex(seviri.time)
timeindex= pd.DatetimeIndex(sarah_bnds.time)
timeindex = timeindex.intersection(seviri_time)
timeindex = timeindex[(timeindex.hour >10) & (timeindex.hour <17)]

traintimeindex = timeindex[(timeindex.year == 2016)]
_, validtimeindex = valid_test_split(timeindex[(timeindex.year == 2017)])

train_dataset = SeviriDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    patches_per_image=2048,
    timeindices=traintimeindex,
)
valid_dataset = SeviriDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    patches_per_image=1024,
    timeindices=validtimeindex,
    seed=0,
)

In [7]:
model = ConvResNet(
    num_attr=len(config.x_features),
    input_channels=len(config.x_vars),
    output_channels=len(config.y_vars),
)

In [9]:

early_stopping = EarlyStopping('val_loss')


wandb_logger = WandbLogger(project="SIS_point_estimation", log_model=True)

if rank_zero_only.rank == 0:  # only update the wandb.config on the rank 0 process
    wandb_logger.experiment.config.update(vars(config))

mc_station = ModelCheckpoint(
    every_n_epochs=10, save_top_k = -1
)
mc_sarah = ModelCheckpoint(
    every_n_epochs=1, save_top_k = -1
) 

# trainer_station = Trainer(
#     # profiler="simple",
#     # fast_dev_run=True,
#     # num_sanity_val_steps=2,
#     logger=wandb_logger,
#     accelerator=ACCELERATOR,
#     devices=DEVICES,
#     min_epochs=1,
#     max_epochs=100,
#     precision=PRECISION,
#     log_every_n_steps=500,
#     check_val_every_n_epoch=5,
#     callbacks=[early_stopping, mc],
# )


trainer_sarah = Trainer(
    logger=wandb_logger,
    accelerator=config.ACCELERATOR,
    devices=config.DEVICES,
    min_epochs=1,
    max_epochs=35,
    precision=config.PRECISION,
    log_every_n_steps=500,
    # val_check_interval=1,
    callbacks=[early_stopping, mc_sarah],
    max_time="00:02:00:00"
)

/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/pytorch/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/lightning/lib/python3.9/si ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
train_dataloaders = DataLoader(train_dataset, shuffle=True, batch_size=None, num_workers=0)
valid_dataloaders = DataLoader(valid_dataset, shuffle=False, batch_size=None, num_workers=0)

In [None]:

estimator = LitEstimatorPoint(
    model=model,
    learning_rate=0.0001,
    config=config,
)
trainer_sarah.fit(
    estimator, train_dataloaders=train_dataloaders, val_dataloaders=valid_dataloaders,
    # ckpt_path='/scratch/snx3000/kschuurm/irradiance_estimation/train/SIS_point_estimation/tt2pie1v/checkpoints/epoch=30-step=4020.ckpt'
)

In [None]:
stations = ['CAB', 'CAR', 'CEN' ,'MIL', 'NOR', 'PAL', 'PAY', 'TAB', 'TOR', 'VIS']

test_datasets = [GroundstationDataset(nm, 
                                      config.y_vars, 
                                      config.x_vars, 
                                      config.x_features, 
                                      patch_size=15,
                                      transform=config.transform,
                                      target_transform=config.target_transform) 
                 for nm in stations] 

In [None]:
test_dataset = torch.utils.data.ConcatDataset(test_datasets)

train_ds, valid_ds = random_split(test_dataset, [0.7, 0.3])

train_dl = DataLoader(train_ds, batch_size=2048, shuffle=True, num_workers= 5)
valid_dl = DataLoader(valid_ds, batch_size=2048, shuffle=False, num_workers=5)

In [None]:
estimator = LitEstimatorPoint(
    model=model,
    learning_rate=0.001,
    config=config,
)
trainer.fit(
    estimator, train_dataloaders=train_dl, val_dataloaders=valid_dl
)

In [None]:

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10000, shuffle=False, num_workers=0
)
trainer.test(dataloaders=test_dataloader)
wandb.finish()