In [None]:
import torch
import torch.nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import pytorch_lightning as pl

from torch_enhance.datasets import BSDS300, Set14, Set5
from torch_enhance.models import MYSRCNN, SRCNN, MYSRCNN2, MYSRCNN3
from torch_enhance import metrics

from torchvision.utils import save_image

import cv2
import matplotlib.pyplot as plt

In [None]:
class Module(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")
        
        # metrics
        mae = metrics.mae(sr, hr)
        psnr = metrics.psnr(sr, hr)

        # Logs
        self.log("train_loss", loss)
        self.log("train_mae", mae)
        self.log("train_psnr", psnr)

        return loss

    def validation_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")
        
        # metrics
        mae = metrics.mae(sr, hr)
        psnr = metrics.psnr(sr, hr)

        # Logs
        self.log("val_loss", loss)
        self.log("val_mae", mae)
        self.log("val_psnr", psnr)

        return loss

    def test_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")
        
        # metrics
        mae = metrics.mae(sr, hr)
        psnr = metrics.psnr(sr, hr)

        # Logs
        self.log("test_loss", loss)
        self.log("test_mae", mae)
        self.log("test_psnr", psnr)

        return loss

In [None]:
# Hyper parameters
scale_factor = 1
channels = 3

# Setup dataloaders
train_dataset = BSDS300(scale_factor=scale_factor)
val_dataset = Set14(scale_factor=scale_factor)
test_dataset = Set5(scale_factor=scale_factor)
train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=4)

# Define model
model = MYSRCNN3(scale_factor, channels)
module = Module(model)

In [None]:
# Train
torch.set_float32_matmul_precision('high')
trainer = pl.Trainer(
	max_epochs=200,
	accelerator="gpu"
)
trainer.fit(
	module,
	train_dataloader,
	val_dataloader
)
trainer.test(module, test_dataloader)

In [None]:
# Uncomment to save the current model
# torch.save(model.state_dict(), "model_path.pt")

In [None]:
# Uncomment to load model from path
# model = MYSRCNN3(scale_factor, channels)
# model.load_state_dict(torch.load("SRCNN_weight_400_3.pt"))
# model.eval()