In [None]:
# pip install lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=0.4.0->lpips)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=0.4.0->lpips)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=0.4.0->lpips)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=0.4.0->lpips)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=0.4.0->lpips)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=0.4.0->lpips)
  Using cached nvidia_cufft_cu12-11

# Imports

In [None]:
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose

import torchvision.models as models
from torchvision.models import vgg19
from torchvision.transforms import Normalize
from torch.cuda.amp import GradScaler, autocast
from numba import cuda

from skimage.metrics import structural_similarity as SSIM

import lpips

from google.colab import drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


# Downsample HR images into LR

In [None]:
def downsample_image(image_path, output_path, scale_factor=4):
    with Image.open(image_path) as img:
        # Convert RGBA to RGB if necessary
        if img.mode == 'RGBA':
            img = img.convert('RGB')
        lr_img = img.resize((img.width // scale_factor, img.height // scale_factor), Image.BICUBIC)
        lr_img.save(output_path, 'JPEG')

In [None]:
# Paths
#hr_folder = 'data/train/HR/'
#lr_folder = 'data/train/LR/'

hr_folder = '/content/drive/My Drive/plates_data/HR'
lr_folder = '/content/drive/My Drive/plates_data/LR'
os.makedirs(lr_folder, exist_ok=True)

# Downsample HR images to create LR images
count = 1
for filename in os.listdir(hr_folder):
    print(count, filename)
    count+=1
    if filename.endswith(".png") or filename.endswith(".jpg"):
        hr_image_path = os.path.join(hr_folder, filename)
        lr_image_path = os.path.join(lr_folder, filename)
        downsample_image(hr_image_path, lr_image_path, scale_factor=4)

1 1c922c98065a78.jpg
2 2b16c327561485.jpg
3 734b3bef43666c.jpg
4 342190cf1ce0b7.jpg
5 6fad048bbe77a0.jpg
6 6abbe67d883a47.jpg
7 29e9db79b78874.jpg
8 22dc35cfbd14c6.jpg
9 67105b07c148a3.jpg
10 0c7e8f2c6b3912.jpg
11 89ee46c137c513.jpg
12 64c56b57295252.jpg
13 2a851c5e267456.jpg
14 93c7d04f62e718.jpg
15 1bcd1cadaf6509.jpg
16 5f858781b96ab1.jpg
17 65b09397bdddf3.jpg
18 4072a4b7eac635.jpg
19 8f6068906b9ea1.jpg
20 16ba80fb501288.jpg
21 632e00afd37b16.jpg
22 26c35dfb79b9d0.jpg
23 343102b105ed6f.jpg
24 1716ad8450b0c7.jpg
25 6ffcbdcd75ffb4.jpg
26 7ce544eed8ab14.jpg
27 8ec4af5c6027e9.jpg
28 3b8d08bd62aff7.jpg
29 753a1c8859efec.jpg
30 8aaea606b62b45.jpg
31 67c6ac36263c84.jpg
32 0be33b642a7565.jpg
33 55095b6a3d0dd9.jpg
34 0f5beb8cb24494.jpg
35 4eb6b3f8b9f076.jpg
36 4fe57c69ffd6eb.jpg
37 18ea325f8446ee.jpg
38 b4a1585dd817d9.jpg
39 bdfe54bd62181b.jpg
40 de2097605e1da9.jpg
41 f5c34510870189.jpg
42 a092938179f40f.jpg
43 a04c890f6b936a.jpg
44 d32fe872ac2d3b.jpg
45 9cdce10764ea38.jpg
46 b730d7641dcd35.j

# SR Generator (GSR)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.prelu(self.conv1(x)))

class GSR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residuals=5):
        super(GSR, self).__init__()
        self.encoder = nn.Conv2d(in_channels, 64, kernel_size=5, padding=2)
        self.resnet = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residuals)])
        self.decoder = nn.Conv2d(64, out_channels, kernel_size=5, padding=2)
        self.upsample = nn.Upsample(scale_factor=4, mode='bicubic')

    def forward(self, x):
        x = self.upsample(x)
        x = self.encoder(x)
        x = self.resnet(x)
        x = self.decoder(x)
        return torch.clamp(x, 0, 1)

