<a href="https://colab.research.google.com/github/AyaAlHaj17/COEN691-project/blob/main/notebooks/ResNet/ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install -q datasets huggingface_hub pillow torch torchvision tqdm requests
!pip install -q scikit-image


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from datasets import load_dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import random
import requests
from io import BytesIO
import time
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")



class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class ImprovedResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.ca = ChannelAttention(channels)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.ca(out)
        out += residual
        out = self.relu(out)
        return out


class MultiScaleBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.branch1 = nn.Conv2d(channels, channels//4, 1)
        self.branch2 = nn.Conv2d(channels, channels//4, 3, padding=1)
        self.branch3 = nn.Conv2d(channels, channels//4, 5, padding=2)
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(channels, channels//4, 1)
        )
        self.bn = nn.BatchNorm2d(channels)

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        out = torch.cat([b1, b2, b3, b4], dim=1)
        return self.bn(out)



class EnhancedResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_blocks=12, features=64):
        super().__init__()


        self.conv_input = nn.Sequential(
            nn.Conv2d(in_channels, features, 3, padding=1),
            nn.ReLU(inplace=True)
        )


        self.multi_scale = MultiScaleBlock(features)


        self.res_blocks = nn.ModuleList([
            ImprovedResidualBlock(features) for _ in range(num_blocks)
        ])


        self.fusion_conv = nn.Conv2d(features * 3, features, 1)


        self.conv_output = nn.Sequential(
            nn.Conv2d(features, features, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, out_channels, 3, padding=1)
        )

        self._initialize_weights()

    def forward(self, x):
        feat = self.conv_input(x)
        feat = self.multi_scale(feat)

        block_feats = []
        for i, block in enumerate(self.res_blocks):
            feat = block(feat)
            if i in [len(self.res_blocks)//4, len(self.res_blocks)//2, 3*len(self.res_blocks)//4]:
                block_feats.append(feat)

        if len(block_feats) == 3:
            feat = self.fusion_conv(torch.cat(block_feats, dim=1))

        out = self.conv_output(feat)
        out = x + out

        return torch.clamp(out, 0, 1)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



class RobustImageRestorationDataset(Dataset):
    def __init__(self, hf_dataset, degradation_type='denoise',
                 transform=None, subset_size=None, max_retries=3):
        self.dataset = hf_dataset
        self.degradation_type = degradation_type
        self.max_retries = max_retries
        self.failed_count = 0

        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])


        if subset_size and subset_size < len(hf_dataset):
            indices = random.sample(range(len(hf_dataset)), subset_size)
            self.indices = indices
        else:
            self.indices = list(range(len(hf_dataset)))

        print(f"Dataset size: {len(self.indices)} images")

    def _extract_url(self, item):
        if 'url' in item:
            return item['url']
        elif 'image_url' in item:
            return item['image_url']
        else:
            for value in item.values():
                if isinstance(value, str) and value.startswith('http'):
                    return value
        return None

    def _add_gaussian_noise(self, image, noise_level=25):
        noise = torch.randn_like(image) * (noise_level / 255.0)
        return torch.clamp(image + noise, 0, 1)

    def _jpeg_compression(self, image, quality=10):
        from torchvision.transforms.functional import to_pil_image, to_tensor
        pil_img = to_pil_image(image)
        buffer = BytesIO()
        pil_img.save(buffer, format='JPEG', quality=quality)
        buffer.seek(0)
        compressed = Image.open(buffer)
        return to_tensor(compressed)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]

        for attempt in range(self.max_retries):
            try:
                item = self.dataset[actual_idx]
                url = self._extract_url(item)

                if not url:
                    raise ValueError("No URL found")

                response = requests.get(url, timeout=10)
                response.raise_for_status()

                image = Image.open(BytesIO(response.content))
                if image.mode != 'RGB':
                    image = image.convert('RGB')

                clean_image = self.transform(image)


                if self.degradation_type == 'denoise':
                    degraded_image = self._add_gaussian_noise(clean_image, 25)
                elif self.degradation_type == 'dejpeg':
                    degraded_image = self._jpeg_compression(clean_image, 10)
                else:
                    degraded_image = self._add_gaussian_noise(clean_image, 25)

                return degraded_image, clean_image

            except Exception as e:
                if attempt == self.max_retries - 1:
                    self.failed_count += 1

                    clean = torch.rand(3, 256, 256)
                    degraded = self._add_gaussian_noise(clean, 25)
                    return degraded, clean
                time.sleep(0.5)


        clean = torch.rand(3, 256, 256)
        degraded = self._add_gaussian_noise(clean, 25)
        return degraded, clean


