In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
loli_path = kagglehub.dataset_download('tanvirnwu/loli-street-low-light-image-enhancement-of-street')

print('Data source import complete.')


In [None]:
import os

In [None]:
os.listdir(loli_path)

In [None]:
loli_path = os.path.join(loli_path, "LoLI-Street Dataset")

In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class LowLightDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Path to the root directory containing 'low' and 'high' folders.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform

        # Get the list of low and high image filenames
        self.low_dir = os.path.join(root_dir, 'low')
        self.high_dir = os.path.join(root_dir, 'high')

        # Ensure the filenames in 'low' and 'high' directories match
        self.low_images = sorted(os.listdir(self.low_dir))
        self.high_images = sorted(os.listdir(self.high_dir))

        # Verify that the filenames match
        assert self.low_images == self.high_images, "Low and high image filenames do not match!"

    def __len__(self):
        return len(self.low_images)

    def __getitem__(self, idx):
        # Load low and high images
        low_image_path = os.path.join(self.low_dir, self.low_images[idx])
        high_image_path = os.path.join(self.high_dir, self.high_images[idx])

        low_image = Image.open(low_image_path).convert('RGB')
        high_image = Image.open(high_image_path).convert('RGB')

        # Apply transformations if any
        if self.transform:
            low_image = self.transform(low_image)
            high_image = self.transform(high_image)

        return low_image, high_image

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to a fixed size
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

In [None]:
import torch

# Define paths
train_path = os.path.join(loli_path, 'Train')
test_path = os.path.join(loli_path, 'Test')
val_path = os.path.join(loli_path, 'Val')

# Create datasets
train_dataset = LowLightDataset(root_dir=train_path, transform=transform)
val_dataset = LowLightDataset(root_dir=val_path, transform=transform)

train_dataset = torch.utils.data.Subset(train_dataset, range(5000))
val_dataset = torch.utils.data.Subset(val_dataset, range(5000))

In [None]:
from torch.utils.data import DataLoader

# Define batch size
batch_size = 32

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Get a batch of data
low_images, high_images = next(iter(train_dataloader))

# Create a figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 16))

# Display Low-Light Images
axes[0].set_title("Low-Light Images")
axes[0].imshow(vutils.make_grid(low_images, nrow=8, normalize=True).permute(1, 2, 0))
axes[0].axis("off")

# Display High-Light Images
axes[1].set_title("High-Light Images")
axes[1].imshow(vutils.make_grid(high_images, nrow=8, normalize=True).permute(1, 2, 0))
axes[1].axis("off")

plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residual=9):  # Increased num_residual
        super(UNetGenerator, self).__init__()

        # Initial Convolution
        self.initial = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        )

        # Enhanced Downsampling with more channels
        self.down1 = DownBlock(64, 128)    # /2
        self.down2 = DownBlock(128, 256)   # /4
        self.down3 = DownBlock(256, 512)   # /8
        self.down4 = DownBlock(512, 512)

        # Bottleneck with more Efficient Residual Blocks
        self.bottleneck = nn.Sequential(*[
            EfficientResidualBlock(512) for _ in range(num_residual)
        ])

        # Enhanced Upsampling with Attention-Guided Skip Connections
        self.up1 = UpBlock(512, 512, 512)
        self.up2 = UpBlock(512, 256, 512)    # /8
        self.up3 = UpBlock(256, 128, 256)    # /4
        self.up4 = UpBlock(128, 64, 128)     # /2

        # Output Layer
        self.output = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, kernel_size=7, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        # Initial convolution
        x = self.initial(x)

        # Downsampling with skip connections
        d1 = self.down1(x)   # 128
        d2 = self.down2(d1)  # 256
        d3 = self.down3(d2)  # 512
        d4 = self.down4(d3)

        # Bottleneck processing
        x = self.bottleneck(d4)

        # Upsampling with attention-guided skip connections
        x = self.up1(x, d4)   # 512
        x = self.up2(x, d3)   # 256
        x = self.up3(x, d2)   # 128
        x = self.up4(x, d1)   # 64

        return self.output(x)

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(True)
        )
        self.attention = EnhancedChannelSpatialAttention(out_channels)  # Enhanced attention

    def forward(self, x):
        x = self.conv(x)
        x = self.attention(x)
        return x

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels):
        super(UpBlock, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),  # Lightweight upsampling
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(True)
        )
        self.attention = EnhancedChannelSpatialAttention(out_channels + skip_channels)  # Enhanced attention
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels + skip_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x, skip):
        x = self.up(x)
        skip_resized = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip_resized], dim=1)
        x = self.attention(x)  # Attention-guided feature fusion
        return self.conv(x)

