In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from skimage.metrics import peak_signal_noise_ratio
import os
import cv2

# Function to apply gamma correction to a numpy array image
def apply_gamma_correction(img, gamma):
    inv_gamma = 1.0 / gamma
    table = np.array([(i / 255.0) ** inv_gamma * 255 for i in np.arange(0, 256)]).astype("uint8")
    return cv2.LUT(img, table)

# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, low_image_paths, high_image_paths, transform=None, gamma=None):
        self.low_image_paths = low_image_paths
        self.high_image_paths = high_image_paths
        self.transform = transform
        self.gamma = gamma
        assert len(self.low_image_paths) == len(self.high_image_paths), "Low and High image directories must contain the same number of images."

    def gamma_correction(self, image, gamma=1.0):
        inv_gamma = 1.0 / gamma
        table = np.array([(i / 255.0) ** inv_gamma * 255 for i in np.arange(0, 256)]).astype("uint8")
        return cv2.LUT(image, table)

    def __len__(self):
        return len(self.low_image_paths)

    def __getitem__(self, idx):
        low_img_path = self.low_image_paths[idx]
        high_img_path = self.high_image_paths[idx]
        low_image = Image.open(low_img_path).convert('RGB')
        high_image = Image.open(high_img_path).convert('RGB')

        if self.gamma:
            low_image_np = np.array(low_image)
            high_image_np = np.array(high_image)
            low_image_np = self.gamma_correction(low_image_np, self.gamma)
            high_image_np = self.gamma_correction(high_image_np, self.gamma)
            low_image = Image.fromarray(low_image_np)
            high_image = Image.fromarray(high_image_np)

        if self.transform:
            low_image = self.transform(low_image)
            high_image = self.transform(high_image)

        return low_image, high_image

import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.adjust_channels = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.adjust_channels:
            residual = self.adjust_channels(residual)

        out += residual
        out = self.relu(out)
        return out

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.pool = nn.MaxPool2d(2)
        self.residual_block = ResidualBlock(in_channels, out_channels)

    def forward(self, x):
        x = self.pool(x)
        x = self.residual_block(x)
        return x

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.residual_block = ResidualBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.residual_block(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = ResidualBlock(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.down5 = Down(512, 512)
        self.up1 = Up(1024, 512)
        self.up2 = Up(1024, 256)
        self.up3 = Up(512, 128)
        self.up4 = Up(256, 64)
        self.up5 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2)
        x = self.up5(x, x1)
        logits = self.outc(x)
        return torch.sigmoid(logits)

# Paths to your low and high light images
train_low_image_directory = '/kaggle/input/loldataset-4/LOLdataset/our485/low'
train_high_image_directory = '/kaggle/input/loldataset-4/LOLdataset/our485/high'
eval_low_image_directory = '/kaggle/input/loldataset-4/LOLdataset/eval15/low'
eval_high_image_directory = '/kaggle/input/loldataset-4/LOLdataset/eval15/high'

# Collecting image paths
train_low_image_paths = [os.path.join(train_low_image_directory, img) for img in os.listdir(train_low_image_directory) if img.endswith('.png')]
train_high_image_paths = [os.path.join(train_high_image_directory, img) for img in os.listdir(train_high_image_directory) if img.endswith('.png')]

eval_low_image_paths = [os.path.join(eval_low_image_directory, img) for img in os.listdir(eval_low_image_directory) if img.endswith('.png')]
eval_high_image_paths = [os.path.join(eval_high_image_directory, img) for img in os.listdir(eval_high_image_directory) if img.endswith('.png')]

# Gamma correction value
gamma_value = 2.2

# Transformations
transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
])

# Datasets and DataLoaders
train_dataset = CustomDataset(train_low_image_paths, train_high_image_paths, transform=transform, gamma=gamma_value)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

eval_dataset = CustomDataset(eval_low_image_paths, eval_high_image_paths, transform=transform, gamma=gamma_value)
eval_loader = DataLoader(eval_dataset, batch_size=4, shuffle=False)

# Model, Loss, Optimizer, and Scheduler
num_models = 5
models = [UNet(n_channels=3, n_classes=3).cuda() for _ in range(num_models)]
criterion = nn.MSELoss()
optimizers = [optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4) for model in models]
schedulers = [optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5) for optimizer in optimizers]

