In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import wandb
import xarray
from dataset.dataset import ImageDataset, valid_test_split
from dataset.normalization import MinMax
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.plugins.environments import SLURMEnvironment
from lightning.pytorch.utilities import rank_zero_only
from models.ConvResNet_Jiang import ConvResNet, ConvResNet_dropout
from models.LightningModule import LitEstimator, LitEstimatorPoint
from tqdm import tqdm

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from utils.plotting import best_worst_plot, prediction_error_plot

In [2]:
from types import SimpleNamespace

config = {
    "batch_size": 2048,
    "patch_size": {
        "x": 15,
        "y": 15,
        "stride_x": 2,
        "stride_y": 2,
    },
    "x_vars": [
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "channel_8",
        "channel_9",
        "channel_10",
        "channel_11",
        "DEM",
    ],
    "y_vars": ["SIS"],
    "x_features": ["dayofyear", "lat", "lon", "SZA", "AZI"],
    "transform": MinMax(),
    "target_transform": MinMax(),
}
config = SimpleNamespace(**config)

In [3]:
timeindex = np.load("/scratch/snx3000/kschuurm/ZARR/idxnotnan.npy")
timeindex = pd.DatetimeIndex(timeindex)
traintimeindex = timeindex[timeindex.year == 2016]
_, validtimeindex = valid_test_split(timeindex[timeindex.year == 2017])

train_dataset = ImageDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size=config.patch_size,
    transform=config.transform,
    target_transform=config.target_transform,
    timeindices=traintimeindex[::5],
    random_sample=0.5,
    batch_in_time=6,
)
valid_dataset = ImageDataset(
    x_vars=config.x_vars,
    y_vars=config.y_vars,
    x_features=config.x_features,
    patch_size={
        "x": config.patch_size["x"],
        "y": config.patch_size["y"],
        "stride_x": 5,
        "stride_y": 5,
    },
    transform=config.transform,
    target_transform=config.target_transform,
    timeindices=validtimeindex[::5],
    random_sample=None,
    batch_in_time=5,
)


model = ConvResNet(
    num_attr=len(config.x_features),
    input_channels=len(config.x_vars),
    output_channels=len(config.y_vars),
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=0
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=0
)

In [4]:
# Compute related
ACCELERATOR = "gpu"
DEVICES = -1
NUM_NODES = 1
# STRATEGY = "ddp"
PRECISION = "16"

wandb_logger = WandbLogger(project="SIS_point_estimation", log_model=True)

if rank_zero_only.rank == 0:  # only update the wandb.config on the rank 0 process
    wandb_logger.experiment.config.update(vars(config))

trainer = Trainer(
    # profiler="simple",
    # fast_dev_run=True,
    # num_sanity_val_steps=2,
    logger=wandb_logger,
    accelerator=ACCELERATOR,
    devices=DEVICES,
    min_epochs=1,
    max_epochs=1,
    precision=PRECISION,
    log_every_n_steps=500,
    val_check_interval=0.125,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkrschuurman[0m. Use [1m`wandb login --relogin`[0m to force relogin


/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/lightning/lib/python3.9/si ...
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
estimator = LitEstimatorPoint(
    model=model,
    learning_rate=0.001,
    config=config,
)
trainer.fit(
    estimator, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader
)

/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/lightning/lib/python3.9/si ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ConvResNet       | 6.4 M 
1 | metric        | MeanSquaredError | 0     
2 | other_metrics | MetricCollection | 0     
---------------------------------------------------
6.4 M     Trainable params
0    

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/scratch/snx3000/kschuurm/lightning/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
from dataset.station_dataset import GroundstationDataset
try:
    stations = ['CAB', 'CAR', 'CEN' ,'MIL', 'NOR', 'PAL', 'PAY', 'TAB', 'TOR', 'VIS']
    test_datasets = [GroundstationDataset(nm, 
                                          config.y_vars, 
                                          config.x_vars, 
                                          config.x_features, 
                                          patch_size=15,
                                          transform=config.transform,
                                          target_transform=config.target_transform) 
                     for nm in stations]                             
    test_dataset = torch.utils.data.ConcatDataset(test_datasets)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=10000, shuffle=False, num_workers=0
    )
    trainer.test(dataloaders=test_dataloader)
except:
    print('failed')
finally:
    wandb.finish()