class EfficientResidualBlock(nn.Module):
    def __init__(self, channels):
        super(EfficientResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels),
            nn.Conv2d(channels, channels, kernel_size=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels),
            nn.Conv2d(channels, channels, kernel_size=1),
            nn.InstanceNorm2d(channels)
        )
        self.attention = EnhancedChannelSpatialAttentionWithSE(channels)  # Use combined attention

    def forward(self, x):
        residual = x
        x = self.conv(x)
        x = self.attention(x)  # Apply combined attention
        return x + residual

class EnhancedChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):  # Increased reduction ratio
        super(EnhancedChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(True),
            nn.Linear(channels // reduction, channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg = self.avg_pool(x).view(b, c)
        max = self.max_pool(x).view(b, c)

        avg_out = self.fc(avg).view(b, c, 1, 1)
        max_out = self.fc(max).view(b, c, 1, 1)

        out = avg_out + max_out
        return x * self.sigmoid(out)

class EnhancedSpatialAttention(nn.Module):
    def __init__(self):
        super(EnhancedSpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        max, _ = torch.max(x, dim=1, keepdim=True)
        combined = torch.cat([avg, max], dim=1)
        att = self.conv(combined)
        return x * self.sigmoid(att)
        

class EnhancedChannelSpatialAttention(nn.Module):
    def __init__(self, channels, reduction=16):  # Increased reduction ratio
        super(EnhancedChannelSpatialAttention, self).__init__()
        self.channel_att = EnhancedChannelAttention(channels, reduction)
        self.spatial_att = EnhancedSpatialAttention()

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class EnhancedChannelSpatialAttentionWithSE(nn.Module):
    def __init__(self, channels, reduction=16):
        super(EnhancedChannelSpatialAttentionWithSE, self).__init__()
        self.channel_spatial_att = EnhancedChannelSpatialAttention(channels, reduction)
        self.se_att = SEBlock(channels, reduction)  # Add SE block
        self.combine = nn.Conv2d(channels * 2, channels, kernel_size=1)  # Combine features

    def forward(self, x):
        att1 = self.channel_spatial_att(x)  # Apply channel-spatial attention
        att2 = self.se_att(x)  # Apply SE attention
        combined = torch.cat([att1, att2], dim=1)  # Concatenate outputs
        return self.combine(combined)  # Combine features

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, 64, 4, stride=2, padding=1)),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, stride=1, padding=1)),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(512, 1, 4, stride=1, padding=1)),
            nn.Dropout2d(0.3)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
!pip install pytorch_msssim

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torchvision.models as models
# from pytorch_msssim import ssim

# class PerceptualLoss(nn.Module):
#     def __init__(self):
#         super(PerceptualLoss, self).__init__()
#         self.vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features[:16].eval()
#         for param in self.vgg.parameters():
#             param.requires_grad = False
#         self.criterion = nn.L1Loss()a
#         self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
#         self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

#     def forward(self, input, target):
#         # Ensure input and target have 3 channels
#         if input.shape[1] != 3:
#             input = input[:, :3, :, :]
#         if target.shape[1] != 3:
#             target = target[:, :3, :, :]

#         # Normalize input and target for VGG
#         input_vgg = (input - self.mean.to(input.device)) / self.std.to(input.device)
#         target_vgg = (target - self.mean.to(target.device)) / self.std.to(target.device)

#         # Extract features and compute loss
#         input_features = self.vgg(input_vgg)
#         target_features = self.vgg(target_vgg)
#         return self.criterion(input_features, target_features)


# # SSIM Loss
# class SSIMLoss(nn.Module):
#     def forward(self, input, target):
#         # Ensure input and target have 3 channels
#         if input.shape[1] != 3:
#             input = input[:, :3, :, :]
#         if target.shape[1] != 3:
#             target = target[:, :3, :, :]

#         return 1 - ssim(input, target, data_range=1.0, size_average=True)


# # Edge-Preserving Loss (Sobel-based)
# class EdgePreservingLoss(nn.Module):
#     def __init__(self):
#         super(EdgePreservingLoss, self).__init__()
#         sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
#         sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
#         self.register_buffer("sobel_x", sobel_x)
#         self.register_buffer("sobel_y", sobel_y)

