<a href="https://colab.research.google.com/github/AyaAlHaj17/COEN691-project/blob/main/notebooks/DnCNN/DnCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install -q datasets huggingface_hub pillow torch torchvision tqdm requests
!pip install -q pytorch-lightning wandb scikit-image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from datasets import load_dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import random
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


print("Loading HQ-50K dataset from Hugging Face...")
dataset = load_dataset("YangQiee/HQ-50K", split="train")
print(f"Dataset loaded successfully! Total images: {len(dataset)}")

print(f"Dataset columns: {dataset.column_names}")
print(f"First example keys: {dataset[0].keys()}")


def add_gaussian_noise(image, noise_level=25):
    """Add Gaussian noise to image"""
    noise = torch.randn_like(image) * (noise_level / 255.0)
    noisy = image + noise
    return torch.clamp(noisy, 0, 1)

def jpeg_compression(image, quality=10):
    """Simulate JPEG compression artifacts"""
    import io
    from torchvision.transforms.functional import to_pil_image, to_tensor

    pil_img = to_pil_image(image)
    buffer = io.BytesIO()
    pil_img.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)
    compressed = Image.open(buffer)
    return to_tensor(compressed)

def downsample_image(image, scale_factor=4):
    """Downsample image for super-resolution"""
    h, w = image.shape[1:]
    lr_h, lr_w = h // scale_factor, w // scale_factor
    lr_image = F.interpolate(image.unsqueeze(0), size=(lr_h, lr_w),
                             mode='bicubic', align_corners=False)
    return lr_image.squeeze(0)

def add_rain_streaks(image, num_streaks=100):
    """Add synthetic rain streaks"""
    img_copy = image.clone()
    c, h, w = img_copy.shape

    for _ in range(num_streaks):
        x = random.randint(0, w - 1)
        y = random.randint(0, h - 20)
        length = random.randint(10, 20)
        thickness = random.randint(1, 2)
        brightness = random.uniform(0.3, 0.7)

        for i in range(length):
            if y + i < h:
                x_pos = min(x + random.randint(-1, 1), w - 1)
                for t in range(thickness):
                    if x_pos + t < w:
                        img_copy[:, y + i, x_pos + t] = brightness

    return img_copy



class ImageRestorationDataset(Dataset):
    def __init__(self, hf_dataset, degradation_type='denoise',
                 transform=None, subset_size=None):
        """
        Args:
            hf_dataset: Hugging Face dataset with image URLs
            degradation_type: 'denoise', 'dejpeg', 'super_resolution', 'derain'
            transform: Image transformations
            subset_size: Use subset of data (for faster training)
        """
        self.dataset = hf_dataset
        self.degradation_type = degradation_type
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

        if subset_size and subset_size < len(hf_dataset):
            indices = random.sample(range(len(hf_dataset)), subset_size)
            self.dataset = self.dataset.select(indices)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        from io import BytesIO
        import requests

        max_retries = 3


        item = self.dataset[idx]


        if 'url' in item:
            image_url = item['url']
        elif 'image_url' in item:
            image_url = item['image_url']
        elif isinstance(item, dict) and len(item) == 1:
            image_url = list(item.values())[0]
        else:
            for value in item.values():
                if isinstance(value, str) and value.startswith('http'):
                    image_url = value
                    break


        for attempt in range(max_retries):
            try:
                response = requests.get(image_url, timeout=10)
                response.raise_for_status()
                image = Image.open(BytesIO(response.content))

                if image.mode != 'RGB':
                    image = image.convert('RGB')

                clean_image = self.transform(image)
                break

            except Exception as e:
                if attempt == max_retries - 1:
                    print(f"Failed to load image after {max_retries} attempts: {image_url}")
                    clean_image = torch.zeros(3, 256, 256)
                    break
                continue


        if self.degradation_type == 'denoise':
            degraded_image = add_gaussian_noise(clean_image, noise_level=25)
        elif self.degradation_type == 'dejpeg':
            degraded_image = jpeg_compression(clean_image, quality=10)
        elif self.degradation_type == 'super_resolution':
            degraded_image = downsample_image(clean_image, scale_factor=4)
            degraded_image = F.interpolate(degraded_image.unsqueeze(0),
                                          size=clean_image.shape[1:],
                                          mode='bicubic', align_corners=False).squeeze(0)
        elif self.degradation_type == 'derain':
            degraded_image = add_rain_streaks(clean_image, num_streaks=100)
        else:
            raise ValueError(f"Unknown degradation type: {self.degradation_type}")

        return degraded_image, clean_image



