In [None]:
import os
import torch
import torch
import torchvision
import torchvision.transforms as T
import torchmetrics
import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from PIL import Image

In [None]:
experiment_name = "512_DCGAN_LOSS"

if not os.path.exists("./results/{}".format(experiment_name)):
    os.makedirs("./results/{}".format(experiment_name))

In [None]:
class DIV2KDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, lr_transform=None, hr_transform=None):
        self.root_dir = root_dir
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform
        self.lr_dir = os.path.join(root_dir, 'LR/X4')
        self.hr_dir = os.path.join(root_dir, 'HR')
        self.images = [f for f in os.listdir(self.hr_dir) if not f.startswith('.')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        lr_img_name = os.path.join(self.lr_dir, self.images[idx][:-4] + "x4" + self.images[idx][-4:])
        hr_img_name = os.path.join(self.hr_dir, self.images[idx])
        lr_image = Image.open(lr_img_name)
        hr_image = Image.open(hr_img_name)
        
        if self.lr_transform:
            lr_image = self.lr_transform(lr_image)
        if self.hr_transform:     
            hr_image = self.hr_transform(hr_image)

        return lr_image, hr_image

In [None]:
lr_transform = T.Compose([
    T.Resize((128,128)),
    T.ToTensor(),
])


hr_transform = T.Compose([
    T.Resize((512,512)),
    T.ToTensor(),
])

root_train = "./datasets/train/"
root_val = "./datasets/val/"

train_ds = DIV2KDataset(root_dir=root_train, hr_transform=hr_transform, lr_transform=lr_transform)
val_ds = DIV2KDataset(root_dir=root_val, hr_transform=hr_transform, lr_transform=lr_transform)

In [None]:
batch_size = 16

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 4, prefetch_factor = 13, pin_memory_device = 'cuda')
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 4, prefetch_factor = 13, pin_memory_device = 'cuda')

In [None]:
(lr, hr) = train_ds[0]

In [None]:
num_samples_to_plot = 2
fig, axes = plt.subplots(num_samples_to_plot, 2, figsize=(10, 10))

samples = [160, 170]
for i, sample in enumerate(samples):
    lr_image, hr_image = train_ds[sample]
    lr_image = np.array(lr_image).transpose(1, 2, 0)  # Transpose LR image data
    hr_image = np.array(hr_image).transpose(1, 2, 0)  # Transpose HR image data

    axes[i, 0].imshow(lr_image)
    axes[i, 0].set_title('LR Image')
    axes[i, 1].imshow(hr_image)
    axes[i, 1].set_title('HR Image')

plt.tight_layout()
plt.show()

# Model

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(
            3, 64, 
            kernel_size=5, stride=2, padding=2
        )
        self.conv2 = nn.Conv2d(
            64, 128,
            kernel_size=5, stride=2, padding=2
        )

        self.fc = nn.Linear(
            2097152, 1
        )

    def forward(self, xb):
        out = self.conv1(xb)
        out = F.leaky_relu(out, negative_slope=0.3)
        out = F.dropout(out, p=0.3)
        out = self.conv2(out)
        # print(out.shape)
        out = F.leaky_relu(out, negative_slope=0.3)
        out = F.dropout(out, p=0.3)
        out = out.view(xb.shape[0], -1)
        # print(out.shape)
        out = self.fc(out)
        return out

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # self.conv  = nn.Upsample(size=(256, 64, 64), mode='bicubic')
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)
        
        self.bn1 = nn.BatchNorm2d(
            num_features=128
        )
        
        self.conv2 = nn.Sequential(*[
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        ])
        self.bn2 = nn.BatchNorm2d(
            num_features=64
        )

        self.conv3 = nn.Sequential(*[
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        ])
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.bn1(out)
        out = F.leaky_relu(out, negative_slope=0.3)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.leaky_relu(out, negative_slope=0.3)
        out = self.conv3(out)
        out = F.tanh(out)
        
        return out

## Loss

In [None]:
import torchvision.models as models

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg19 = models.vgg19(pretrained=True)
        self.features = vgg19.features[:35].eval()  # Extract features till conv4_4
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)


