
Low Light Superresolution using a GAN Architecture
---



In [37]:
import os
from torchvision import transforms
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

def load_dataset():
    # Define paths
    dataset_folder = 'lol_dataset'
    low_folder = os.path.join(dataset_folder, 'low')
    high_folder = os.path.join(dataset_folder, 'high')

    # Sample loading
    transform = transforms.ToTensor()

    low_quality_imgs = []
    high_quality_imgs = []

    resize_transform = transforms.Resize((200, 300))  # Resize for low-quality images

    for filename in os.listdir(low_folder):
        low_qual_img_path = os.path.join(low_folder, filename)
        high_qual_img_path = os.path.join(high_folder, filename)

        try:
            low_qual_img = transform(Image.open(low_qual_img_path))
            high_qual_img = transform(Image.open(high_qual_img_path))

            low_qual_img = resize_transform(low_qual_img)  # Apply resize transformation for low-quality images

            low_quality_imgs.append(low_qual_img)
            high_quality_imgs.append(high_qual_img)
        except Exception as e:
            print(f"Error loading image pair {filename}: {str(e)}")

    print(f"Loaded {len(low_quality_imgs)} image pairs successfully.")

    if len(low_quality_imgs) != len(high_quality_imgs):
        print("Error: Incomplete dataset.")
        return None, None

    print(low_quality_imgs[0].shape)
    print(high_quality_imgs[0].shape)

    return low_quality_imgs, high_quality_imgs


# Split dataset into train, validation, and test sets
def split_dataset(low_quality_imgs, high_quality_imgs):
    X = list(range(len(low_quality_imgs)))
    y = list(range(len(high_quality_imgs)))

    # Split into train (70%), validation (20%), test (10%)
    X_train, X_val_test, y_train, y_val_test = train_test_split(X, y, test_size=0.3, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_val_test, y_val_test, test_size=1/3, random_state=42)

    # Create train, validation, and test datasets
    train_low_quality_imgs = [low_quality_imgs[i] for i in X_train]
    train_high_quality_imgs = [high_quality_imgs[i] for i in y_train]

    val_low_quality_imgs = [low_quality_imgs[i] for i in X_val]
    val_high_quality_imgs = [high_quality_imgs[i] for i in y_val]

    test_low_quality_imgs = [low_quality_imgs[i] for i in X_test]
    test_high_quality_imgs = [high_quality_imgs[i] for i in y_test]

    return train_low_quality_imgs, train_high_quality_imgs, val_low_quality_imgs, val_high_quality_imgs, test_low_quality_imgs, test_high_quality_imgs

num_images = 500

class LowLightDataset(Dataset):
    def __init__(self, low_quality_imgs, high_quality_imgs):
        self.low_quality_imgs = low_quality_imgs
        self.high_quality_imgs = high_quality_imgs

    def __len__(self):
        return len(self.low_quality_imgs)

    def __getitem__(self, idx):
        low_img = self.low_quality_imgs[idx]
        high_img = self.high_quality_imgs[idx]
        return low_img, high_img

# Call the function to load the dataset
low_quality_imgs, high_quality_imgs = load_dataset()

# Split dataset into train(70%), validation(20%), and test(10%) sets
train_low_quality_imgs, train_high_quality_imgs, val_low_quality_imgs, val_high_quality_imgs, test_low_quality_imgs, test_high_quality_imgs = split_dataset(low_quality_imgs, high_quality_imgs)

# Print dataset statistics
print("Train dataset:")
print(f"Low-quality images: {len(train_low_quality_imgs)}")
print(f"High-quality images: {len(train_high_quality_imgs)}")

print("\nValidation dataset:")
print(f"Low-quality images: {len(val_low_quality_imgs)}")
print(f"High-quality images: {len(val_high_quality_imgs)}")

print("\nTest dataset:")
print(f"Low-quality images: {len(test_low_quality_imgs)}")
print(f"High-quality images: {len(test_high_quality_imgs)}")

# Create DataLoader objects for train, validation, and test datasets
train_dataset = LowLightDataset(train_low_quality_imgs, train_high_quality_imgs)
val_dataset = LowLightDataset(val_low_quality_imgs, val_high_quality_imgs)
test_dataset = LowLightDataset(test_low_quality_imgs, test_high_quality_imgs)






Loaded 500 image pairs successfully.
torch.Size([3, 200, 300])
torch.Size([3, 400, 600])
Train dataset:
Low-quality images: 350
High-quality images: 350

Validation dataset:
Low-quality images: 100
High-quality images: 100

Test dataset:
Low-quality images: 50
High-quality images: 50


In [None]:
# Split the datasets into the batches necessary

# Split into train (70%), validation (20%), test (10%)


Mounted at /content/drive