class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.l2_loss = nn.MSELoss()

    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        l2 = self.l2_loss(pred, target)
        return l1 + 0.1 * l2



def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    valid_batches = 0

    for degraded, clean in tqdm(dataloader, desc="Training"):
        try:
            degraded, clean = degraded.to(device), clean.to(device)

            # Check for invalid data
            if torch.isnan(degraded).any() or torch.isnan(clean).any():
                continue

            optimizer.zero_grad()
            output = model(degraded)
            loss = criterion(output, clean)

            # Check for invalid loss
            if torch.isnan(loss) or torch.isinf(loss):
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            valid_batches += 1

        except Exception as e:
            print(f"Batch error: {e}")
            continue

    return total_loss / max(valid_batches, 1)

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_psnr = 0
    total_ssim = 0
    valid_samples = 0

    with torch.no_grad():
        for degraded, clean in tqdm(dataloader, desc="Validation"):
            try:
                degraded, clean = degraded.to(device), clean.to(device)

                if torch.isnan(degraded).any() or torch.isnan(clean).any():
                    continue

                output = model(degraded)
                loss = criterion(output, clean)

                if not torch.isnan(loss) and not torch.isinf(loss):
                    total_loss += loss.item()


                for i in range(output.size(0)):
                    img_pred = output[i].cpu().numpy().transpose(1, 2, 0)
                    img_clean = clean[i].cpu().numpy().transpose(1, 2, 0)

                    img_pred = np.clip(img_pred, 0, 1)
                    img_clean = np.clip(img_clean, 0, 1)

                    try:
                        psnr_val = psnr(img_clean, img_pred, data_range=1.0)
                        ssim_val = ssim(img_clean, img_pred, data_range=1.0,
                                       channel_axis=2, win_size=11)

                        if not np.isnan(psnr_val) and not np.isinf(psnr_val):
                            total_psnr += psnr_val
                            total_ssim += ssim_val
                            valid_samples += 1
                    except:
                        continue

            except Exception as e:
                continue

    avg_loss = total_loss / max(len(dataloader), 1)
    avg_psnr = total_psnr / max(valid_samples, 1)
    avg_ssim = total_ssim / max(valid_samples, 1)

    return avg_loss, avg_psnr, avg_ssim



print("\n" + "="*70)
print("ENHANCED ResNet Image Restoration Training")
print("="*70)


DEGRADATION_TYPE = 'denoise'
BATCH_SIZE = 4 if device.type == 'cpu' else 8
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
SUBSET_SIZE = 500

print(f"\nConfiguration:")
print(f"  Device: {device}")
print(f"  Degradation: {DEGRADATION_TYPE}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Dataset Size: {SUBSET_SIZE}")
print("="*70)

# Load dataset
print("\nLoading HQ-50K dataset...")
try:
    dataset = load_dataset("YangQiee/HQ-50K", split="train")
    print(f"Dataset loaded: {len(dataset)} total images")
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please check your internet connection and try again.")
    raise

# Split dataset
train_size = int(0.8 * min(len(dataset), SUBSET_SIZE * 1.25))
val_size = int(0.2 * min(len(dataset), SUBSET_SIZE * 1.25))

all_indices = list(range(len(dataset)))
random.shuffle(all_indices)
train_indices = all_indices[:train_size]
val_indices = all_indices[train_size:train_size + val_size]

# Create datasets
train_dataset = RobustImageRestorationDataset(
    dataset.select(train_indices),
    degradation_type=DEGRADATION_TYPE,
    subset_size=int(SUBSET_SIZE * 0.8)
)

val_dataset = RobustImageRestorationDataset(
    dataset.select(val_indices),
    degradation_type=DEGRADATION_TYPE,
    subset_size=int(SUBSET_SIZE * 0.2)
)

# Create dataloaders - NO timeout when num_workers=0
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

# Initialize model
print("\nInitializing Enhanced ResNet model...")
model = EnhancedResNet(num_blocks=12, features=64).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")

