In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg19
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as compare_psnr, structural_similarity as compare_ssim
from torch.amp import autocast, GradScaler
from torchvision.utils import save_image
import torch.nn.functional as F

# Set memory management configuration
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Learnable Self-Attention Module
class LearnableSelfAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super(LearnableSelfAttention, self).__init__()
        self.query = nn.Conv2d(channels, channels // reduction, kernel_size=1)
        self.key = nn.Conv2d(channels, channels // reduction, kernel_size=1)
        self.value = nn.Conv2d(channels, channels, kernel_size=1)
        self.scale = nn.Parameter(torch.tensor(0.1))  # Learnable scaling parameter

    def forward(self, x):
        batch, channels, height, width = x.size()

        # Compute query, key, and value
        query = self.query(x).view(batch, -1, height * width).permute(0, 2, 1)
        key = self.key(x).view(batch, -1, height * width)
        value = self.value(x).view(batch, -1, height * width).permute(0, 2, 1)

        # Compute attention
        attention = torch.softmax(torch.bmm(query, key) / (channels ** 0.5), dim=-1)
        out = torch.bmm(attention, value).permute(0, 2, 1).view(batch, channels, height, width)

        return self.scale * out + x  # Residual connection


# Dense Residual Block
class DenseResidualBlock(nn.Module):
    def __init__(self, channels):
        super(DenseResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        return x + out  # Residual connection


# Enhanced Generator with Dense Residual Blocks and Learnable Self-Attention
class AdvancedAttentionGenerator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AdvancedAttentionGenerator, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU())
        self.encoder2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU())
        self.encoder3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.ReLU())

        # Latent block with dense residuals and self-attention
        self.latent = nn.Sequential(
            DenseResidualBlock(256),
            LearnableSelfAttention(256),
            DenseResidualBlock(256)
        )

        self.decoder1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.decoder2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.decoder3 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        e1 = self.encoder1(x)  # 256x256 -> 256x256
        e2 = self.encoder2(e1)  # 256x256 -> 128x128
        e3 = self.encoder3(e2)  # 128x128 -> 64x64

        latent = self.latent(e3)  # 64x64 -> 64x64

        d1 = self.decoder1(latent) + e2  # Skip connection
        d2 = self.decoder2(d1) + e1  # Skip connection
        return self.decoder3(d2)  # 256x256 -> 256x256


# Lightweight Discriminator
class LightweightDiscriminator(nn.Module):
    def __init__(self, input_nc):
        super(LightweightDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)


