# Imports

In [1]:
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose

import torchvision.models as models
from torchvision.models import vgg19
from torchvision.transforms import Normalize
from torch.cuda.amp import GradScaler, autocast

# Downsample HR images into LR

In [3]:
def downsample_image(image_path, output_path, scale_factor=4):
    with Image.open(image_path) as img:
        # Convert RGBA to RGB if necessary
        if img.mode == 'RGBA':
            img = img.convert('RGB')
        lr_img = img.resize((img.width // scale_factor, img.height // scale_factor), Image.BICUBIC)
        lr_img.save(output_path, 'JPEG')

In [5]:
# Paths
hr_folder = 'data/train/HR/'
lr_folder = 'data/train/LR/'
os.makedirs(lr_folder, exist_ok=True)

# Downsample HR images to create LR images
count = 1
for filename in os.listdir(hr_folder):
    print(count, filename)
    count+=1
    if filename.endswith(".png") or filename.endswith(".jpg"):
        hr_image_path = os.path.join(hr_folder, filename)
        lr_image_path = os.path.join(lr_folder, filename)
        downsample_image(hr_image_path, lr_image_path, scale_factor=4)

1 00d85df95f49fc.jpg
2 00f88bb445ed07.jpg
3 06daac6f0ec143.jpg
4 0eb65a862a018c.jpg
5 10f8ccb1cae577.jpg
6 11480a72b92855.jpg
7 114bc748f3bf99.jpg
8 16c6fd2cf1f624.jpg
9 173e7c2ae39b17.jpg
10 17859c46b3b92c.jpg
11 1bf005bbaf7b94.jpg
12 1c2b32fbd457a8.jpg
13 1c98a48e5510ec.jpg
14 1e0f47c2670231.jpg
15 276d36e398c069.jpg
16 29a05a0fd700ac.jpg
17 2dea37f104896b.jpg
18 2f30fad4bdb2c2.jpg
19 310d6a3cfe1a4e.jpg
20 32cd06546e924c.jpg
21 33e729741fa101.jpg
22 388839242d89f1.jpg
23 412fbb13ecded7.jpg
24 422be15adae561.jpg
25 443badcd373167.jpg
26 44616a41d8623a.jpg
27 4596781034fe02.jpg
28 45ef7ada1f32f2.jpg
29 4684b764684b67.jpg
30 488081370abe3b.jpg
31 48cc1de8b1df31.jpg
32 4def4bbec979e4.jpg
33 50864141b27141.jpg
34 51f6fe89078ccb.jpg
35 5784688ed96534.jpg
36 5ab8a4f5164040.jpg
37 5c955e1d880f03.jpg
38 5db729be093c0b.jpg
39 5f2df3303a7ae1.jpg
40 626895c87c1c41.jpg
41 66a962c5a3605c.jpg
42 6c46e7b095aa24.jpg
43 6f9b4676d9a475.jpg
44 7081e21b53ae0c.jpg
45 72a5243e6c0a17.jpg
46 72a6e6077a4698.j

# SR Generator (GSR)

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.prelu(self.conv1(x)))

class GSR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residuals=5):
        super(GSR, self).__init__()
        self.encoder = nn.Conv2d(in_channels, 64, kernel_size=5, padding=2)
        self.resnet = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residuals)])
        self.decoder = nn.Conv2d(64, out_channels, kernel_size=5, padding=2)
        self.upsample = nn.Upsample(scale_factor=4, mode='bicubic')

    def forward(self, x):
        x = self.upsample(x)
        x = self.encoder(x)
        x = self.resnet(x)
        x = self.decoder(x)
        return torch.clamp(x, 0, 1)

# Discriminators (Dx and Dy)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 3, 1),
            self._block(128, 256, 3, 1),
            self._block(256, 512, 3, 1),
            nn.Conv2d(512, 1, kernel_size=3, padding=1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.net(x)

# LR Generator (GLR)

In [9]:
class GLR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GLR, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 3, 2),
            self._block(128, 256, 3, 2),
            nn.Conv2d(256, out_channels, kernel_size=3, padding=1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.net(x)

# Loss Functions

In [11]:
class SRResCycGANLosses:
    def __init__(self, vgg):
        self.vgg = vgg
        self.l1_loss = nn.L1Loss()
        self.tv_loss = TVLoss()

    def perceptual_loss(self, sr, hr):
        sr_vgg = self.vgg((sr - vgg_mean) / vgg_std)
        hr_vgg = self.vgg((hr - vgg_mean) / vgg_std)
        return self.l1_loss(sr_vgg, hr_vgg)

    def gan_loss(self, sr_pred, hr_pred):
        return F.binary_cross_entropy_with_logits(sr_pred, torch.ones_like(sr_pred)) + \
               F.binary_cross_entropy_with_logits(hr_pred, torch.zeros_like(hr_pred))

    def total_variation_loss(self, sr):
        return self.tv_loss(sr)

    def content_loss(self, sr, hr):
        return self.l1_loss(sr, hr)

    def cyclic_loss(self, lr_recon, lr):
        return self.l1_loss(lr_recon, lr)

    def total_loss(self, sr, hr, lr_recon, lr, sr_pred, hr_pred):
        l_per = self.perceptual_loss(sr, hr)
        l_gan = self.gan_loss(sr_pred, hr_pred)
        l_tv = self.total_variation_loss(sr)
        l_content = self.content_loss(sr, hr)
        l_cyclic = self.cyclic_loss(lr_recon, lr)
        return l_per + l_gan + l_tv + 10 * l_content + 10 * l_cyclic

In [13]:
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:, :, 1:, :])
        count_w = self._tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    def _tensor_size(self, t):
        return t.size()[1] * t.size()[2] * t.size()[3]

