In [None]:
import os
import torch
import torch
import torchvision
import torchvision.transforms as T
import torchmetrics
import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from PIL import Image

In [None]:
experiment_name = "256_512_UNET"

if not os.path.exists("./results/{}".format(experiment_name)):
    os.makedirs("./results/{}".format(experiment_name))

In [None]:
class DIV2KDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, lr_transform=None, hr_transform=None):
        self.root_dir = root_dir
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform
        self.lr_dir = os.path.join(root_dir, 'LR/X4')
        self.hr_dir = os.path.join(root_dir, 'HR')
        self.images = [f for f in os.listdir(self.hr_dir) if not f.startswith('.')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        lr_img_name = os.path.join(self.lr_dir, self.images[idx][:-4] + "x4" + self.images[idx][-4:])
        hr_img_name = os.path.join(self.hr_dir, self.images[idx])
        lr_image = Image.open(lr_img_name)
        hr_image = Image.open(hr_img_name)
        
        if self.lr_transform:
            lr_image = self.lr_transform(lr_image)
        if self.hr_transform:     
            hr_image = self.hr_transform(hr_image)

        return lr_image, hr_image

In [None]:
lr_transform = T.Compose([
    T.Resize((128,128)),
    T.ToTensor(),
])


hr_transform = T.Compose([
    T.Resize((512,512)),
    T.ToTensor(),
])

root_train = "./datasets/train/"
root_val = "./datasets/val/"

train_ds = DIV2KDataset(root_dir=root_train, hr_transform=hr_transform, lr_transform=lr_transform)
val_ds = DIV2KDataset(root_dir=root_val, hr_transform=hr_transform, lr_transform=lr_transform)

In [None]:
batch_size = 8

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 4, prefetch_factor = 13, pin_memory_device = 'cuda')
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 4, prefetch_factor = 13, pin_memory_device = 'cuda')

In [None]:
(lr, hr) = train_ds[0]

In [None]:
num_samples_to_plot = 2
fig, axes = plt.subplots(num_samples_to_plot, 2, figsize=(10, 10))

samples = [160, 170]
for i, sample in enumerate(samples):
    lr_image, hr_image = train_ds[sample]
    lr_image = np.array(lr_image).transpose(1, 2, 0)  # Transpose LR image data
    hr_image = np.array(hr_image).transpose(1, 2, 0)  # Transpose HR image data

    axes[i, 0].imshow(lr_image)
    axes[i, 0].set_title('LR Image')
    axes[i, 1].imshow(hr_image)
    axes[i, 1].set_title('HR Image')

plt.tight_layout()
plt.show()

# Model

In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # Down-sampling layers
        self.down1 = self.conv_block(in_channels, 64)
        self.down2 = self.conv_block(64, 128)
        self.down3 = self.conv_block(128, 256)
        self.down4 = self.conv_block(256, 512)
        
        # Up-sampling layers
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(256*2, 128, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(128*2, 64, kernel_size=2, stride=2)
        self.up4 = nn.ConvTranspose2d(64*2, out_channels, kernel_size=2, stride=2)
        self.up5 = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2)
        
        # # Final conv layer to get the desired number of channels
        # self.final = nn.Conv2d(16, out_channels, kernel_size=1)
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        # Down-sampling path
        down1 = self.down1(x)
        down2 = self.down2(nn.functional.max_pool2d(down1, 2))
        down3 = self.down3(nn.functional.max_pool2d(down2, 2))
        down4 = self.down4(nn.functional.max_pool2d(down3, 2))
        # print(down4.shape)
        # Up-sampling path with skip connections
        up1 = self.up1(down4)
        up1 = torch.cat([down3, up1], dim=1)
        
        up2 = self.up2(up1)
        up2 = torch.cat([down2, up2], dim=1)
        
        up3 = self.up3(up2)
        up3 = torch.cat([down1, up3], dim=1)
        
        up4 = self.up4(up3)
        up4 = torch.cat([x, up4], dim=1)
        up5 = self.up5(up5)
        # Final convolutional layer
        # out = self.final(up4)
        return up5

## Loss

In [None]:
import torchvision.models as models

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg19 = models.vgg19(pretrained=True)
        self.features = vgg19.features[:35].eval()  # Extract features till conv4_4
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)