#     def forward(self, input, target):
#         # Ensure input and target have the same number of channels
#         if input.shape[1] != target.shape[1]:
#             raise ValueError("Input and target must have the same number of channels")

#         # Compute edge maps
#         input_edges_x = F.conv2d(input, self.sobel_x.repeat(input.shape[1], 1, 1, 1), groups=input.shape[1], padding=1)
#         input_edges_y = F.conv2d(input, self.sobel_y.repeat(input.shape[1], 1, 1, 1), groups=input.shape[1], padding=1)
#         input_edges = torch.sqrt(input_edges_x**2 + input_edges_y**2)

#         target_edges_x = F.conv2d(target, self.sobel_x.repeat(target.shape[1], 1, 1, 1), groups=target.shape[1], padding=1)
#         target_edges_y = F.conv2d(target, self.sobel_y.repeat(target.shape[1], 1, 1, 1), groups=target.shape[1], padding=1)
#         target_edges = torch.sqrt(target_edges_x**2 + target_edges_y**2)

#         return F.l1_loss(input_edges, target_edges)


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

# Initialize TensorBoard
writer = SummaryWriter(log_dir="/kaggle/working/runs/experiment_name")

# Function to Load Checkpoints with Partial Matching
def load_checkpoint(path, generator, discriminator, device):
    checkpoint = torch.load(path, map_location=device)

    # Load model states with partial matching
    generator_state_dict = checkpoint['generator']
    discriminator_state_dict = checkpoint['discriminator']
    
    generator_dict = generator.state_dict()
    discriminator_dict = discriminator.state_dict()

    # Load only matching weights
    generator_state_dict = {k: v for k, v in generator_state_dict.items() if k in generator_dict}
    discriminator_state_dict = {k: v for k, v in discriminator_state_dict.items() if k in discriminator_dict}

    generator_dict.update(generator_state_dict)
    discriminator_dict.update(discriminator_state_dict)

    generator.load_state_dict(generator_dict, strict=False)
    discriminator.load_state_dict(discriminator_dict, strict=False)

    print(f"Checkpoint {path} loaded successfully with partial matching!")

    return checkpoint['epoch']  # Return last completed epoch


# Training Loop with Full Checkpoint Support
def train(generator, discriminator, dataloader, optimizer_G, optimizer_D, device, 
          start_epoch=0, epochs=100, checkpoint_dir="/kaggle/working/checkpoints"):
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    criterion_GAN = nn.BCEWithLogitsLoss().to(device)
    scaler = GradScaler()
    accumulation_steps = 4
    lambda_L1 = 100

    # Wrap with DataParallel if multiple GPUs are available
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    for epoch in range(start_epoch, start_epoch + epochs):
        total_loss_D, total_loss_G, total_loss_GAN, total_loss_L1 = 0, 0, 0, 0
        current_epoch_in_session = epoch - start_epoch + 1
        
        with tqdm(dataloader, desc=f"Session Epoch {current_epoch_in_session}/{epochs} (Global {epoch+1})",
                  unit="batch") as pbar:
            for i, (low_imgs, high_imgs) in enumerate(pbar):
                low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)

                # Train Discriminator
                with autocast():
                    fake_imgs = generator(low_imgs).detach()
                    real_loss = criterion_GAN(discriminator(high_imgs), torch.ones_like(discriminator(high_imgs)))
                    fake_loss = criterion_GAN(discriminator(fake_imgs), torch.zeros_like(discriminator(fake_imgs)))
                    loss_D = (real_loss + fake_loss) / 2
                scaler.scale(loss_D).backward()
                if (i + 1) % accumulation_steps == 0:
                    scaler.step(optimizer_D)
                    scaler.update()
                    optimizer_D.zero_grad()

                # Train Generator
                with autocast():
                    fake_imgs = generator(low_imgs)
                    loss_GAN = criterion_GAN(discriminator(fake_imgs), torch.ones_like(discriminator(fake_imgs)))
                    loss_L1 = F.l1_loss(fake_imgs, high_imgs)
                    loss_G = loss_GAN + lambda_L1 * loss_L1
                scaler.scale(loss_G).backward()
                if (i + 1) % accumulation_steps == 0:
                    scaler.step(optimizer_G)
                    scaler.update()
                    optimizer_G.zero_grad()

                # Track losses
                total_loss_D += loss_D.item()
                total_loss_G += loss_G.item()
                total_loss_GAN += loss_GAN.item()
                total_loss_L1 += loss_L1.item()

                # Update tqdm display
                pbar.set_postfix({
                    "Loss D": f"{loss_D.item():.4f}",
                    "Loss G": f"{loss_G.item():.4f}",
                    "GAN": f"{loss_GAN.item():.4f}",
                    "L1": f"{loss_L1.item():.4f}"
                })

                # Free memory
                del fake_imgs, loss_D, loss_G, loss_GAN, loss_L1

            # Log losses after each epoch
            writer.add_scalar("Loss/Discriminator", total_loss_D / len(dataloader), epoch)
            writer.add_scalar("Loss/Generator", total_loss_G / len(dataloader), epoch)
            writer.add_scalar("Loss/GAN", total_loss_GAN / len(dataloader), epoch)
            writer.add_scalar("Loss/L1", total_loss_L1 / len(dataloader), epoch)

            print(f"Session Epoch [{current_epoch_in_session}/{epochs}] (Global {epoch+1}) "
                  f"Loss D: {total_loss_D / len(dataloader):.4f}, "
                  f"Loss G: {total_loss_G / len(dataloader):.4f}, "
                  f"GAN: {total_loss_GAN / len(dataloader):.4f}, "
                  f"L1: {total_loss_L1 / len(dataloader):.4f}")

        # Save full checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'generator': generator.module.state_dict() if isinstance(generator, nn.DataParallel) else generator.state_dict(),
                'discriminator': discriminator.module.state_dict() if isinstance(discriminator, nn.DataParallel) else discriminator.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
            }
            torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pth")

    writer.close()

