In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
!cp "/content/drive/MyDrive/archive.zip" /content/

In [3]:
!unzip -q "/content/drive/MyDrive/archive.zip" -d /content/

In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np

In [5]:
# Custom Dataset Class
class UnderwaterImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.raw_images = sorted(os.listdir(os.path.join(root_dir, '/content/Train/Raw')))
        self.reference_images = sorted(os.listdir(os.path.join(root_dir, '/content/Train/Reference')))

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.root_dir, '/content/Train/Raw', self.raw_images[idx])
        reference_image_path = os.path.join(self.root_dir, '/content/Train/Reference', self.reference_images[idx])

        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image


In [6]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

dataset = UnderwaterImageDataset(root_dir='Dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [7]:
# UNet Generator
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.encoder1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)  # Downsample
        self.encoder2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.decoder1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Upsample
        self.decoder2 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        enc1 = F.leaky_relu(self.encoder1(x))
        enc2 = F.leaky_relu(self.encoder2(enc1))
        dec1 = F.relu(self.decoder1(enc2))
        dec2 = self.decoder2(dec1)
        return dec2

class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [8]:
# L1 Loss Function
def l1_loss(y_true, y_pred):
    return F.l1_loss(y_true, y_pred)

In [9]:
# Hyperparameters
lambda_l1 = 100
learning_rate = 0.0001
num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

In [10]:
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (real_images, target_images) in enumerate(dataloader):
        real_images = real_images.to(device)
        target_images = target_images.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        fake_images = generator(real_images)

        real_pairs = torch.cat((real_images, target_images), dim=1)
        fake_pairs = torch.cat((real_images, fake_images.detach()), dim=1)

        D_real = discriminator(real_pairs)
        D_fake = discriminator(fake_pairs)

        # Clamping outputs
        eps = 1e-8  # A small value to avoid log(0)
        loss_D = -torch.mean(torch.log(D_real + eps) + torch.log(1 - D_fake + eps))
        loss_D.backward()
        optimizer_D.step()


        # Train Generator
        optimizer_G.zero_grad()

        # Computation of G loss
        D_fake_for_G = discriminator(fake_pairs)
        loss_G_GAN = -torch.mean(torch.log(D_fake_for_G))
        loss_G_L1 = l1_loss(target_images, fake_images)
        loss_G = loss_G_GAN + lambda_l1 * loss_G_L1

        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch}/{num_epochs}], Loss D: {loss_D.item()}, Loss G: {loss_G.item()}')

Epoch [0/50], Loss D: 0.570707380771637, Loss G: 17.290164947509766
Epoch [1/50], Loss D: 0.7995550036430359, Loss G: 22.42626953125
Epoch [2/50], Loss D: 0.6717220544815063, Loss G: 20.798009872436523
Epoch [3/50], Loss D: 0.6300781965255737, Loss G: 20.150964736938477
Epoch [4/50], Loss D: 0.21730884909629822, Loss G: 29.954578399658203
Epoch [5/50], Loss D: 0.1986531764268875, Loss G: 24.817581176757812
Epoch [6/50], Loss D: 0.29798245429992676, Loss G: 20.87131690979004
Epoch [7/50], Loss D: 0.2675427496433258, Loss G: 23.649805068969727
Epoch [8/50], Loss D: 0.2299765944480896, Loss G: 21.938749313354492
Epoch [9/50], Loss D: 0.34901124238967896, Loss G: 19.72670555114746
Epoch [10/50], Loss D: 1.1516358852386475, Loss G: 13.272932052612305
Epoch [11/50], Loss D: 0.9599515199661255, Loss G: 16.538795471191406
Epoch [12/50], Loss D: 0.15175946056842804, Loss G: 30.893850326538086
Epoch [13/50], Loss D: 0.25608018040657043, Loss G: 29.184165954589844
Epoch [14/50], Loss D: 0.5663152

In [11]:
import torch
import torch.nn.functional as F

# MSE Calculation
def mse(image1, image2):
    return F.mse_loss(image1, image2)

# PSNR Calculation
def psnr(image1, image2, max_val=1.0):
    mse_value = mse(image1, image2)
    psnr_value = 10 * torch.log10(max_val ** 2 / mse_value)
    return psnr_value

# SSIM Calculation
def ssim(image1, image2, C1=0.01**2, C2=0.03**2):
    mu1 = F.avg_pool2d(image1, kernel_size=11, stride=1, padding=5)
    mu2 = F.avg_pool2d(image2, kernel_size=11, stride=1, padding=5)

    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.avg_pool2d(image1 ** 2, kernel_size=11, stride=1, padding=5) - mu1_sq
    sigma2_sq = F.avg_pool2d(image2 ** 2, kernel_size=11, stride=1, padding=5) - mu2_sq
    sigma12 = F.avg_pool2d(image1 * image2, kernel_size=11, stride=1, padding=5) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


In [12]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define the custom dataset
class PairedImageDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.image_filenames = os.listdir(raw_dir)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        raw_image_path = os.path.join(self.raw_dir, image_filename)
        reference_image_path = os.path.join(self.reference_dir, image_filename)

        raw_image = Image.open(raw_image_path).convert('RGB')
        reference_image = Image.open(reference_image_path).convert('RGB')

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

raw_dir = '/content/Test/Raw'
reference_dir = '/content/Test/Reference'
test_dataset = PairedImageDataset(raw_dir, reference_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

import torch

# Set the generator to evaluation mode
generator.eval()

# Initialize accumulators for MSE, PSNR, SSIM
mse_total = 0.0
psnr_total = 0.0
ssim_total = 0.0
num_samples = len(test_loader)

for raw_image, reference_image in test_loader:
    with torch.no_grad():
        generated_image = generator(raw_image)

    generated_image = (generated_image + 1) / 2
    reference_image = (reference_image + 1) / 2

    mse_value = mse(generated_image, reference_image)
    psnr_value = psnr(generated_image, reference_image)
    ssim_value = ssim(generated_image, reference_image)

    mse_total += mse_value.item()
    psnr_total += psnr_value.item()
    ssim_total += ssim_value.item()

mean_mse = mse_total / num_samples
mean_psnr = psnr_total / num_samples
mean_ssim = ssim_total / num_samples

print(f"Mean MSE: {mean_mse}")
print(f"Mean PSNR: {mean_psnr}")
print(f"Mean SSIM: {mean_ssim}")

Mean MSE: 0.01947583749057039
Mean PSNR: 18.910640776784795
Mean SSIM: 0.8239411484254034