In [None]:
# Load the dataset

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        # Encoder
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)

        # Decoder
        self.tconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.tconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.tconv3 = nn.ConvTranspose2d(64, out_channels, kernel_size=3, padding=1)

        # Skip connections
        self.skip1 = nn.Conv2d(64, 128, kernel_size=1)
        self.skip2 = nn.Conv2d(128, 64, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)

        # Skip connections
        s1 = self.skip1(x1)
        s2 = self.skip2(x2)

        # Decoder
        x = self.tconv1(x3)
        x = x + s1
        x = self.tconv2(x)
        x = x + s2
        x = self.tconv3(x)

        return torch.tanh(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, 1)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
print("DataLoaders created.")

# Models
generator = Generator().to(device)
discriminator = Discriminator().to(device)
print("Models loaded.")

# Loss function
criterion = nn.BCEWithLogitsLoss()
print("Loss function loaded.")

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.0001)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
print("Optimizers created.")

# Training loop
num_epochs = 100
d_losses = []
g_losses = []

# Plot losses
def plot_loss(d_losses, g_losses):
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss")
    plt.plot(g_losses, label="G")
    plt.plot(d_losses, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

for epoch in range(num_epochs):
    print(f"Epoch [{epoch}/{num_epochs}]")

    for i, (low_res, high_res) in enumerate(train_dataloader):
        print(f"Batch [{i+1}/{len(train_dataloader)}]")

        # Move data to device
        low_res = low_res.to(device)
        high_res = high_res.to(device)

        ##############################
        # Train the discriminator
        ##############################

        # Forward pass real samples through discriminator
        real_outputs = discriminator(high_res)
        real_labels = torch.ones_like(real_outputs).to(device)
        d_loss_real = criterion(real_outputs, real_labels)

        # Generate fake images and forward pass through discriminator
        generated_imgs = generator(low_res)
        fake_outputs = discriminator(generated_imgs.detach())
        fake_labels = torch.zeros_like(fake_outputs).to(device)
        d_loss_fake = criterion(fake_outputs, fake_labels)

        # Compute discriminator loss
        d_loss = d_loss_real + d_loss_fake

        # Backpropagation and optimization
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        ##############################
        # Train the generator
        ##############################

        # Forward pass fake images through discriminator
        fake_outputs = discriminator(generated_imgs)
        g_loss = criterion(fake_outputs, real_labels)

        # Backpropagation and optimization
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Save losses for plotting
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

        print(f"Batch [{i+1}/{len(train_dataloader)}], Generator Loss: {g_loss.item()}, Discriminator Loss: {d_loss.item()}")

    # Validation
    with torch.no_grad():
        val_loss = 0.0
        num_val_samples = 0

        for i, (low_res, high_res) in enumerate(val_dataloader):
            low_res = low_res.to(device)
            high_res = high_res.to(device)

            batch_size = low_res.size(0)  # Get the actual batch size

            generated_imgs = generator(low_res[:batch_size])
            fake_outputs = discriminator(generated_imgs)
            real_labels = torch.ones_like(fake_outputs).to(device)  # Adjust real_labels to match the batch size
            g_loss = criterion(fake_outputs, real_labels[:batch_size])
            val_loss += g_loss.item() * batch_size
            num_val_samples += batch_size

        val_loss /= num_val_samples

    print(f"Validation Loss: {val_loss}")



    print(f"Epoch [{epoch}/{num_epochs}], Generator Loss: {g_loss.item()}, Discriminator Loss: {d_loss.item()}, Validation Loss: {val_loss}")

# Plot losses
plot_loss(d_losses, g_losses)


Device: cpu
DataLoaders created.
Models loaded.
Loss function loaded.
Optimizers created.
Epoch [0/100]
Batch [1/22]
Batch [1/22], Generator Loss: 0.708132266998291, Discriminator Loss: 1.3897626399993896
Batch [2/22]
Batch [2/22], Generator Loss: 0.6980065703392029, Discriminator Loss: 1.3783408403396606
Batch [3/22]
Batch [3/22], Generator Loss: 0.684342622756958, Discriminator Loss: 1.3774781227111816
Batch [4/22]
Batch [4/22], Generator Loss: 0.6663116216659546, Discriminator Loss: 1.3830461502075195
Batch [5/22]
Batch [5/22], Generator Loss: 0.6438714265823364, Discriminator Loss: 1.400578498840332
Batch [6/22]
Batch [6/22], Generator Loss: 0.6354162693023682, Discriminator Loss: 1.41600501537323
Batch [7/22]
Batch [7/22], Generator Loss: 0.6358683109283447, Discriminator Loss: 1.4221951961517334
Batch [8/22]
Batch [8/22], Generator Loss: 0.6440486311912537, Discriminator Loss: 1.421299934387207
Batch [9/22]
Batch [9/22], Generator Loss: 0.6566146016120911, Discriminator Loss: 1.4

In [None]:
# Validation of the network

In [None]:
# Evaluate the network on the test images

In [25]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available.")

    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")

    # Print GPU device information
    for gpu in range(num_gpus):
        print(f"GPU {gpu}: {torch.cuda.get_device_name(gpu)}")
else:
    print("CUDA is not available.")


CUDA is not available.


In [34]:
!py -m pip uninstall torch

^C