# Discriminators (Dx and Dy)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
             self._block(64, 128, 3, 1),
             self._block(128, 256, 3, 1),
             self._block(256, 512, 3, 1),
            nn.Conv2d(512, 1, kernel_size=3, padding=1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.net(x)

# LR Generator (GLR)

In [None]:
class GLR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GLR, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
             self._block(64, 128, 3, 2),
             self._block(128, 256, 3, 2),
            nn.Conv2d(256, out_channels, kernel_size=3, padding=1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.net(x)

# Loss Functions

In [None]:
class SRResCycGANLosses:
    def __init__(self, vgg):
        self.vgg = vgg
        self.l1_loss = nn.L1Loss()
        self.tv_loss = TVLoss()

    def perceptual_loss(self, sr, hr):
        sr_vgg = self.vgg((sr - vgg_mean.to(sr.device)) / vgg_std.to(sr.device))
        hr_vgg = self.vgg((hr - vgg_mean.to(hr.device)) / vgg_std.to(hr.device))
        return self.l1_loss(sr_vgg, hr_vgg)

    def gan_loss(self, sr_pred, hr_pred):
        return F.binary_cross_entropy_with_logits(sr_pred, torch.ones_like(sr_pred)) + \
               F.binary_cross_entropy_with_logits(hr_pred, torch.zeros_like(hr_pred))

    def total_variation_loss(self, sr):
        return self.tv_loss(sr)

    def content_loss(self, sr, hr):
        return self.l1_loss(sr, hr)

    def cyclic_loss(self, lr_recon, lr):
        return self.l1_loss(lr_recon, lr)

    def total_loss(self, sr, hr, lr_recon, lr, sr_pred, hr_pred):
        # Resize tensors to ensure shape compatibility
        hr_resized = F.interpolate(hr, size=sr.size()[2:], mode='bicubic', align_corners=False)
        lr_resized = F.interpolate(lr, size=lr_recon.size()[2:], mode='bicubic', align_corners=False)

        l_per = self.perceptual_loss(sr, hr_resized)
        l_gan = self.gan_loss(sr_pred, hr_pred)
        l_tv = self.total_variation_loss(sr)
        l_content = self.content_loss(sr, hr_resized)
        l_cyclic = self.cyclic_loss(lr_recon, lr_resized)

        return l_per + l_gan + l_tv + 10 * l_content + 10 * l_cyclic

In [None]:
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:, :, 1:, :])
        count_w = self._tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    def _tensor_size(self, t):
        return t.size()[1] * t.size()[2] * t.size()[3]

# Training the Model

In [None]:
# Initialize scaler
scaler = GradScaler()

# Define the training loop
def train_srrescycgan(generator_sr, generator_lr, discriminator_hr, discriminator_lr, data_loader, optimizer_g, optimizer_d, losses, num_epochs):
    scaler = GradScaler()

    for epoch in range(num_epochs):
        for i, (lr, hr) in enumerate(data_loader):
            lr, hr = lr.to(device), hr.to(device)

            # Train Generators
            optimizer_g.zero_grad()
            with autocast():
                sr = generator_sr(lr)
                lr_recon = generator_lr(sr)
                sr_pred = discriminator_hr(sr)
                hr_pred = discriminator_hr(hr)

#                print(f"lr shape: {lr.shape}, hr shape: {hr.shape}")
#                print(f"sr shape: {sr.shape}, lr_recon shape: {lr_recon.shape}")
#                print(f"sr_pred shape: {sr_pred.shape}, hr_pred shape: {hr_pred.shape}")

                loss_g = losses.total_loss(sr, hr, lr_recon, lr, sr_pred, hr_pred)

            scaler.scale(loss_g).backward()
            scaler.step(optimizer_g)
            scaler.update()

            # Train Discriminators
            optimizer_d.zero_grad()
            with autocast():
                sr_pred = discriminator_hr(sr.detach())
                hr_pred = discriminator_hr(hr)
                loss_d_hr = losses.gan_loss(sr_pred, hr_pred)

                lr_pred = discriminator_lr(lr_recon.detach())
                lr_real_pred = discriminator_lr(lr)
                print(f"lr_pred shape: {lr_pred.shape}, lr_real_pred shape: {lr_real_pred.shape}")

                loss_d_lr = losses.gan_loss(lr_pred, lr_real_pred)

                loss_d = loss_d_hr + loss_d_lr

            scaler.scale(loss_d).backward()
            scaler.step(optimizer_d)
            scaler.update()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss G: {loss_g.item()}, Loss D: {loss_d.item()}")


