In [2]:
pip install numpy matplotlib pillow torch torchvision scikit-image

[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import os
import math
import glob
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Set non-interactive backend
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
import random
import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
import torchvision.models as models
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import skimage.metrics
from torchvision.models import vgg19, VGG19_Weights

print("[INFO] Libraries imported.")

# LPIPS for evaluation
class LPIPS(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(weights=VGG19_Weights.DEFAULT).features[:30].eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        self.weights = nn.Parameter(torch.ones(5))
        self.layers = [0, 5, 10, 19, 28]
    
    def _normalize(self, x):
        return (x - self.mean) / self.std
    
    def forward(self, x, y):
        features_x = []
        features_y = []
        x = self._normalize(x)
        y = self._normalize(y)
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)
            if i in self.layers:
                features_x.append(x)
                features_y.append(y)
        dists = [F.l1_loss(fx, fy) for fx, fy in zip(features_x, features_y)]
        weighted_dist = sum(w * d for w, d in zip(self.weights, dists))
        return weighted_dist

# LayerNorm2d
class LayerNorm2d(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        return x

# SpatialResidualModule
class SpatialResidualModule(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
        self.spatial_att = nn.Sequential(
            nn.Conv2d(channels, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.scale = nn.Parameter(torch.FloatTensor([0.1]))

    def forward(self, x):
        residual = x
        out = self.conv(x)
        att = self.spatial_att(out)
        return residual + self.scale * (out * att)

# EnhancedResidualBlock
class EnhancedResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = LayerNorm2d(channels)
        self.norm2 = LayerNorm2d(channels)
        self.act = nn.PReLU()
        self.spatial_residual = SpatialResidualModule(channels)
        self.scale = nn.Parameter(torch.FloatTensor([0.1]))

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.act(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.spatial_residual(out)
        return residual + self.scale * out

# EnhancedResidualGroup
class EnhancedResidualGroup(nn.Module):
    def __init__(self, channels, n_blocks):
        super().__init__()
        blocks = [EnhancedResidualBlock(channels) for _ in range(n_blocks)]
        self.body = nn.Sequential(*blocks)
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.scale = nn.Parameter(torch.FloatTensor([0.1]))

    def forward(self, x):
        residual = x
        out = self.body(x)
        out = self.conv(out)
        return residual + self.scale * out

# EnhancedESPCN
class EnhancedESPCN(nn.Module):
    def __init__(self, in_channels, scale_factor=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(64, 3 * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
    
    def forward(self, x):
        return self.net(x)

# UltraEnhancedSR
class UltraEnhancedSR(nn.Module):
    def __init__(self, scale=2):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(3, 128, 3, padding=1),
            nn.PReLU()
        )
        self.body = nn.ModuleList([
            EnhancedResidualGroup(128, 10) for _ in range(5)
        ])
        self.global_residual = nn.Conv2d(128, 128, 3, padding=1)
        self.upscale = EnhancedESPCN(128, scale)
        self.direct_path = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.PReLU(),
            nn.Conv2d(16, 3*(scale**2), 3, padding=1),
            nn.PixelShuffle(scale)
        )
        self.refine = nn.Sequential(
            nn.Conv2d(6, 32, 3, padding=1),
            nn.PReLU(),
            nn.Conv2d(32, 3, 3, padding=1)
        )

    def forward(self, x):
        direct = self.direct_path(x)
        shallow = self.head(x)
        deep = shallow
        for block in self.body:
            deep = block(deep)
        global_res = self.global_residual(deep)
        fused = shallow + global_res
        upscaled = self.upscale(fused)
        combined = torch.cat([direct, upscaled], dim=1)
        return self.refine(combined)

# RobustSRDataset
class RobustSRDataset(Dataset):
    def __init__(self, hr_dir, scale=2, augment=True):
        self.hr_paths = sorted(glob.glob(os.path.join(hr_dir, '*')))
        if not self.hr_paths:
            raise ValueError(f"No files found in directory: {hr_dir}. Please check the path or ensure the directory contains images.")
        self.scale = scale
        self.augment = augment
        self.color_jitter = transforms.ColorJitter(0.15, 0.15, 0.15, 0.05)
        self.blur_kernels = [
            ImageFilter.GaussianBlur(0.5),
            ImageFilter.GaussianBlur(0.8),
            ImageFilter.GaussianBlur(1.0),
            ImageFilter.BoxBlur(1),
            ImageFilter.BoxBlur(2)
        ]

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        hr_path = self.hr_paths[idx]
        try:
            hr = Image.open(hr_path).convert('RGB')
            if hr.width < 1024 or hr.height < 1024:
                hr = TF.resize(hr, [1024, 1024])
            hr = TF.resize(hr, [1024, 1024])
            if self.augment:
                blur_kernel = random.choice(self.blur_kernels)
                hr_blurred = hr.filter(blur_kernel)
                if random.random() > 0.5:
                    hr_blurred = self.color_jitter(hr_blurred)
            else:
                hr_blurred = hr
            lr_size = 1024 // self.scale
            lr = hr_blurred.resize((lr_size, lr_size), Image.BICUBIC)
            lr_np = np.array(lr).astype(np.float32) / 255.0
            noise_level = np.random.uniform(0.005, 0.015)
            noise = np.random.normal(0, noise_level, lr_np.shape) * np.sqrt(lr_np + 0.001)
            lr_np = np.clip(lr_np + noise, 0, 1)
            lr = Image.fromarray((lr_np * 255).astype(np.uint8))
            if self.augment:
                if random.random() > 0.5:
                    lr, hr = TF.hflip(lr), TF.hflip(hr)
                if random.random() > 0.5:
                    lr, hr = TF.vflip(lr), TF.vflip(hr)
                if random.random() > 0.5:
                    angle = random.choice([90, 180, 270])
                    lr, hr = TF.rotate(lr, angle), TF.rotate(hr, angle)
            return TF.to_tensor(lr), TF.to_tensor(hr), os.path.basename(hr_path)
        except Exception as e:
            print(f"Error loading {hr_path}: {e}")
            lr = torch.zeros(3, 1024//self.scale, 1024//self.scale)
            hr = torch.zeros(3, 1024, 1024)
            return lr, hr, "blank_fallback"

# Setup Training
def setup_advanced_training():
    print("[INFO] Setting up advanced training...")
    os.makedirs("results", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    # Validate dataset directories
    train_dir = 'DF2K_train_HR'
    val_dir = 'DF2K_valid_HR'
    if not os.path.exists(train_dir):
        raise ValueError(f"Training directory {train_dir} does not exist.")
    if not os.path.exists(val_dir):
        raise ValueError(f"Validation directory {val_dir} does not exist.")

    train_dataset = RobustSRDataset(train_dir, scale=2)
    val_dataset = RobustSRDataset(val_dir, scale=2, augment=False)

    print(f"[INFO] Loaded {len(train_dataset)} training images and {len(val_dataset)} validation images.")

    train_loader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    model = UltraEnhancedSR(scale=2).cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        print(f"[INFO] Using {torch.cuda.device_count()} GPUs")
    criterion = nn.L1Loss().cuda()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=2e-4,
        weight_decay=1e-4,
        betas=(0.9, 0.999)
    )
    scaler = GradScaler()
    def lr_lambda(epoch):
        if epoch < 5:
            return (epoch + 1) / 5
        else:
            return 0.5 * (1 + math.cos(math.pi * (epoch - 5) / (150 - 5)))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    print("[INFO] Advanced training setup complete.")
    return train_loader, val_loader, model, criterion, optimizer, scheduler, scaler

# Calculate Metrics
def calculate_metrics(sr, hr):
    if not hasattr(calculate_metrics, "lpips_model"):
        calculate_metrics.lpips_model = LPIPS().cuda()
    sr_np = sr.squeeze(0).detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy()
    hr_np = hr.squeeze(0).detach().cpu().permute(1, 2, 0).numpy()
    psnr = skimage.metrics.peak_signal_noise_ratio(hr_np, sr_np, data_range=1.0)
    ssim = skimage.metrics.structural_similarity(
        hr_np, sr_np,
        win_size=11,
        multichannel=True,
        channel_axis=2,
        data_range=1.0,
        gaussian_weights=True
    )
    with torch.no_grad():
        lpips = calculate_metrics.lpips_model(sr, hr).item()
    return psnr, ssim, lpips

# Train Model
def train_advanced_model(train_loader, val_loader, model, criterion, optimizer, scheduler, scaler, epochs=150):
    print("[INFO] Starting advanced training...")
    best_psnr = 0
    best_epoch = 0
    history = {
        'train_loss': [], 'val_loss': [],
        'psnr': [], 'ssim': [], 'lpips': [],
        'best_psnr': 0, 'best_epoch': 0
    }
    start_time = time.time()
    accum_steps = 8
    for epoch in range(epochs):
        epoch_start_time = time.time()
        current_lr = optimizer.param_groups[0]['lr']
        model.train()
        train_loss = 0
        batch_count = 0
        optimizer.zero_grad()
        for i, (lr, hr, _) in enumerate(train_loader):
            lr, hr = lr.cuda(), hr.cuda()
            with autocast():
                sr = model(lr)
                loss = criterion(sr, hr)
            scaler.scale(loss).backward()
            if (i + 1) % accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            train_loss += loss.item()
            batch_count += 1
            if (i + 1) % 50 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_loader)} | Loss: {loss.item():.4f}")
        model.eval()
        val_loss = 0
        psnr_total = 0
        ssim_total = 0
        lpips_total = 0
        with torch.no_grad():
            for lr, hr, fname in val_loader:
                lr, hr = lr.cuda(), hr.cuda()
                with autocast():
                    sr = model(lr)
                    val_loss_batch = criterion(sr, hr)
                val_loss += val_loss_batch.item()
                psnr, ssim, lpips = calculate_metrics(sr, hr)
                psnr_total += psnr
                ssim_total += ssim
                lpips_total += lpips
                if fname[0] in [val_loader.dataset.hr_paths[i].split('/')[-1] for i in range(min(3, len(val_loader.dataset)))]:
                    save_comparison(lr, sr, hr, f"epoch_{epoch+1}_{fname[0]}")
        avg_train_loss = train_loss / batch_count
        avg_val_loss = val_loss / len(val_loader)
        avg_psnr = psnr_total / len(val_loader)
        avg_ssim = ssim_total / len(val_loader)
        avg_lpips = lpips_total / len(val_loader)
        scheduler.step()
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['psnr'].append(avg_psnr)
        history['ssim'].append(avg_ssim)
        history['lpips'].append(avg_lpips)
        epoch_time = time.time() - epoch_start_time
        total_time = time.time() - start_time
        print(f"\nEpoch {epoch+1}/{epochs} Summary:")
        print(f"Time: {epoch_time:.2f}s | Total: {str(datetime.timedelta(seconds=int(total_time)))}")
        print(f"LR: {current_lr:.8f}")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}")
        print(f"PSNR: {avg_psnr:.2f} | SSIM: {avg_ssim:.4f} | LPIPS: {avg_lpips:.4f}")
        checkpoint_path = f"checkpoints/model_epoch_{epoch+1}.pth"
        if isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), checkpoint_path)
        else:
            torch.save(model.state_dict(), checkpoint_path)
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            best_epoch = epoch + 1
            history['best_psnr'] = best_psnr
            history['best_epoch'] = best_epoch
            best_model_path = f"checkpoints/best_model_epoch_{epoch+1}_psnr_{avg_psnr:.2f}.pth"
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), best_model_path)
            else:
                torch.save(model.state_dict(), best_model_path)
            print(f"[INFO] New best model saved with PSNR: {avg_psnr:.2f}")
        np.save("history.npy", history)
        plot_training_curves(history, epoch+1)
        print(f"Best PSNR so far: {best_psnr:.2f} at epoch {best_epoch}")
        print("-" * 80)
    final_model_path = "checkpoints/final_model.pth"
    if isinstance(model, nn.DataParallel):
        torch.save(model.module.state_dict(), final_model_path)
    else:
        torch.save(model.state_dict(), final_model_path)
    print(f"[INFO] Training complete. Best PSNR: {best_psnr:.2f} at epoch {best_epoch}")
    print(f"[INFO] Final model saved as '{final_model_path}'")
    return model, history