vgg = VGGFeatureExtractor().cuda()
def g_criterion(image1, image2, vgg=vgg):
    x, y = image1.cuda(), image2.cuda()
    # Preprocess images
    preprocess = T.Compose([
        T.Resize((224, 224)),
        # T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    image1 = preprocess(image1).cuda()
    image2 = preprocess(image2).cuda()

    # features1 = vgg(image1)
    # features2 = vgg(image2)

    criterion = nn.MSELoss()
    # vgg_loss = criterion(features1, features2)
    mse_loss = criterion(x, y)
    loss  =  mse_loss
    return loss.item()

## Train Loop

In [None]:
def train(G, D, train_dl, criterion, g_criterion, D_opt, G_opt, epochs, save_epochs = [1, 10, 30, 50], start_idx=0):
    losses_generator = torch.Tensor([0 for _ in range(epochs)])
    losses_discriminator = torch.Tensor([0 for _ in range(epochs)])

    device = next(G.parameters()).device
    
    for epoch in range(start_idx, start_idx+epochs):
        g_running = 0
        d_running = 0
        
        for lr, hr in tqdm(train_dl, desc=f'Epoch {epoch+1}/{start_idx+epochs}', leave=True):
            lr = lr.to(device)
            hr = hr.to(device)
            sr = G(lr)
            labels_shape = (lr.shape[0], )

            # Discriminator training
            D_opt.zero_grad()
            y_fake = torch.zeros(labels_shape, dtype=torch.float32).to(device).unsqueeze(1)
            y_real = torch.ones(labels_shape, dtype=torch.float32).to(device).unsqueeze(1)

            loss_d1 = criterion(D(sr), y_fake)
            loss_d2 = criterion(D(hr), y_real)
            
            loss_discriminator = loss_d1 + loss_d2
            d_running += loss_discriminator.item()
            loss_discriminator.backward()
            D_opt.step()
    
            # Generator Training
            G_opt.zero_grad()
            sr = G(lr)
            y_real = torch.ones(labels_shape, dtype=torch.float32).to(device).unsqueeze(1)
            loss_g1 = criterion(
                D(sr), y_real
            )
            loss_g2 = g_criterion(sr, hr)

            loss_generator = 0.1*(loss_g1 + 0.01 * loss_g2)
            g_running += loss_generator.item()
            loss_generator.backward()
            G_opt.step()
            
        g_running /= len(train_dl)
        d_running /= (2*len(train_dl))
        losses_generator[epoch - start_idx] = g_running
        losses_discriminator[epoch - start_idx] = d_running
        
        if epoch+1 - start_idx in save_epochs:
            psnr, ssim = plot_generated_images(G, 3, epoch+1, device)
            
            file_path = './results/{}/generator_{}.pth'.format(experiment_name, epoch+1)
            torch.save(G.state_dict(), file_path)
            print("Epoch {}/{} Loss G: {}+{} Loss D: {} PSNR: {} SSIM: {}".format(epoch+1, epochs, loss_g1, loss_g2, d_running, psnr, ssim))
        else:
            print("Epoch {}/{} Loss G: {} Loss D: {}".format(epoch+1, epochs, g_running, d_running))
            
    file_path = './results/{}/generator.pth'.format(experiment_name)
    torch.save(G.state_dict(), file_path)
    file_path = './results/{}/discriminator.pth'.format(experiment_name)
    torch.save(D.state_dict(), file_path)
    return losses_generator, losses_discriminator

In [None]:
def plot_generated_images(G, n_imgs, epoch, device, val_ds=val_ds):
    G.eval()  # Set model to evaluation mode
    with torch.no_grad():
        fig, axs = plt.subplots(n_imgs, 3, figsize=(15, 5 * n_imgs))
        psnr = 0
        ssim = 0
        for i, (lr_img, hr_img) in enumerate(val_ds):
            if i >= n_imgs:
                break

            lr_img = lr_img.to(device).unsqueeze(0)
            hr_img = hr_img.to(device).unsqueeze(0)

            # Generate super-resolved image
            sr_img = G(lr_img)

            psnr_value = torchmetrics.functional.image.peak_signal_noise_ratio(sr_img, hr_img).item()
            ssim_value = torchmetrics.functional.image.structural_similarity_index_measure(sr_img, hr_img).item() 
            psnr += psnr_value
            ssim += ssim_value
            # Move images to CPU for plotting
            lr_img = lr_img.cpu().squeeze(0).permute(1, 2, 0)
            hr_img = hr_img.cpu().squeeze(0).permute(1, 2, 0)
            sr_img = sr_img.cpu().squeeze(0).permute(1, 2, 0)

            # Plotting
            axs[i, 0].imshow(lr_img)
            axs[i, 0].set_title('Low-Resolution')
            axs[i, 0].axis('off')
            axs[i, 1].imshow(hr_img)
            axs[i, 1].set_title('High-Resolution')
            axs[i, 1].axis('off')
            axs[i, 2].imshow(sr_img)
            axs[i, 2].set_title('Super-Resolved')
            axs[i, 2].axis('off')

            axs[i, 2].text(0.5, -0.1, f'SSIM: {ssim_value:.4f}\nPSNR: {psnr_value:.2f} dB', horizontalalignment='center', verticalalignment='bottom', transform=axs[i, 2].transAxes, color='black')
        psnr /= n_imgs
        ssim /= n_imgs

    # Save the plot
    # axs[-1, 1].text(0.5, -0.1, f'Overall \nSSIM: {ssim:.4f}\nPSNR: {psnr:.2f} dB', horizontalalignment='center', verticalalignment='bottom', transform=axs[-1, 1].transAxes color='black')
    plt.tight_layout()
    plt.savefig(f'./results/{experiment_name}/G_{epoch}.png')
    plt.show()
    plt.close()
    # print(f"{epoch} \t PSNR: {psnr:.2f} \t SSIM:{ssim:.4f}")

    # Set model back to training mode
    G.train()
    return psnr, ssim


## Evaluate

In [None]:
def evaluate(g, test_dl):
    g = g.eval()
    device = next(g.parameters()).device
    ssim = 0
    psnr = 0
    for lr, hr in test_dl:
        lr = lr.to(device)
        hr_img = hr.to(device)

        sr_img = g(lr)
        psnr += torchmetrics.functional.image.peak_signal_noise_ratio(sr_img, hr_img).item()
        ssim += torchmetrics.functional.image.structural_similarity_index_measure(sr_img, hr_img).item()

    psnr /= len(test_dl)
    ssim /= len(test_dl)
    g = g.train()
    return psnr, ssim

## Initialize Model

In [None]:
g = Generator().to('cuda')
d = Discriminator().to('cuda')

optimizerD = torch.optim.Adam(d.parameters(), lr=0.0005, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(g.parameters(), lr=0.0020, betas=(0.5, 0.999))

criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
import torchsummary

print("DISCRIMINATOR")
# torchsummary.summary(d, (3, 256, 256))
print()
print("GENERATOR")
# torchsummary.summary(g, (3, 64, 64))
print()

In [None]:
g_loss, d_loss = train(g, d, train_dl, criterion, g_criterion, optimizerD, optimizerG, 200, save_epochs = [1, 10, 20, 40, 50, 70, 80, 100, 150, 200])

In [None]:
psnr, ssim = evaluate(g, val_dl)

print(f"PSNR: {psnr:.2f} dB, SSIM: {ssim:.4f}")

In [None]:
import matplotlib.pyplot as plt

plt.plot(g_loss, label='Generator Loss')
plt.plot(d_loss, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Training Loss')
plt.legend()
plt.savefig('losses.png')
plt.show()

In [None]:
g_loss, d_loss = train(g, d, train_dl, criterion, g_criterion, optimizerD, optimizerG, 200, save_epochs = [1, 10, 50, 100, 150, 200], start_idx=200)

In [None]:
g_loss += g_loss_
d_loss += d_loss_

import matplotlib.pyplot as plt

plt.plot(g_loss, label='Generator Loss')
plt.plot(d_loss, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Training Loss')
plt.legend()
plt.savefig('losses.png')
plt.show()

In [None]:
psnr, ssim = evaluate(g, val_dl)

print(f"PSNR: {psnr:.2f} dB, SSIM: {ssim:.4f}")

In [None]:
# g_eval = g.load_state_dict(torch.load(f"./results/{experiment_name}/generator.pth"))
# g.eval()