# Dataset Class
class HazyDataset(Dataset):
    def __init__(self, hazy_dir, clear_dir, transform=None):
        self.hazy_paths = sorted([os.path.join(hazy_dir, f) for f in os.listdir(hazy_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.clear_paths = sorted([os.path.join(clear_dir, f) for f in os.listdir(clear_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])

        if len(self.hazy_paths) != len(self.clear_paths):
            raise ValueError(f"Mismatch in number of hazy ({len(self.hazy_paths)}) and clear ({len(self.clear_paths)}) images. Please check the dataset.")

        self.transform = transform

    def __len__(self):
        return len(self.hazy_paths)

    def __getitem__(self, idx):
        hazy = Image.open(self.hazy_paths[idx]).convert("RGB")
        clear = Image.open(self.clear_paths[idx]).convert("RGB")
        if self.transform:
            hazy = self.transform(hazy)
            clear = self.transform(clear)
        return hazy, clear
def hinge_loss_discriminator(real_output, fake_output):
    # Hinge loss for discriminator
    real_loss = torch.mean(torch.clamp(1 - real_output, min=0))
    fake_loss = torch.mean(torch.clamp(1 + fake_output, min=0))

    return real_loss + fake_loss

def hinge_loss_generator(fake_output):
    # Hinge loss for generator
    return -torch.mean(fake_output)

# Loss Functions
def perceptual_loss(output, target, vgg):
    output_features = vgg(output)
    target_features = vgg(target)
    return nn.L1Loss()(output_features, target_features)

# def hinge_loss(output, target):
#     # replacing 0 = -1
#     new_predicted = np.array([-1 if i==0 else i for i in output])

#     # calculating hinge loss
#     hinge_loss = np.mean([max(0, 1-x*y) for x, y in zip(target, new_predicted)])
    return hinge_loss

class SSIMLoss(nn.Module):
    def forward(self, x, y):
        x = x.detach().cpu().numpy()
        y = y.detach().cpu().numpy()

        # Ensure input shape is (H, W, C) for SSIM
        x = np.transpose(x, (0, 2, 3, 1))  # Convert (B, C, H, W) → (B, H, W, C)
        y = np.transpose(y, (0, 2, 3, 1))

        # Compute SSIM for each image in the batch
        ssim_values = [compare_ssim(x[i], y[i], data_range=1.0, win_size=3, channel_axis=-1) for i in range(x.shape[0])]

        return 1 - np.mean(ssim_values)  # Return SSIM loss


def entropy_regularization(generator_output):
    probabilities = torch.sigmoid(generator_output)  # Convert to probability distribution
    return -torch.mean(probabilities * torch.log(probabilities + 1e-8) +
                       (1 - probabilities) * torch.log(1 - probabilities + 1e-8))

def compute_total_loss(generator_output, target, discriminator, hazy, vgg, fake_output1):
    # Content Loss (L1)
    content_loss = nn.L1Loss()(generator_output, target)

    # Perceptual Loss (VGG Features)
    perceptual_loss = nn.L1Loss()(vgg(generator_output), vgg(target))

    # Discriminator outputs
    real_pair = torch.cat([hazy, target], dim=1)
    fake_pair = torch.cat([hazy, generator_output], dim=1)

    real_logits = discriminator(real_pair)
    fake_logits = discriminator(fake_pair)
    fake_output = discriminator(torch.cat([target, generator_output], dim=1))

    d_loss = hinge_loss_discriminator(fake_output1, fake_output)
    g_loss1 = hinge_loss_generator(fake_output)

    # Relativistic GAN Loss
    d_real_loss = nn.BCEWithLogitsLoss()(real_logits - torch.mean(fake_logits), torch.ones_like(real_logits))
    d_fake_loss = nn.BCEWithLogitsLoss()(fake_logits - torch.mean(real_logits), torch.zeros_like(fake_logits))
    adversarial_loss = (d_real_loss + d_fake_loss) / 2  # Balanced GAN loss

    # Structural Similarity Loss (SSIM)
    ssim_loss = SSIMLoss()(generator_output, target)

    # Entropy Regularization Loss
    entropy_loss = entropy_regularization(generator_output)

    # Final Weighted Loss Combination
    total_loss = (
        120 * content_loss +  # Slightly reduced for better perceptual learning
        100 * perceptual_loss +  # Increased to improve feature retention
        70 * ssim_loss +  # Increased to emphasize texture and fine details
        1.0 * entropy_loss +  # Increased for better regularization
        d_loss + g_loss1  # Adversarial loss remains unchanged
    )

    return total_loss




def train(generator, discriminator, train_loader, val_loader, vgg, num_epochs, checkpoint_dir, patience):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    vgg.to(device).eval()
    scaler = GradScaler()

    optimizer_G = optim.AdamW(generator.parameters(), lr=5e-5)
    optimizer_D = optim.AdamW(discriminator.parameters(), lr=1e-4)
    best_psnr = 0
    early_stop_count = 0

    for epoch in range(num_epochs):
        generator.train()
        for hazy, clear in train_loader:
            hazy, clear = hazy.to(device), clear.to(device)

            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            with autocast(device_type='cuda'):
                dehazed = generator(hazy)
                fake_output = discriminator(torch.cat([hazy, dehazed], dim=1))
                total_loss = compute_total_loss(dehazed, clear, discriminator, hazy, vgg,fake_output)  # Updated loss function

            scaler.scale(total_loss).backward()
            scaler.step(optimizer_G)
            scaler.update()

        # Validation phase
        # Validation phase
        # Validation phase
        generator.eval()
        val_psnr, val_ssim = [], []
        with torch.no_grad():
            for hazy, clear in val_loader:
                hazy, clear = hazy.to(device), clear.to(device)
                dehazed = generator(hazy)

                # Convert tensors to numpy arrays
                clear_np = clear.squeeze(0).cpu().numpy()  # Ensure (C, H, W) shape
                dehazed_np = dehazed.squeeze(0).cpu().numpy()  # Ensure (C, H, W) shape

                # Ensure array has at least 3 dimensions before transposing
                if clear_np.ndim == 3:
                    clear_np = np.transpose(clear_np, (1, 2, 0))  # Convert (C, H, W) → (H, W, C)
                    dehazed_np = np.transpose(dehazed_np, (1, 2, 0))  # Convert (C, H, W) → (H, W, C)

                # Compute PSNR
                psnr = compare_psnr(clear_np, dehazed_np, data_range=1.0)

                # Compute SSIM
                # Compute SSIM with adaptive window size
                min_dim = min(clear_np.shape[0], clear_np.shape[1])  # Find the smallest image dimension
                win_size = min(7, min_dim)  # Ensure window size does not exceed image size

                ssim = compare_ssim(clear_np, dehazed_np, data_range=1.0, channel_axis=-1, win_size=win_size)


                val_psnr.append(psnr)
                val_ssim.append(ssim)

        avg_psnr, avg_ssim = np.mean(val_psnr), np.mean(val_ssim)
        print(f"Epoch {epoch+1}/{num_epochs}, PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            early_stop_count = 0
            torch.save(generator.state_dict(), os.path.join(checkpoint_dir, "best_generator.pth"))
        else:
            early_stop_count += 1
            if early_stop_count >= patience:
                print("Early stopping triggered.")
                return


# Paths
train_hazy_dir = "/content/drive/MyDrive/Updated_Code/NH_Dataset/Dataset/train/hazy"
train_clear_dir = "/content/drive/MyDrive/Updated_Code/NH_Dataset/Dataset/train/label"
val_hazy_dir ="/content/drive/MyDrive/Updated_Code/NH_Dataset/Dataset/test/hazy"
val_clear_dir =  "/content/drive/MyDrive/Updated_Code/NH_Dataset/Dataset/test/label"
checkpoint_dir = "/content/drive/MyDrive/Updated_Code/NH_Dataset/28thjuly_2025"
os.makedirs(checkpoint_dir, exist_ok=True)

# Dataset and DataLoader
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
train_dataset = HazyDataset(train_hazy_dir, train_clear_dir, transform=transform)
val_dataset = HazyDataset(val_hazy_dir, val_clear_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Models
generator = AdvancedAttentionGenerator(3, 3)
discriminator = LightweightDiscriminator(6)
vgg = vgg19(weights="IMAGENET1K_V1").features[:16].eval()
#vgg = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)

# Train
train(generator, discriminator, train_loader, val_loader, vgg, num_epochs=500, checkpoint_dir=checkpoint_dir, patience=50)
