In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch, wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger


from dataset.dataset import MSGDataModule
from models.FNO import FNO2d
from models.LightningModule import LitEstimator

# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from train import config
from utils.plotting import best_worst_plot, prediction_error_plot

In [10]:
# wandb_logger = WandbLogger(project="SIS_estimation")
dm = MSGDataModule(
    batch_size=config.BATCH_SIZE,
    num_workers=config.NUM_WORKERS,
    patch_size=config.INPUT_SIZE,
)

model = FNO2d(
    modes=(16, 16),
    input_channels=config.INPUT_CHANNELS,
    output_channels=config.OUTPUT_CHANNELS,
    channels=10,
)

estimator = LitEstimator(
    model=model,
    learning_rate=config.LEARNING_RATE,
    dm=dm,
)

/scratch/snx3000/kschuurm/pytorch/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [11]:
wandb_logger = WandbLogger(project="SIS_estimation")

trainer = Trainer(
    profiler="simple",
    # fast_dev_run=True,
    # callbacks=[DeviceStatsMonitor(cpu_stats=true)]
    num_sanity_val_steps=2,
    logger=wandb_logger,
    accelerator=config.ACCELERATOR,
    devices=config.DEVICES,
    min_epochs=config.MIN_EPOCHS,
    max_epochs=config.MAX_EPOCHS,
    precision=config.PRECISION,
    log_every_n_steps=200,
)

/scratch/snx3000/kschuurm/pytorch/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/pytorch/lib/python3.9/site ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=estimator, train_dataloaders=dm)

/scratch/snx3000/kschuurm/pytorch/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/pytorch/lib/python3.9/site ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type                 | Params
------------------------------------------------
0 | model  | FNO2d                | 104 K 
1 | metric | RelativeSquaredError | 0     
------------------------------------------------
104 K     Trainable params
0         Non-trainable params
104 K     Total params
0.418     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:

y_hat, y = trainer.predict(model = estimator, dataloaders=dm.val_dataloader())[0]
error = torch.mean((y-y_hat)**2/y**2, dim=(1,2))
idxmin = error.argmin()
idxmax = error.argmax()
idxminbatch = int(torch.floor(idxmin/dm.batch_size))
idxmaxbatch = int(torch.floor(idxmax/dm.batch_size))


In [None]:
_, y_array = dm.val_dataset.get_xarray_batch(idxminbatch)
fig1 = best_worst_plot(y_array.isel(sample=int(idxmin % dm.batch_size)), 
                      y_hat[idxmin, :, :], 
                      loss=error[idxmin], 
                      metric='RelativeSquaredError',
                     best=True)

_, y_array = dm.val_dataset.get_xarray_batch(idxmaxbatch)
fig2 = best_worst_plot(y_array.isel(sample=int(idxmax % dm.batch_size)), 
                      y_hat[idxmin, :, :], 
                      loss=error[idxmax].numpy(), 
                      metric='RelativeSquaredError',
                     best=False)

trainer.logger.log_image(key='Best patch', images=[fig1])
trainer.logger.log_image(key='Worst patch', images=[fig2])


In [None]:
fig3 = prediction_error_plot(y, y_hat)
trainer.logger.log_image(key='Prediction error', images=[fig3])