class DnCNN(nn.Module):
    """
    DnCNN: Deep Convolutional Neural Network for Image Denoising
    IMPROVED VERSION with better stability
    """
    def __init__(self, channels=3, num_layers=20, features=64):
        super(DnCNN, self).__init__()

        layers = []


        layers.append(nn.Conv2d(channels, features, kernel_size=3, padding=1, bias=True))
        layers.append(nn.ReLU(inplace=True))


        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))


        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1, bias=True))

        self.dncnn = nn.Sequential(*layers)


        self._initialize_weights()

    def forward(self, x):

        noise = self.dncnn(x)

        return torch.clamp(x - noise, 0, 1)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



def train_epoch(model, dataloader, criterion, optimizer, device, clip_grad=True):
    model.train()
    total_loss = 0

    for degraded, clean in tqdm(dataloader, desc="Training"):
        degraded, clean = degraded.to(device), clean.to(device)

        optimizer.zero_grad()
        output = model(degraded)
        loss = criterion(output, clean)
        loss.backward()


        if clip_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_psnr = 0

    with torch.no_grad():
        for degraded, clean in tqdm(dataloader, desc="Validation"):
            degraded, clean = degraded.to(device), clean.to(device)

            output = model(degraded)
            loss = criterion(output, clean)
            total_loss += loss.item()


            for i in range(output.size(0)):
                img_pred = output[i].cpu().numpy().transpose(1, 2, 0)
                img_clean = clean[i].cpu().numpy().transpose(1, 2, 0)
                total_psnr += psnr(img_clean, img_pred, data_range=1.0)

    avg_loss = total_loss / len(dataloader)
    avg_psnr = total_psnr / len(dataloader.dataset)

    return avg_loss, avg_psnr


DEGRADATION_TYPE = 'denoise'
BATCH_SIZE = 4
NUM_EPOCHS = 25
LEARNING_RATE = 1e-4
SUBSET_SIZE = 1500
print(f"\n{'='*60}")
print(f"DnCNN CPU-OPTIMIZED Training Configuration")
print(f"{'='*60}")
print(f"Model: DnCNN (20 layers)")
print(f"Device: {device} (CPU-optimized)")
print(f"Degradation Type: {DEGRADATION_TYPE}")
print(f"Batch Size: {BATCH_SIZE} (CPU-friendly)")
print(f"Number of Epochs: {NUM_EPOCHS}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Dataset Size: {SUBSET_SIZE} (reduced for CPU)")
print(f"Gradient Clipping: Enabled")
print(f"{'='*60}\n")


train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_restoration = ImageRestorationDataset(
    train_dataset.dataset.select(train_dataset.indices),
    degradation_type=DEGRADATION_TYPE,
    subset_size=SUBSET_SIZE
)

val_restoration = ImageRestorationDataset(
    val_dataset.dataset.select(val_dataset.indices),
    degradation_type=DEGRADATION_TYPE,
    subset_size=int(SUBSET_SIZE * 0.2) if SUBSET_SIZE else None
)


train_loader = DataLoader(train_restoration, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=0)
val_loader = DataLoader(val_restoration, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=0)


model = DnCNN(channels=3, num_layers=15, features=48).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB\n")


criterion = nn.L1Loss()


optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE,
                            weight_decay=1e-5, betas=(0.9, 0.999))


scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=NUM_EPOCHS, eta_min=1e-7
)


warmup_epochs = 3
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=0.1, total_iters=warmup_epochs
)


best_psnr = 0
patience = 5
patience_counter = 0
train_losses, val_losses, val_psnrs = [], [], []

print("Starting CPU-OPTIMIZED DnCNN training...")
print(f"{'='*60}")
print(" CPU Training Tips:")
print("  • Smaller batch size (4) for faster processing")
print("  • Reduced dataset (500 images)")
print("  • Lighter model (15 layers, 48 features)")
print("  • Expected time: ~15-20 minutes")
print(f"{'='*60}\n")

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'-'*60}")


    if epoch < warmup_epochs:
        current_scheduler = warmup_scheduler
        print(f"[WARMUP PHASE]")
    else:
        current_scheduler = scheduler

    train_loss = train_epoch(model, train_loader, criterion, optimizer, device, clip_grad=True)
    val_loss, val_psnr = validate(model, val_loader, criterion, device)

    current_scheduler.step()

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)

    current_lr = optimizer.param_groups[0]['lr']
    print(f"LR: {current_lr:.6f} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val PSNR: {val_psnr:.2f} dB")


    if val_psnr > best_psnr:
        improvement = val_psnr - best_psnr
        best_psnr = val_psnr
        patience_counter = 0
        torch.save(model.state_dict(), 'best_dncnn_model.pth')
        print(f"✓ Saved best model (PSNR: {best_psnr:.2f} dB, +{improvement:.2f} dB)")
    else:
        patience_counter += 1
        print(f"⚠ No improvement for {patience_counter}/{patience} epochs")

        if patience_counter >= patience:
            print(f"\n Early stopping triggered at epoch {epoch + 1}")
            print(f"Best PSNR achieved: {best_psnr:.2f} dB")
            break

    print()