# Save Comparison
def save_comparison(lr, sr, hr, filename):
    os.makedirs("results", exist_ok=True)
    lr_img = lr.squeeze(0).cpu().permute(1, 2, 0).numpy().astype(np.float32)
    sr_img = sr.squeeze(0).cpu().clamp(0, 1).permute(1, 2, 0).numpy().astype(np.float32)
    hr_img = hr.squeeze(0).cpu().permute(1, 2, 0).numpy().astype(np.float32)
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(lr_img)
    axes[0].set_title('Low Resolution')
    axes[0].axis('off')
    axes[1].imshow(sr_img)
    axes[1].set_title('Super Resolution')
    axes[1].axis('off')
    axes[2].imshow(hr_img)
    axes[2].set_title('High Resolution (Ground Truth)')
    axes[2].axis('off')
    plt.tight_layout()
    plt.savefig(f"results/{filename}.png", dpi=300, bbox_inches='tight')
    plt.close()

# Plot Training Curves
def plot_training_curves(history, current_epoch):
    os.makedirs("results", exist_ok=True)
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes[0, 0].plot(history['train_loss'], label='Train Loss')
    axes[0, 0].plot(history['val_loss'], label='Validation Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    axes[0, 1].plot(history['psnr'])
    axes[0, 1].axhline(y=32, color='r', linestyle='--', label='Target PSNR (32dB)')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('PSNR (dB)')
    axes[0, 1].set_title('PSNR Progress')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    axes[1, 0].plot(history['ssim'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('SSIM')
    axes[1, 0].set_title('SSIM Progress')
    axes[1, 0].grid(True)
    axes[1, 1].plot(history['lpips'])
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('LPIPS')
    axes[1, 1].set_title('LPIPS Progress (Lower is Better)')
    axes[1, 1].grid(True)
    fig.suptitle(f'Training Progress (Epoch {current_epoch})', fontsize=16)
    plt.tight_layout()
    plt.savefig(f"results/training_curves_epoch_{current_epoch}.png", dpi=300, bbox_inches='tight')
    plt.close()

# Main Execution
if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.cuda.empty_cache()
    train_loader, val_loader, model, criterion, optimizer, scheduler, scaler = setup_advanced_training()
    trained_model, history = train_advanced_model(
        train_loader,
        val_loader,
        model,
        criterion,
        optimizer,
        scheduler,
        scaler,
        epochs=150
    )
    plot_training_curves(history, len(history['train_loss']))
    print("[INFO] Training completed successfully!")

[INFO] Libraries imported.
[INFO] Setting up advanced training...
[INFO] Loaded 3450 training images and 100 validation images.
[INFO] Using 2 GPUs
[INFO] Advanced training setup complete.
[INFO] Starting advanced training...
Epoch 1/150 | Batch 50/3450 | Loss: 0.5687
Epoch 1/150 | Batch 100/3450 | Loss: 0.5586
Epoch 1/150 | Batch 150/3450 | Loss: 0.3954
Epoch 1/150 | Batch 200/3450 | Loss: 0.2434
Epoch 1/150 | Batch 250/3450 | Loss: 0.2952
Epoch 1/150 | Batch 300/3450 | Loss: 0.2663
Epoch 1/150 | Batch 350/3450 | Loss: 0.1907
Epoch 1/150 | Batch 400/3450 | Loss: 0.1639
Epoch 1/150 | Batch 450/3450 | Loss: 0.1559
Epoch 1/150 | Batch 500/3450 | Loss: 0.1640
Epoch 1/150 | Batch 550/3450 | Loss: 0.1106
Epoch 1/150 | Batch 600/3450 | Loss: 0.1088
Epoch 1/150 | Batch 650/3450 | Loss: 0.0998
Epoch 1/150 | Batch 700/3450 | Loss: 0.0862
Epoch 1/150 | Batch 750/3450 | Loss: 0.1430
Epoch 1/150 | Batch 800/3450 | Loss: 0.1092
Epoch 1/150 | Batch 850/3450 | Loss: 0.0807
Epoch 1/150 | Batch 900/345

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:09<00:00, 60.5MB/s] 



Epoch 1/150 Summary:
Time: 5103.60s | Total: 1:25:03
LR: 0.00004000
Train Loss: 0.1081
Val Loss: 0.0500
PSNR: 22.91 | SSIM: 0.7033 | LPIPS: 3.7862
[INFO] New best model saved with PSNR: 22.91
Best PSNR so far: 22.91 at epoch 1
--------------------------------------------------------------------------------
Epoch 2/150 | Batch 50/3450 | Loss: 0.0463
Epoch 2/150 | Batch 100/3450 | Loss: 0.0745
Epoch 2/150 | Batch 150/3450 | Loss: 0.0411
Epoch 2/150 | Batch 200/3450 | Loss: 0.0518
Epoch 2/150 | Batch 250/3450 | Loss: 0.0567
Epoch 2/150 | Batch 300/3450 | Loss: 0.0668
Epoch 2/150 | Batch 350/3450 | Loss: 0.0662
Epoch 2/150 | Batch 400/3450 | Loss: 0.0483
Epoch 2/150 | Batch 450/3450 | Loss: 0.0256
Epoch 2/150 | Batch 500/3450 | Loss: 0.0539
Epoch 2/150 | Batch 550/3450 | Loss: 0.0452
Epoch 2/150 | Batch 600/3450 | Loss: 0.0335
Epoch 2/150 | Batch 650/3450 | Loss: 0.0410
Epoch 2/150 | Batch 700/3450 | Loss: 0.0327
Epoch 2/150 | Batch 750/3450 | Loss: 0.0397
Epoch 2/150 | Batch 800/3450 | L