# Training and Evaluation Functions
def train(model, criterion, optimizer, loader, scaler, accumulation_steps=1, max_norm=1.0):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()

    for i, (low, high) in enumerate(loader):
        low = low.cuda()
        high = high.cuda()

        with torch.cuda.amp.autocast():
            outputs = model(low)
            loss = criterion(outputs, high) / accumulation_steps

        scaler.scale(loss).backward()

        if (i + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += loss.item() * accumulation_steps

    epoch_loss = running_loss / len(loader)
    return epoch_loss

def evaluate(model, criterion, loader):
    model.eval()
    running_loss = 0.0
    psnr_list = []

    with torch.no_grad():
        for low, high in loader:
            low = low.cuda()
            high = high.cuda()

            outputs = model(low)
            loss = criterion(outputs, high)
            running_loss += loss.item()

            outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
            high = high.permute(0, 2, 3, 1).cpu().numpy()

            for i in range(outputs.shape[0]):
                psnr_list.append(peak_signal_noise_ratio(high[i], outputs[i]))

    avg_psnr = np.mean(psnr_list) if psnr_list else 0.0

    return running_loss / len(loader), avg_psnr

# Hyperparameter Tuning
param_grid = {
    'lr': [0.0001, 0.0005, 0.001],
    'batch_size': [4, 8],
    'accumulation_steps': [1, 4]
}

best_params = {}
best_val_loss = float('inf')

num_epochs = 30
scaler = torch.cuda.amp.GradScaler()

for lr in param_grid['lr']:
    for batch_size in param_grid['batch_size']:
        for accumulation_steps in param_grid['accumulation_steps']:
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
            models = [UNet(n_channels=3, n_classes=3).cuda() for _ in range(num_models)]
            optimizers = [optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) for model in models]
            schedulers = [optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5) for optimizer in optimizers]

            # Track PSNR values for each epoch
            all_psnr_values = []  # List to store PSNR values for all epochs

            for epoch in range(num_epochs):
                train_losses = []
                for model, optimizer in zip(models, optimizers):
                    train_loss = train(model, criterion, optimizer, train_loader, scaler, accumulation_steps=accumulation_steps, max_norm=1.0)
                    train_losses.append(train_loss)

                avg_train_loss = np.mean(train_losses)

                val_losses = []
                psnr_values = []
                for model in models:
                    val_loss, avg_psnr = evaluate(model, criterion, eval_loader)
                    val_losses.append(val_loss)
                    psnr_values.append(avg_psnr)

                # Append the average PSNR of the best model for the current epoch
                all_psnr_values.append(psnr_values[np.argmin(val_losses)])  

                best_model_idx = np.argmin(val_losses)
                val_loss = val_losses[best_model_idx]

                print(f"LR: {lr}, Batch Size: {batch_size}, Accumulation Steps: {accumulation_steps} - Epoch [{epoch+1}/{num_epochs}] - Avg Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, PSNR: {psnr_values[best_model_idx]:.4f}")

                for scheduler, val_loss in zip(schedulers, val_losses):
                    scheduler.step(val_loss)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_params = {
                        'lr': lr,
                        'batch_size': batch_size,
                        'accumulation_steps': accumulation_steps
                    }

            avg_epoch_psnr = np.mean(all_psnr_values) if all_psnr_values else 0.0
            print(f"Hyperparameters (LR: {lr}, Batch Size: {batch_size}, Accumulation Steps: {accumulation_steps}) - Average PSNR over {num_epochs} epochs: {avg_epoch_psnr:.4f}")

print(f"Best Hyperparameters: {best_params}")


LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [1/30] - Avg Train Loss: 0.0186, Val Loss: 0.0124, PSNR: 21.2690
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [2/30] - Avg Train Loss: 0.0101, Val Loss: 0.0108, PSNR: 20.9596
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [3/30] - Avg Train Loss: 0.0090, Val Loss: 0.0103, PSNR: 21.2878
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [4/30] - Avg Train Loss: 0.0084, Val Loss: 0.0098, PSNR: 21.9843
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [5/30] - Avg Train Loss: 0.0081, Val Loss: 0.0095, PSNR: 22.1693
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [6/30] - Avg Train Loss: 0.0080, Val Loss: 0.0102, PSNR: 21.8825
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [7/30] - Avg Train Loss: 0.0076, Val Loss: 0.0092, PSNR: 21.9760
LR: 0.0001, Batch Size: 4, Accumulation Steps: 1 - Epoch [8/30] - Avg Train Loss: 0.0073, Val Loss: 0.0096, PSNR: 22.4214
LR: 0.0001, Batch Size: 