# Loss and optimizer
criterion = CombinedLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=1e-4
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=NUM_EPOCHS, eta_min=1e-7
)

# Training loop
print("\n" + "="*70)
print("Starting Training")
print("="*70 + "\n")

best_psnr = 0
best_ssim = 0
patience = 5
patience_counter = 0
train_losses, val_losses, val_psnrs, val_ssims = [], [], [], []

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 70)


    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)


    val_loss, val_psnr, val_ssim = validate(model, val_loader, criterion, device)


    scheduler.step()


    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)
    val_ssims.append(val_ssim)


    current_lr = optimizer.param_groups[0]['lr']
    print(f"LR: {current_lr:.6f}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f} | PSNR: {val_psnr:.2f} dB | SSIM: {val_ssim:.4f}")


    if val_psnr > best_psnr:
        improvement = val_psnr - best_psnr
        best_psnr = val_psnr
        best_ssim = val_ssim
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'psnr': best_psnr,
            'ssim': best_ssim,
        }, 'best_enhanced_resnet.pth')
        print(f"✓ Saved best model (PSNR: {best_psnr:.2f} dB, SSIM: {best_ssim:.4f})")
    else:
        patience_counter += 1
        print(f"⚠ No improvement for {patience_counter}/{patience} epochs")

        if patience_counter >= patience:
            print(f"\ Early stopping at epoch {epoch + 1}")
            break

    print()

print("\n" + "="*70)
print("Training Completed!")
print("="*70)
print(f"Best PSNR: {best_psnr:.2f} dB")
print(f"Best SSIM: {best_ssim:.4f}")
print(f"Total Epochs: {len(train_losses)}")
print("="*70)


fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(train_losses, label='Train Loss', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Progress: Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(val_psnrs, color='purple', linewidth=2, marker='o')
axes[1].axhline(y=best_psnr, color='r', linestyle='--', alpha=0.5)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('PSNR (dB)')
axes[1].set_title('Validation PSNR')
axes[1].grid(True, alpha=0.3)

axes[2].plot(val_ssims, color='green', linewidth=2, marker='s')
axes[2].axhline(y=best_ssim, color='r', linestyle='--', alpha=0.5)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('SSIM')
axes[2].set_title('Validation SSIM')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()


checkpoint = torch.load('best_enhanced_resnet.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


num_samples = 4
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

print("\nTesting on sample images...")
with torch.no_grad():
    for i in range(num_samples):
        try:
            degraded, clean = val_dataset[i]
            degraded_input = degraded.unsqueeze(0).to(device)
            restored = model(degraded_input).squeeze(0).cpu()

            degraded_np = np.clip(degraded.numpy().transpose(1, 2, 0), 0, 1)
            clean_np = np.clip(clean.numpy().transpose(1, 2, 0), 0, 1)
            restored_np = np.clip(restored.numpy().transpose(1, 2, 0), 0, 1)

            psnr_degraded = psnr(clean_np, degraded_np, data_range=1.0)
            psnr_restored = psnr(clean_np, restored_np, data_range=1.0)
            ssim_degraded = ssim(clean_np, degraded_np, data_range=1.0, channel_axis=2)
            ssim_restored = ssim(clean_np, restored_np, data_range=1.0, channel_axis=2)

            axes[i, 0].imshow(degraded_np)
            axes[i, 0].set_title(f'Degraded\nPSNR: {psnr_degraded:.2f} dB | SSIM: {ssim_degraded:.3f}')
            axes[i, 0].axis('off')

            axes[i, 1].imshow(restored_np)
            axes[i, 1].set_title(f'Restored\nPSNR: {psnr_restored:.2f} dB (+{psnr_restored-psnr_degraded:.2f}) | SSIM: {ssim_restored:.3f}',
                                color='green', fontweight='bold')
            axes[i, 1].axis('off')

            axes[i, 2].imshow(clean_np)
            axes[i, 2].set_title('Ground Truth')
            axes[i, 2].axis('off')
        except:
            continue

plt.tight_layout()
plt.savefig('test_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("="*70)
print(f"✓ Best model saved: best_enhanced_resnet.pth")
print(f"✓ Training curves saved: training_curves.png")
print(f"✓ Test results saved: test_results.png")
print(f"✓ Final PSNR: {best_psnr:.2f} dB")
print(f"✓ Final SSIM: {best_ssim:.4f}")
print("="*70)