# Initialize Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGenerator().to(device)
discriminator = PatchGANDiscriminator().to(device)

# Initialize Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

# Try to Load Latest Checkpoint
checkpoint_dir = "/kaggle/working/checkpoints"
checkpoint_path = f"/kaggle/input/cvpr_gan/pytorch/default/3/checkpoint_epoch_400.pth"
start_epoch = 0

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, device)
    print(f"Resuming training from epoch {start_epoch}")

# Continue Training from the Loaded Checkpoint
train(generator, discriminator, train_dataloader, optimizer_G, optimizer_D,
      device, start_epoch=start_epoch, epochs=1, checkpoint_dir=checkpoint_dir)


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms

class LowLightTestDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Path to the directory containing low-light images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform

        # Get the list of image filenames in the directory
        self.image_paths = sorted(os.listdir(root_dir))  # assuming all images are in the root directory

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image
        image_path = os.path.join(self.root_dir, self.image_paths[idx])
        image = Image.open(image_path).convert('RGB')

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)

        return image


In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms

# Define transformations for the test images (resize, normalize, etc.)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to the expected size
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

# Create the test dataset
test_dataset = LowLightTestDataset(root_dir=test_path, transform=transform)

# Create the test dataloader
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# import torch
# import matplotlib.pyplot as plt

# def test(generator, dataloader, device):
#     generator.eval()

#     with torch.no_grad():
#         low_imgs = next(iter(dataloader)).to(device)
#         generated_high_imgs = generator(low_imgs)

#         low_imgs = low_imgs.cpu().numpy()
#         generated_high_imgs = generated_high_imgs.cpu().numpy()

#         low_imgs = (low_imgs + 1) / 2
#         generated_high_imgs = (generated_high_imgs + 1) / 2

#         fig, axes = plt.subplots(len(low_imgs), 2, figsize=(8, 4 * len(low_imgs)))

#         for i in range(len(low_imgs)):
#             axes[i, 0].imshow(low_imgs[i].transpose(1, 2, 0))
#             axes[i, 0].set_title(f"Low {i+1}")
#             axes[i, 0].axis('off')

#             axes[i, 1].imshow(generated_high_imgs[i].transpose(1, 2, 0))
#             axes[i, 1].set_title(f"Generated High {i+1}")
#             axes[i, 1].axis('off')

#         plt.tight_layout()
#         plt.show()

# # Perform test
# test(generator, test_dataloader, device)


In [None]:
import torch
import matplotlib.pyplot as plt

