In [1]:
%reload_ext autoreload
%autoreload 2
%config Completer.use_jedi=False

In [2]:
import torch
import pytorch_lightning as pl
from data import *
# from pytorch_msssim import ssim
# from kornia.losses import PSNRLoss, SSIM
from pytorch_lightning.metrics.functional import psnr
from pytorch_lightning.metrics.regression import SSIM
from pytorch_lightning.loggers import TensorBoardLogger

In [4]:
class DeblurModelBase(pl.LightningModule):
    def __init__(self, data_module:DeblurDataModule, lr:float=0.001):
        self.data_module  = data_module
        self.transforms   = self.data_module.transforms
        self.loss_func    = torch.nn.MSELoss()
        self.metrics      = [(psnr, 'PSNR'), (SSIM(kernel_size=5), 'SSIM')]
        self.model_config = None
        self.lr           = lr
        
    def training_step(self, batch, batch_idx):
        x, y   = augment_image_pair(batch, self.transforms)
        out    = self(x)
        loss   = self.loss_func(out, y)
        result = pl.TrainResult(loss)
        
        # Log results to progress bar and logger
        result.log('train_loss', loss, on_step=True, on_epoch=True,
                   progress_bar=True, logger=True, sync_dist=True)
        
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = augment_image_pair(batch, [self.transforms[0]]) # Only apply resize
        out  = self(x)
        loss = self.loss_func(out, y)
        result = pl.EvalResult(checkpoint_on=loss)
        log_dict = {'val_loss': loss}
        
        # Calculate metrics like PSNR, SSIM, etc.
        if self.metrics:
            out = self.data_module.denormalize_func(out).clamp(0.0, 1.0)
            y   = self.data_module.denormalize_func(y).clamp(0.0, 1.0)
            for metric in self.metrics:
                log_dict[metric[1]] = metric[0](out, y)
                
        result.log_dict(log_dict, on_step=False, on_epoch=True, 
                        prog_bar=True, logger=True, sync_dist=True)
        
        return result
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def forward(self, x):
        raise NotImplementedError

In [5]:
class SampleModel(DeblurModelBase):
    def __init__(self, data_module:DeblurDataModule, lr:float=0.001):
        super(SampleModel, self).__init__(data_module, lr)
        self.a = torch.nn.Parameter(data=torch.tensor(1.0))
    
    def forward(self, x):
        return self.a * x