vgg = VGGFeatureExtractor().cuda()
def g_criterion(image1, image2, vgg=vgg):
    # Preprocess images
    preprocess = T.Compose([
        T.Resize((224, 224)),
        # T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    image1 = preprocess(image1).cuda()
    image2 = preprocess(image2).cuda()

    features1 = vgg(image1)
    features2 = vgg(image2)

    criterion = nn.MSELoss()
    vgg_loss = criterion(features1, features2)
    return vgg_loss.item()

## Train Loop

In [None]:
device = 'cuda'
def train(train_dl, model, opt, loss_fn, extra_loss_fn, epochs, save_epochs=[1, 10, 20, 40, 50, 100]):
	loss_plt = torch.Tensor([0 for _ in range(epochs)])
	for _ in range(epochs):
		for lr, hr in tqdm(train_dl, desc=f'Epoch {_+1}/{epochs}', leave=True):
			lr, hr = lr.to(device), hr.to(device)
			sr = model(lr)
			# print(outputs.grad_fn)
			loss = loss_fn(hr, sr)
			loss_plt[_] += loss.item()
			# print(loss.grad_fn)
			if opt is not None:
				loss.backward()
				opt.step()
				opt.zero_grad()
		if _+1 in save_epochs:
			psnr, ssim = plot_generated_images(model, 3, _+1, 'cuda')
			file_path = './results/{}/generator_{}.pth'.format(experiment_name, _+1)
			torch.save(model.state_dict(), file_path)
			print("Epoch [{}/{}]: \t Running Loss = {} \t PSNR: {} \t SSIM: {}".format(_+1, epochs, loss_plt[_]/len(train_dl), psnr, ssim))
		else:
			print("Epoch [{}/{}]: \t Running Loss = {}".format(_+1, epochs, loss_plt[_]/len(train_dl)))
	return loss_plt, loss_plt/len(train_dl)

In [None]:
def plot_generated_images(G, n_imgs, epoch, device, val_ds=val_ds):
    G.eval()  # Set model to evaluation mode
    with torch.no_grad():
        fig, axs = plt.subplots(n_imgs, 3, figsize=(15, 5 * n_imgs))
        psnr = 0
        ssim = 0
        for i, (lr_img, hr_img) in enumerate(val_ds):
            if i >= n_imgs:
                break

            lr_img = lr_img.to(device).unsqueeze(0)
            hr_img = hr_img.to(device).unsqueeze(0)

            # Generate super-resolved image
            sr_img = G(lr_img)

            psnr_value = torchmetrics.functional.image.peak_signal_noise_ratio(sr_img, hr_img).item()
            ssim_value = torchmetrics.functional.image.structural_similarity_index_measure(sr_img, hr_img).item() 
            psnr += psnr_value
            ssim += ssim_value
            # Move images to CPU for plotting
            lr_img = lr_img.cpu().squeeze(0).permute(1, 2, 0)
            hr_img = hr_img.cpu().squeeze(0).permute(1, 2, 0)
            sr_img = sr_img.cpu().squeeze(0).permute(1, 2, 0)

            # Plotting
            axs[i, 0].imshow(lr_img)
            axs[i, 0].set_title('Low-Resolution')
            axs[i, 0].axis('off')
            axs[i, 1].imshow(hr_img)
            axs[i, 1].set_title('High-Resolution')
            axs[i, 1].axis('off')
            axs[i, 2].imshow(sr_img)
            axs[i, 2].set_title('Super-Resolved')
            axs[i, 2].axis('off')

            axs[i, 2].text(0.5, -0.1, f'SSIM: {ssim_value:.4f}\nPSNR: {psnr_value:.2f} dB', horizontalalignment='center', verticalalignment='bottom', transform=axs[i, 2].transAxes, color='black')
        psnr /= n_imgs
        ssim /= n_imgs

    # Save the plot
    # axs[-1, 1].text(0.5, -0.1, f'Overall \nSSIM: {ssim:.4f}\nPSNR: {psnr:.2f} dB', horizontalalignment='center', verticalalignment='bottom', transform=axs[-1, 1].transAxes color='black')
    plt.tight_layout()
    plt.savefig(f'./results/{experiment_name}/G_{epoch}.png')
    plt.show()
    plt.close()
    # print(f"{epoch} \t PSNR: {psnr:.2f} \t SSIM:{ssim:.4f}")

    # Set model back to training mode
    G.train()
    return psnr, ssim


## Evaluate

In [None]:
def evaluate(g, test_dl):
    g = g.eval()
    device = next(g.parameters()).device
    ssim = 0
    psnr = 0
    for lr, hr in test_dl:
        lr = lr.to(device)
        hr_img = hr.to(device)

        sr_img = g(lr)
        psnr += torchmetrics.functional.image.peak_signal_noise_ratio(sr_img, hr_img).item()
        ssim += torchmetrics.functional.image.structural_similarity_index_measure(sr_img, hr_img).item()

    psnr /= len(test_dl)
    ssim /= len(test_dl)
    g = g.train()
    return psnr, ssim

## Initialize Model

In [None]:
g = UNet(3, 3).to('cuda')

opt = torch.optim.Adam(g.parameters(), lr=0.0001, betas=(0.5, 0.999))

criterion = torch.nn.MSELoss()

In [None]:
import torchsummary

print("DISCRIMINATOR")
# torchsummary.summary(d, (3, 256, 256))
print()
print("GENERATOR")
torchsummary.summary(g, (3, 128, 128))
print()

In [None]:
g_loss = train(train_dl, g, opt, criterion, g_criterion, 200, save_epochs = [1, 10, 30, 50, 80, 100, 150, 200])

In [None]:
psnr, ssim = evaluate(g, val_dl)

print(f"PSNR: {psnr:.2f} dB, SSIM: {ssim:.4f}")

In [None]:
import matplotlib.pyplot as plt

plt.plot(g_loss[-1], label='Generator Loss')
# plt.plot(d_loss, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Training Loss')
plt.legend()
plt.savefig('losses.png')
plt.show()