# Load Dataset

In [None]:
class SRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = sorted([file for file in os.listdir(lr_dir) if file.endswith(('.png', '.jpg', '.jpeg'))])
        self.hr_images = sorted([file for file in os.listdir(hr_dir) if file.endswith(('.png', '.jpg', '.jpeg'))])
        self.transform = transform

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_image_path = os.path.join(self.hr_dir, self.hr_images[idx])

        lr_image = Image.open(lr_image_path).convert('RGB')
        hr_image = Image.open(hr_image_path).convert('RGB')

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image

In [None]:
# Paths
lr_dir = '/content/drive/My Drive/plates_data/LR'
hr_dir = '/content/drive/My Drive/plates_data/HR'

# Dataset and DataLoader
# Define a transform to resize images to 256x256 and convert to tensor
transform = Compose([
    Resize((256, 256)),
    ToTensor()
])

dataset = SRDataset(lr_dir, hr_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=3, shuffle=True) # Build Batch

# Example usage
for lr, hr in data_loader:
    print(lr.shape, hr.shape)

torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3, 3, 256, 256]) torch.Size([3, 3, 256, 256])
torch.Size([3,

# Add VGG19

In [None]:
# Define VGG19 model for perceptual loss
class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        vgg = models.vgg19(pretrained=True)
        self.features = nn.Sequential(*list(vgg.features.children())[:36]).eval()  # Use first 36 layers

        # Freeze parameters
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)

In [None]:
# Normalize input images for VGG19
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
vgg_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

# Run Training

In [None]:
# Model, Optimizer, and Losses
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator_sr = GSR().to(device)
generator_lr = GLR().to(device)
discriminator_hr = Discriminator().to(device)
discriminator_lr = Discriminator().to(device)
vgg = VGG19().to(device)
optimizer_g = torch.optim.Adam(list(generator_sr.parameters()) + list(generator_lr.parameters()), lr=0.0002, betas=(0.9, 0.999))
optimizer_d = torch.optim.Adam(list(discriminator_hr.parameters()) + list(discriminator_lr.parameters()), lr=0.00005, betas=(0.9, 0.999))
losses = SRResCycGANLosses(vgg=vgg)

# Train the model
train_srrescycgan(generator_sr, generator_lr, discriminator_hr, discriminator_lr, data_loader, optimizer_g, optimizer_d, losses, num_epochs=10)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 198MB/s]


lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred shape: torch.Size([3, 1, 256, 256])
lr_pred shape: torch.Size([3, 1, 128, 128]), lr_real_pred sh

In [None]:
# Save the generator and discriminator weights
torch.save(generator_sr.state_dict(), 'data/weights/generator_sr.pth')
torch.save(generator_lr.state_dict(), 'data/weights/generator_lr.pth')
torch.save(discriminator_hr.state_dict(), 'data/weights/discriminator_hr.pth')
torch.save(discriminator_lr.state_dict(), 'data/weights/discriminator_lr.pth')

# Optionally, save the optimizer states as well
torch.save(optimizer_g.state_dict(), 'data/weights/optimizer_g.pth')
torch.save(optimizer_d.state_dict(), 'data/weights/optimizer_d.pth')

In [None]:
torch.cuda.empty_cache()

In [None]:
device = cuda.get_current_device()
device.reset()

# Calculate Metrics

## PSNR

In [None]:
def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

## SSIM

In [None]:
def calculate_ssim(img1, img2):
    img1_np = img1.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert to HWC format and NumPy array
    img2_np = img2.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert to HWC format and NumPy array
    ssim_value = ssim(img1_np, img2_np, multichannel=True, data_range=img1_np.max() - img1_np.min())
    return ssim_value

## LPIPS

In [None]:
def calculate_lpips(img1, img2):
    # Initialize LPIPS model
    loss_fn = lpips.LPIPS(net='alex')

    lpips_value = loss_fn(img1, img2)
    return lpips_value.item()