def validate(generator, dataloader, device):
    generator.eval()

    with torch.no_grad():
        low_imgs, high_imgs = next(iter(dataloader))
        low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)

        generated_high_imgs = generator(low_imgs)

        generated_high_imgs = generated_high_imgs.cpu().numpy()
        high_imgs = high_imgs.cpu().numpy()
        low_imgs = low_imgs.cpu().numpy()

        generated_high_imgs = (generated_high_imgs + 1) / 2
        high_imgs = (high_imgs + 1) / 2
        low_imgs = (low_imgs + 1) / 2

        num_images = len(low_imgs)
        fig, axes = plt.subplots(num_images, 3, figsize=(10, 5 * num_images))

        for i in range(num_images):
            axes[i, 0].imshow(high_imgs[i].transpose(1, 2, 0))
            axes[i, 0].set_title(f"Original High {i+1}")
            axes[i, 0].axis('off')

            axes[i, 1].imshow(low_imgs[i].transpose(1, 2, 0))
            axes[i, 1].set_title(f"Low {i+1}")
            axes[i, 1].axis('off')

            axes[i, 2].imshow(generated_high_imgs[i].transpose(1, 2, 0))
            axes[i, 2].set_title(f"Generated High {i+1}")
            axes[i, 2].axis('off')

        plt.tight_layout()
        plt.show()

# Perform validation
validate(generator, val_dataloader, device)


In [None]:
eval_path = "/kaggle/input/lli-dataset/LLI_dataset"

# Create datasets
eval_dataset = LowLightDataset(root_dir=eval_path, transform=transform)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
validate(generator, eval_dataloader, device)

In [None]:
# import shutil

# # Zip the checkpoints directory
# shutil.make_archive('/kaggle/working/checkpoints', 'zip', '/kaggle/working/checkpoints')

# # Create a download link for the zip file
# from IPython.display import FileLink
# FileLink('/kaggle/working/checkpoints.zip')

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

generator_params = count_parameters(generator)
discriminator_params = count_parameters(discriminator)

print(f"Generator Parameters: {generator_params:,}")
print(f"Discriminator Parameters: {discriminator_params:,}")


In [None]:
!pip install lpips scikit-image piq
!pip install lpips

In [4]:
import torch
import numpy as np
import lpips
import torchvision.transforms as transforms
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from piq import niqe
import matplotlib.pyplot as plt

# Load LPIPS model
lpips_model = lpips.LPIPS(net='alex').to(device)

# Transform for LPIPS & NIQE
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),  # Resize for NIQE compatibility
])

def compute_metrics(generator, dataloader, device):
    generator.eval()
    
    psnr_list, ssim_list, lpips_list, niqe_list = [], [], [], []
    
    with torch.no_grad():
        for low_imgs, high_imgs in dataloader:
            low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)
            generated_high_imgs = generator(low_imgs).cpu().numpy()
            high_imgs = high_imgs.cpu().numpy()

            # Normalize to [0,1] for SSIM and PSNR
            generated_high_imgs = (generated_high_imgs + 1) / 2
            high_imgs = (high_imgs + 1) / 2

            for i in range(len(low_imgs)):
                gt = np.clip(high_imgs[i].transpose(1, 2, 0), 0, 1)
                pred = np.clip(generated_high_imgs[i].transpose(1, 2, 0), 0, 1)

                # PSNR
                psnr_value = psnr(gt, pred, data_range=1.0)
                psnr_list.append(psnr_value)

                # SSIM
                ssim_value = ssim(gt, pred, data_range=1.0, multichannel=True)
                ssim_list.append(ssim_value)

                # LPIPS (convert to tensor)
                gt_tensor = transform(gt).unsqueeze(0).to(device)
                pred_tensor = transform(pred).unsqueeze(0).to(device)
                lpips_value = lpips_model(gt_tensor, pred_tensor).item()
                lpips_list.append(lpips_value)

                # NIQE
                niqe_value = niqe(torch.tensor(pred).permute(2, 0, 1).unsqueeze(0))
                niqe_list.append(niqe_value.item())

    print(f"Avg PSNR: {np.mean(psnr_list):.4f}")
    print(f"Avg SSIM: {np.mean(ssim_list):.4f}")
    print(f"Avg LPIPS: {np.mean(lpips_list):.4f}")
    print(f"Avg NIQE: {np.mean(niqe_list):.4f}")

    return {
        "PSNR": np.mean(psnr_list),
        "SSIM": np.mean(ssim_list),
        "LPIPS": np.mean(lpips_list),
        "NIQE": np.mean(niqe_list),
    }

# Run Evaluation
metrics = compute_metrics(generator, val_dataloader, device)


ImportError: cannot import name 'niqe' from 'piq' (/usr/local/lib/python3.10/dist-packages/piq/__init__.py)