In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch, wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger


from dataset.dataset import MSGDataModule, MSGDataModulePoint
from dataset.normalization import MinMax
from models.FNO import FNO2d
from models.ConvResNet_Jiang import ConvResNet
from models.LightningModule import LitEstimator, LitEstimatorPoint


# from pytorch_lightning.pytorch.callbacks import DeviceStatsMonitor
from train import config
from utils.plotting import best_worst_plot, prediction_error_plot

In [None]:
# model = FNO2d(
#     modes=(16, 16),
#     input_channels=config.INPUT_CHANNELS,
#     output_channels=config.OUTPUT_CHANNELS,
#     channels=10,
# )
# wandb_logger = WandbLogger(project="SIS_estimation")
# dm = MSGDataModule(
#     batch_size=config.BATCH_SIZE,
#     num_workers=config.NUM_WORKERS,
#     patch_size=config.INPUT_SIZE,
# )

dm = MSGDataModulePoint(
    batch_size=config.BATCH_SIZE,
    num_workers=config.NUM_WORKERS,
    patch_size=config.INPUT_SIZE,
    x_vars=['channel_1'],
    transform=MinMax(),
    target_transform=MinMax(),
)

model = ConvResNet(num_attr=5)

estimator = LitEstimatorPoint(
    model=model,
    learning_rate=config.LEARNING_RATE,
    dm=dm,
)

In [None]:
wandb_logger = WandbLogger(project="SIS_estimation")

trainer = Trainer(
    profiler="simple",
    fast_dev_run=True,
    # callbacks=[DeviceStatsMonitor(cpu_stats=true)]
    num_sanity_val_steps=2,
    logger=wandb_logger,
    accelerator=config.ACCELERATOR,
    devices=config.DEVICES,
    min_epochs=config.MIN_EPOCHS,
    max_epochs=config.MAX_EPOCHS,
    precision=config.PRECISION,
    log_every_n_steps=200,
)

In [None]:
trainer.fit(model=estimator, train_dataloaders=dm)

In [None]:

y_hat, y = trainer.predict(model = estimator, dataloaders=dm.val_dataloader())[0]
error = torch.mean((y-y_hat)**2/y**2, dim=(1,2))
idxmin = error.argmin()
idxmax = error.argmax()
idxminbatch = int(torch.floor(idxmin/dm.batch_size))
idxmaxbatch = int(torch.floor(idxmax/dm.batch_size))


In [None]:
_, y_array = dm.val_dataset.get_xarray_batch(idxminbatch)
fig1 = best_worst_plot(y_array.isel(sample=int(idxmin % dm.batch_size)), 
                      y_hat[idxmin, :, :], 
                      loss=error[idxmin], 
                      metric='RelativeSquaredError',
                     best=True)

_, y_array = dm.val_dataset.get_xarray_batch(idxmaxbatch)
fig2 = best_worst_plot(y_array.isel(sample=int(idxmax % dm.batch_size)), 
                      y_hat[idxmin, :, :], 
                      loss=error[idxmax].numpy(), 
                      metric='RelativeSquaredError',
                     best=False)

trainer.logger.log_image(key='Best patch', images=[fig1])
trainer.logger.log_image(key='Worst patch', images=[fig2])


In [None]:
fig3 = prediction_error_plot(y, y_hat)
trainer.logger.log_image(key='Prediction error', images=[fig3])