<a href="https://colab.research.google.com/github/AyaAlHaj17/COEN691-project/blob/main/notebooks/U-NET/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install -q datasets huggingface_hub pillow torch torchvision tqdm requests
!pip install -q pytorch-lightning wandb scikit-image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from datasets import load_dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import random
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


print("Loading HQ-50K dataset from Hugging Face...")
dataset = load_dataset("YangQiee/HQ-50K", split="train")
print(f"Dataset loaded successfully! Total images: {len(dataset)}")


print(f"Dataset columns: {dataset.column_names}")
print(f"First example keys: {dataset[0].keys()}")
print(f"Sample entry: {dataset[0]}")


sample_idx = 0
print(f"\nSample image URL: {dataset[sample_idx]}")


def add_gaussian_noise(image, noise_level=25):
    """Add Gaussian noise to image"""
    noise = torch.randn_like(image) * (noise_level / 255.0)
    noisy = image + noise
    return torch.clamp(noisy, 0, 1)

def jpeg_compression(image, quality=10):
    """Simulate JPEG compression artifacts"""
    import io
    from torchvision.transforms.functional import to_pil_image, to_tensor

    pil_img = to_pil_image(image)
    buffer = io.BytesIO()
    pil_img.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)
    compressed = Image.open(buffer)
    return to_tensor(compressed)

def downsample_image(image, scale_factor=4):
    """Downsample image for super-resolution"""
    h, w = image.shape[1:]
    lr_h, lr_w = h // scale_factor, w // scale_factor
    lr_image = F.interpolate(image.unsqueeze(0), size=(lr_h, lr_w),
                             mode='bicubic', align_corners=False)
    return lr_image.squeeze(0)

def add_rain_streaks(image, num_streaks=100):
    """Add synthetic rain streaks"""
    img_copy = image.clone()
    c, h, w = img_copy.shape

    for _ in range(num_streaks):
        x = random.randint(0, w - 1)
        y = random.randint(0, h - 20)
        length = random.randint(10, 20)
        thickness = random.randint(1, 2)
        brightness = random.uniform(0.3, 0.7)

        for i in range(length):
            if y + i < h:
                x_pos = min(x + random.randint(-1, 1), w - 1)
                for t in range(thickness):
                    if x_pos + t < w:
                        img_copy[:, y + i, x_pos + t] = brightness

    return img_copy


class ImageRestorationDataset(Dataset):
    def __init__(self, hf_dataset, degradation_type='denoise',
                 transform=None, subset_size=None):
        """
        Args:
            hf_dataset: Hugging Face dataset with image URLs
            degradation_type: 'denoise', 'dejpeg', 'super_resolution', 'derain'
            transform: Image transformations
            subset_size: Use subset of data (for faster training)
        """
        self.dataset = hf_dataset
        self.degradation_type = degradation_type
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

        if subset_size and subset_size < len(hf_dataset):
            indices = random.sample(range(len(hf_dataset)), subset_size)
            self.dataset = self.dataset.select(indices)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        from io import BytesIO
        import requests

        max_retries = 3


        item = self.dataset[idx]


        if 'url' in item:
            image_url = item['url']
        elif 'image_url' in item:
            image_url = item['image_url']
        elif isinstance(item, dict) and len(item) == 1:

            image_url = list(item.values())[0]
        else:

            for value in item.values():
                if isinstance(value, str) and value.startswith('http'):
                    image_url = value
                    break

        for attempt in range(max_retries):
            try:
                response = requests.get(image_url, timeout=10)
                response.raise_for_status()
                image = Image.open(BytesIO(response.content))


                if image.mode != 'RGB':
                    image = image.convert('RGB')


                clean_image = self.transform(image)
                break

            except Exception as e:
                if attempt == max_retries - 1:

                    print(f"Failed to load image after {max_retries} attempts: {image_url}")
                    clean_image = torch.zeros(3, 256, 256)
                    break
                continue


        if self.degradation_type == 'denoise':
            degraded_image = add_gaussian_noise(clean_image, noise_level=25)
        elif self.degradation_type == 'dejpeg':
            degraded_image = jpeg_compression(clean_image, quality=10)
        elif self.degradation_type == 'super_resolution':
            degraded_image = downsample_image(clean_image, scale_factor=4)

            degraded_image = F.interpolate(degraded_image.unsqueeze(0),
                                          size=clean_image.shape[1:],
                                          mode='bicubic', align_corners=False).squeeze(0)
        elif self.degradation_type == 'derain':
            degraded_image = add_rain_streaks(clean_image, num_streaks=100)
        else:
            raise ValueError(f"Unknown degradation type: {self.degradation_type}")

        return degraded_image, clean_image