# Training the Model

In [15]:
def train_srrescycgan(generator_sr, generator_lr, discriminator_hr, discriminator_lr, data_loader, optimizer_g, optimizer_d, losses, num_epochs, accumulation_steps=2):
    for epoch in range(num_epochs):
        for i, (lr, hr) in enumerate(data_loader):
            lr, hr = lr.to(device), hr.to(device)

            # Train Generators
            optimizer_g.zero_grad()
            with autocast():
                sr = generator_sr(lr)
                lr_recon = generator_lr(sr)
                sr_pred = discriminator_hr(sr)
                hr_pred = discriminator_hr(hr)
                loss_g = losses.total_loss(sr, hr, lr_recon, lr, sr_pred, hr_pred) / accumulation_steps

            scaler.scale(loss_g).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer_g)
                scaler.update()
                optimizer_g.zero_grad()

            # Train Discriminators
            optimizer_d.zero_grad()
            with autocast():
                sr_pred = discriminator_hr(sr.detach())
                hr_pred = discriminator_hr(hr)
                loss_d_hr = losses.gan_loss(sr_pred, hr_pred) / accumulation_steps

                lr_pred = discriminator_lr(lr_recon.detach())
                lr_real_pred = discriminator_lr(lr)
                loss_d_lr = losses.gan_loss(lr_pred, lr_real_pred) / accumulation_steps

                loss_d = loss_d_hr + loss_d_lr

            scaler.scale(loss_d).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer_d)
                scaler.update()
                optimizer_d.zero_grad()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss G: {loss_g.item() * accumulation_steps}, Loss D: {loss_d.item() * accumulation_steps}")

# Load Dataset

In [17]:
class SRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = sorted([file for file in os.listdir(lr_dir) if file.endswith(('.png', '.jpg', '.jpeg'))])
        self.hr_images = sorted([file for file in os.listdir(hr_dir) if file.endswith(('.png', '.jpg', '.jpeg'))])
        self.transform = transform

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_image_path = os.path.join(self.hr_dir, self.hr_images[idx])
        
        lr_image = Image.open(lr_image_path).convert('RGB')
        hr_image = Image.open(hr_image_path).convert('RGB')
        
        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)
        
        return lr_image, hr_image

In [29]:
# Paths
lr_dir = 'data/train/LR/'
hr_dir = 'data/train/HR/'

# Dataset and DataLoader
# Define a transform to resize images to 256x256 and convert to tensor
transform = Compose([
    Resize((256, 256)),
    ToTensor()
])

dataset = SRDataset(lr_dir, hr_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Example usage
for lr, hr in data_loader:
    print(lr.shape, hr.shape)

torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])


# Add VGG19

In [21]:
# Define VGG19 model for perceptual loss
class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        vgg = vgg19(pretrained=True)
        self.features = nn.Sequential(*list(vgg.features.children())[:36]).eval()  # Use first 36 layers

        # Freeze parameters
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)

In [23]:
# Normalize input images for VGG19
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 3, 1, 1)
vgg_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 3, 1, 1)

# Run Training

In [25]:
# Model, Optimizer, and Losses
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator_sr = GSR().to(device)
generator_lr = GLR().to(device)
discriminator_hr = Discriminator().to(device)
discriminator_lr = Discriminator().to(device)
vgg = VGG19().to(device)
optimizer_g = torch.optim.Adam(list(generator_sr.parameters()) + list(generator_lr.parameters()), lr=1e-4, betas=(0.9, 0.999))
optimizer_d = torch.optim.Adam(list(discriminator_hr.parameters()) + list(discriminator_lr.parameters()), lr=1e-4, betas=(0.9, 0.999))
losses = SRResCycGANLosses(vgg=vgg)

# Train the model
train_srrescycgan(generator_sr, generator_lr, discriminator_hr, discriminator_lr, data_loader, optimizer_g, optimizer_d, losses, num_epochs=50)



OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 

In [27]:
torch.cuda.empty_cache()