In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
import wandb
from dataset.dataset import MSGDataModule, MSGDataModulePoint
from dataset.normalization import MinMax
from models.ConvResNet_Jiang import ConvResNet
from models.FNO import FNO2d
from models.LightningModule import LitEstimator, LitEstimatorPoint
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.plugins.environments import SLURMEnvironment

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from train import config
from utils.plotting import best_worst_plot, prediction_error_plot

In [5]:
# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from train import config
from utils.plotting import best_worst_plot, prediction_error_plot

# Training Hyperparameters
INPUT_SIZE = (15, 15)
INPUT_OVERLAP = {
    "lat": 12,
    "lon": 12,
}  # dataset sampled with overlapping (15,15) patches
LEARNING_RATE = 0.001
BATCH_SIZE = 512
NUM_EPOCHS = 20
MIN_EPOCHS = 1
MAX_EPOCHS = 30

# Dataset
# DATA_DIR
NUM_WORKERS = 12

# Compute related
ACCELERATOR = "gpu"
DEVICES = -1
NUM_NODES = 1
STRATEGY = "ddp_notebook"
PRECISION = 32

In [6]:
dm = MSGDataModulePoint(
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    patch_size=INPUT_SIZE,
    # input_overlap= INPUT_OVERLAP,
    x_vars=["channel_1"],
    transform=MinMax(),
    target_transform=MinMax(),
)

model = ConvResNet(num_attr=5)

estimator = LitEstimatorPoint(
    model=model,
    learning_rate=config.LEARNING_RATE,
    dm=dm,
)

/scratch/snx3000/kschuurm/pytorch/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [7]:
wandb_logger = WandbLogger(project="SIS_estimation")

trainer = Trainer(
    profiler="simple",
    fast_dev_run=True,
    # callbacks=[DeviceStatsMonitor(cpu_stats=true)]
    num_sanity_val_steps=2,
    logger=wandb_logger,
    accelerator=ACCELERATOR,
    devices=DEVICES,
    # num_nodes=NUM_NODES,
    # strategy=STRATEGY,
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS,
    precision=PRECISION,
    log_every_n_steps=200,
)

/scratch/snx3000/kschuurm/pytorch/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/pytorch/lib/python3.9/site ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


In [None]:
trainer.fit(model=estimator, train_dataloaders=dm)

/scratch/snx3000/kschuurm/pytorch/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/pytorch/lib/python3.9/site ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                 | Params
-------------------------------------------------------
0 | model         | ConvResNet           | 6.4 M 
1 | metric        | RelativeSquaredError | 0     
2 | other_metrics | MetricCollection     | 0     
-------------------------------------------------------
6.4 M     Trainable params
0         Non-trainable params
6.4 M     Total params
25.595    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
y_hat, y = trainer.predict(model=estimator, dataloaders=dm.val_dataloader())[0]
error = torch.mean((y - y_hat) ** 2 / y**2, dim=(1, 2))
idxmin = error.argmin()
idxmax = error.argmax()
idxminbatch = int(torch.floor(idxmin / dm.batch_size))
idxmaxbatch = int(torch.floor(idxmax / dm.batch_size))

In [None]:
fig = prediction_error_plot(y, y_hat)
trainer.logger.log_image(key="Prediction error", images=[fig])