class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNetRestoration(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()


        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)


        self.bottleneck = DoubleConv(512, 1024)


        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):

        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))


        d4 = self.upconv4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.out(d1)
        return torch.clamp(out, 0, 1)


def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for degraded, clean in tqdm(dataloader, desc="Training"):
        degraded, clean = degraded.to(device), clean.to(device)

        optimizer.zero_grad()
        output = model(degraded)
        loss = criterion(output, clean)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_psnr = 0

    with torch.no_grad():
        for degraded, clean in tqdm(dataloader, desc="Validation"):
            degraded, clean = degraded.to(device), clean.to(device)

            output = model(degraded)
            loss = criterion(output, clean)
            total_loss += loss.item()


            for i in range(output.size(0)):
                img_pred = output[i].cpu().numpy().transpose(1, 2, 0)
                img_clean = clean[i].cpu().numpy().transpose(1, 2, 0)
                total_psnr += psnr(img_clean, img_pred, data_range=1.0)

    avg_loss = total_loss / len(dataloader)
    avg_psnr = total_psnr / len(dataloader.dataset)

    return avg_loss, avg_psnr


DEGRADATION_TYPE = 'denoise'  # Options: 'denoise', 'dejpeg', 'super_resolution', 'derain'
BATCH_SIZE = 8
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
SUBSET_SIZE = 5000
print(f"\nTraining Configuration:")
print(f"Degradation Type: {DEGRADATION_TYPE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Number of Epochs: {NUM_EPOCHS}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Dataset Size: {SUBSET_SIZE if SUBSET_SIZE else 'Full'}")

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_restoration = ImageRestorationDataset(
    train_dataset.dataset.select(train_dataset.indices),
    degradation_type=DEGRADATION_TYPE,
    subset_size=SUBSET_SIZE
)

val_restoration = ImageRestorationDataset(
    val_dataset.dataset.select(val_dataset.indices),
    degradation_type=DEGRADATION_TYPE,
    subset_size=int(SUBSET_SIZE * 0.2) if SUBSET_SIZE else None
)


train_loader = DataLoader(train_restoration, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=2)
val_loader = DataLoader(val_restoration, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=2)


model = UNetRestoration().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=NUM_EPOCHS, eta_min=1e-6
)

best_psnr = 0
patience = 5
patience_counter = 0
train_losses, val_losses, val_psnrs = [], [], []

print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")

    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_psnr = validate(model, val_loader, criterion, device)

    scheduler.step()

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val PSNR: {val_psnr:.2f} dB")


    if val_psnr > best_psnr:
        best_psnr = val_psnr
        patience_counter = 0
        torch.save(model.state_dict(), 'best_restoration_model.pth')
        print(f"✓ Saved best model (PSNR: {best_psnr:.2f} dB)")
    else:
        patience_counter += 1
        print(f"⚠ No improvement for {patience_counter}/{patience} epochs")

        if patience_counter >= patience:
            print(f"\n Early stopping triggered at epoch {epoch + 1}")
            print(f"Best PSNR achieved: {best_psnr:.2f} dB")
            break

print(f"\nTraining completed! Best PSNR: {best_psnr:.2f} dB")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(val_psnrs, label='Val PSNR', color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('PSNR (dB)')
ax2.set_title('Validation PSNR')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()


model.load_state_dict(torch.load('best_restoration_model.pth'))
model.eval()


num_samples = 4
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

with torch.no_grad():
    for i in range(num_samples):
        degraded, clean = val_restoration[i]
        degraded_input = degraded.unsqueeze(0).to(device)
        restored = model(degraded_input).squeeze(0).cpu()


        degraded_np = degraded.numpy().transpose(1, 2, 0)
        clean_np = clean.numpy().transpose(1, 2, 0)
        restored_np = restored.numpy().transpose(1, 2, 0)


        psnr_degraded = psnr(clean_np, degraded_np, data_range=1.0)
        psnr_restored = psnr(clean_np, restored_np, data_range=1.0)


        axes[i, 0].imshow(np.clip(degraded_np, 0, 1))
        axes[i, 0].set_title(f'Degraded (PSNR: {psnr_degraded:.2f} dB)')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(np.clip(restored_np, 0, 1))
        axes[i, 1].set_title(f'Restored (PSNR: {psnr_restored:.2f} dB)')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(np.clip(clean_np, 0, 1))
        axes[i, 2].set_title('Ground Truth')
        axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("Image Restoration Training Complete!")
print("="*60)
print(f"✓ Model trained on {DEGRADATION_TYPE} task")
print(f"✓ Best validation PSNR: {best_psnr:.2f} dB")
print(f"✓ Model saved as: best_restoration_model.pth")
print("="*60)