print(f"\n{'='*60}")
print(f"DnCNN Training Completed!")
print(f"{'='*60}")
print(f"Best PSNR: {best_psnr:.2f} dB")
print(f"Total Epochs: {len(train_losses)}")
print(f"Improvement over baseline: +{best_psnr - val_psnrs[0]:.2f} dB")
print(f"{'='*60}\n")


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label='Train Loss', linewidth=2)
ax1.plot(val_losses, label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('DnCNN OPTIMIZED: Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

ax2.plot(val_psnrs, label='Val PSNR', color='green', linewidth=2, marker='o')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('PSNR (dB)', fontsize=12)
ax2.set_title('DnCNN OPTIMIZED: Validation PSNR', fontsize=14, fontweight='bold')
ax2.axhline(y=best_psnr, color='r', linestyle='--', alpha=0.5, label=f'Best: {best_psnr:.2f} dB')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


model.load_state_dict(torch.load('best_dncnn_model.pth'))
model.eval()


num_samples = 4
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

print("Testing OPTIMIZED DnCNN on sample images...")
with torch.no_grad():
    for i in range(num_samples):
        degraded, clean = val_restoration[i]
        degraded_input = degraded.unsqueeze(0).to(device)
        restored = model(degraded_input).squeeze(0).cpu()

        # Convert to numpy for display
        degraded_np = degraded.numpy().transpose(1, 2, 0)
        clean_np = clean.numpy().transpose(1, 2, 0)
        restored_np = restored.numpy().transpose(1, 2, 0)

        # Calculate metrics
        psnr_degraded = psnr(clean_np, degraded_np, data_range=1.0)
        psnr_restored = psnr(clean_np, restored_np, data_range=1.0)
        ssim_degraded = ssim(clean_np, degraded_np, data_range=1.0, channel_axis=2)
        ssim_restored = ssim(clean_np, restored_np, data_range=1.0, channel_axis=2)

        # Plot
        axes[i, 0].imshow(np.clip(degraded_np, 0, 1))
        axes[i, 0].set_title(f'Degraded\nPSNR: {psnr_degraded:.2f} dB | SSIM: {ssim_degraded:.3f}',
                            fontsize=11)
        axes[i, 0].axis('off')

        axes[i, 1].imshow(np.clip(restored_np, 0, 1))
        axes[i, 1].set_title(f'DnCNN Restored\nPSNR: {psnr_restored:.2f} dB (+{psnr_restored-psnr_degraded:.2f}) | SSIM: {ssim_restored:.3f}',
                            fontsize=11, fontweight='bold', color='green')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(np.clip(clean_np, 0, 1))
        axes[i, 2].set_title('Ground Truth', fontsize=11)
        axes[i, 2].axis('off')

plt.tight_layout()
plt.show()



print("\n" + "="*60)
print("DnCNN CPU-OPTIMIZED - PERFORMANCE SUMMARY")
print("="*60)
print(f"✓ Model: DnCNN (15 layers, CPU-optimized)")
print(f"✓ Parameters: {trainable_params:,} (~{trainable_params/1e6:.1f}M)")
print(f"✓ Device: {device}")
print(f"✓ Task: {DEGRADATION_TYPE}")
print(f"✓ Best PSNR: {best_psnr:.2f} dB")
print(f"✓ Training Epochs: {len(train_losses)}")
print(f"✓ Training Data: {SUBSET_SIZE} images")
print(f"✓ Model saved as: best_dncnn_model.pth")
print("="*60)
print("\n CPU OPTIMIZATIONS APPLIED:")
print("  • Reduced model size: 20→15 layers, 64→48 features")
print("  • Smaller batch size: 8→4")
print("  • Less data: 5000→500 images")
print("  • Fewer epochs: 50→25")
print("  • No multiprocessing (num_workers=0)")
print("  • Memory efficient settings")
print("="*60)