In [None]:
import sys
import subprocess
packages = ['torch', 'numpy', 'matplotlib', 'pillow', 'scikit-learn', 'pandas', 'torchvision', 'scipy']
for package in packages:
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
        print(f"✓ {package} installed")
    except Exception as e:
        print(f"✗ {package} failed: {e}")

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from glob import glob
from scipy.ndimage import convolve

print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
IMG_SIZE = 96

def create_motion_blur_kernel(size=15, angle=45):
    kernel = np.zeros((size, size))
    center = size // 2
    angle_rad = np.deg2rad(angle)
    for i in range(size):
        offset = i - center
        x = int(center + offset * np.cos(angle_rad))
        y = int(center + offset * np.sin(angle_rad))
        if 0 <= x < size and 0 <= y < size:
            kernel[y, x] = 1
    kernel = kernel / (kernel.sum() + 1e-8)
    return kernel

def apply_motion_blur(image, kernel_size=15, angle=None):
    if angle is None:
        angle = np.random.uniform(0, 180)
    kernel = create_motion_blur_kernel(kernel_size, angle)
    
    if len(image.shape) == 3:
        blurred = np.zeros_like(image)
        for c in range(3):
            blurred[:, :, c] = convolve(image[:, :, c], kernel, mode='reflect')
    else:
        blurred = convolve(image, kernel, mode='reflect')
    
    return np.clip(blurred, 0, 1).astype(np.float32)


class STL10DeblurDataset(Dataset):
    def __init__(self, split='train', img_size=96, blur_kernel_size=11):
        self.img_size = img_size
        self.blur_kernel_size = blur_kernel_size
        
        print(f"Downloading STL-10 {split} set...")
        stl10 = torchvision.datasets.STL10(
            root='./data', 
            split='train' if split == 'train' else 'test',
            download=True,
            transform=transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor()
            ])
        )
        
        self.images = []
        for img, _ in stl10:
            self.images.append(img.permute(1, 2, 0).numpy().astype(np.float32))
        
        self.images = np.array(self.images)
        print(f"Loaded {len(self.images)} images from STL-10 {split} set ({img_size}x{img_size})")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        sharp = self.images[idx]
        
        angle = np.random.uniform(0, 180)
        kernel_size = np.random.randint(7, self.blur_kernel_size + 1)
        blur = apply_motion_blur(sharp, kernel_size=kernel_size, angle=angle)
        
        sharp_tensor = torch.from_numpy(sharp).permute(2, 0, 1).float()
        blur_tensor = torch.from_numpy(blur).permute(2, 0, 1).float()
        
        sharp_tensor = sharp_tensor * 2 - 1
        blur_tensor = blur_tensor * 2 - 1
        
        return blur_tensor, sharp_tensor


print("Loading STL-10 dataset (96x96 high-quality images with synthetic motion blur)...")
train_dataset = STL10DeblurDataset(split='train', img_size=IMG_SIZE, blur_kernel_size=15)
test_dataset = STL10DeblurDataset(split='test', img_size=IMG_SIZE, blur_kernel_size=15)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE//2, shuffle=False)

print(f"\nTraining samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch size: {BATCH_SIZE}")

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(5):
    blur, sharp = train_dataset[i]
    axes[0, i].imshow(((sharp + 1) / 2).permute(1, 2, 0).numpy().clip(0, 1))
    axes[0, i].set_title('Sharp')
    axes[0, i].axis('off')
    axes[1, i].imshow(((blur + 1) / 2).permute(1, 2, 0).numpy().clip(0, 1))
    axes[1, i].set_title('Motion Blurred')
    axes[1, i].axis('off')
plt.suptitle('STL-10 Training Pairs (Sharp vs Motion Blurred)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, k=4, s=2, p=1, norm=True, act=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, k, s, p, bias=not norm)]
        if norm:
            layers.append(nn.InstanceNorm2d(out_c))
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        if act:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)


class DeconvBlock(nn.Module):
    def __init__(self, in_c, out_c, k=4, s=2, p=1, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_c, out_c, k, s, p),
            nn.InstanceNorm2d(out_c),
        ]
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        layers.append(nn.ReLU(inplace=True))
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)


class Generator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, 64, norm=False)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512, norm=False)
        
        self.dec1 = DeconvBlock(512, 256, dropout=0.5)
        self.dec2 = DeconvBlock(512, 128, dropout=0.5)
        self.dec3 = DeconvBlock(256, 64)
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(128, in_channels, 4, 2, 1),
            nn.Tanh()
        )
    
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        
        d1 = self.dec1(e4)
        d1 = F.interpolate(d1, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d1, e3], dim=1))
        d2 = F.interpolate(d2, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d2, e2], dim=1))
        d3 = F.interpolate(d3, size=e1.shape[2:], mode='bilinear', align_corners=False)
        out = self.dec4(torch.cat([d3, e1], dim=1))
        
        out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=False)
        return out


G = Generator()
print("Generator Architecture (U-Net with Skip Connections):")
print(f"Total parameters: {sum(p.numel() for p in G.parameters()):,}")

test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE)
test_output = G(test_input)
print(f"Input shape: {test_input.shape} -> Output shape: {test_output.shape}")

In [None]:
class PatchDiscriminatorNet(nn.Module):
    def __init__(self, in_c=6):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_c, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 4, 1, 1)
        )
    
    def forward(self, x, y):
        combined = torch.cat([x, y], dim=1)
        return self.model(combined)


class MultiScalePatchDiscriminator(nn.Module):
    def __init__(self, in_c=6):
        super().__init__()
        self.d1 = PatchDiscriminatorNet(in_c)
        self.d2 = PatchDiscriminatorNet(in_c)
    
    def forward(self, x_blur, x_sharp):
        out1 = self.d1(x_blur, x_sharp)
        
        blur_ds = F.interpolate(x_blur, scale_factor=0.5, mode='bilinear', align_corners=False)
        sharp_ds = F.interpolate(x_sharp, scale_factor=0.5, mode='bilinear', align_corners=False)
        out2 = self.d2(blur_ds, sharp_ds)
        
        return out1, out2


D = MultiScalePatchDiscriminator()
print("Multi-Scale PatchGAN Discriminator:")
print(f"Total parameters: {sum(p.numel() for p in D.parameters()):,}")

In [None]:
def gan_loss(pred, target):
    return torch.mean((pred - target) ** 2)


class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features[:16]
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
    
    def forward(self, x, y):
        x = (x + 1) / 2
        y = (y + 1) / 2
        return torch.mean((self.vgg(x) - self.vgg(y)) ** 2)


def psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return torch.tensor(100.0)
    pixel_max = 1.0
    return 20 * torch.log10(pixel_max / torch.sqrt(mse))


print("Loss Functions initialized: GAN Loss, VGG Perceptual Loss, L1 Loss, PSNR")

In [None]:
G = Generator().to(device)
D = MultiScalePatchDiscriminator().to(device)
vgg_loss = VGGPerceptualLoss().to(device)

lr_G = 1e-4
lr_D = 1e-4
opt_G = torch.optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

scheduler_G = torch.optim.lr_scheduler.StepLR(opt_G, step_size=50, gamma=0.5)
scheduler_D = torch.optim.lr_scheduler.StepLR(opt_D, step_size=50, gamma=0.5)

l1_criterion = nn.L1Loss()

print("Models initialized:")
print(f"Generator params: {sum(p.numel() for p in G.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in D.parameters()):,}")
print(f"Learning rates - G: {lr_G}, D: {lr_D}")

In [None]:
num_epochs = 100

history = {'d_loss': [], 'g_loss': [], 'psnr': [], 'l1': []}

lambda_l1 = 100.0
lambda_vgg = 0.01
lambda_gan = 1.0

print(f"Loss weights: L1={lambda_l1}, VGG={lambda_vgg}, GAN={lambda_gan}")
print(f"Training for {num_epochs} epochs...")
print("-" * 60)

for epoch in range(num_epochs):
    G.train()
    D.train()
    
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0
    epoch_psnr = 0.0
    epoch_l1 = 0.0
    
    for blur, sharp in train_loader:
        blur = blur.to(device)
        sharp = sharp.to(device)
        
        opt_D.zero_grad()
        
        with torch.no_grad():
            fake_sharp = G(blur)
        
        real_out1, real_out2 = D(blur, sharp)
        fake_out1, fake_out2 = D(blur, fake_sharp)
        
        real_label1 = torch.ones_like(real_out1).to(device) * 0.9
        real_label2 = torch.ones_like(real_out2).to(device) * 0.9
        fake_label1 = torch.zeros_like(fake_out1).to(device) + 0.1
        fake_label2 = torch.zeros_like(fake_out2).to(device) + 0.1
        
        d_loss_real = gan_loss(real_out1, real_label1) + gan_loss(real_out2, real_label2)
        d_loss_fake = gan_loss(fake_out1, fake_label1) + gan_loss(fake_out2, fake_label2)
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        
        d_loss.backward()
        torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=1.0)
        opt_D.step()
        
        opt_G.zero_grad()
        
        fake_sharp = G(blur)
        
        fake_out1, fake_out2 = D(blur, fake_sharp)
        target_real1 = torch.ones_like(fake_out1).to(device)
        target_real2 = torch.ones_like(fake_out2).to(device)
        
        g_gan = gan_loss(fake_out1, target_real1) + gan_loss(fake_out2, target_real2)
        g_vgg = vgg_loss(fake_sharp, sharp)
        g_l1 = l1_criterion(fake_sharp, sharp)
        
        g_loss = lambda_gan * g_gan + lambda_vgg * g_vgg + lambda_l1 * g_l1
        
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)
        opt_G.step()
        
        with torch.no_grad():
            batch_psnr = psnr((fake_sharp + 1) / 2, (sharp + 1) / 2)
        
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()
        epoch_psnr += batch_psnr.item()
        epoch_l1 += g_l1.item()
    
    scheduler_G.step()
    scheduler_D.step()
    
    n_batches = len(train_loader)
    epoch_d_loss /= n_batches
    epoch_g_loss /= n_batches
    epoch_psnr /= n_batches
    epoch_l1 /= n_batches
    
    history['d_loss'].append(epoch_d_loss)
    history['g_loss'].append(epoch_g_loss)
    history['psnr'].append(epoch_psnr)
    history['l1'].append(epoch_l1)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1:3d}/{num_epochs}] D_loss: {epoch_d_loss:.4f} | G_loss: {epoch_g_loss:.4f} | L1: {epoch_l1:.4f} | PSNR: {epoch_psnr:.2f} dB")

print("-" * 60)
print("Training complete!")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

axes[0, 0].plot(history['d_loss'], label='Discriminator')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Discriminator Loss')
axes[0, 0].grid(True)

axes[0, 1].plot(history['g_loss'], label='Generator', color='orange')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Generator Loss')
axes[0, 1].grid(True)

axes[1, 0].plot(history['l1'], label='L1', color='red')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('L1 Loss')
axes[1, 0].set_title('L1 Reconstruction Loss')
axes[1, 0].grid(True)

axes[1, 1].plot(history['psnr'], label='PSNR', color='green')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('PSNR (dB)')
axes[1, 1].set_title('Training PSNR')
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
G.eval()

test_results = []
all_psnr_blur = []
all_psnr_deblur = []

with torch.no_grad():
    for i in range(min(10, len(test_dataset))):
        blur, sharp = test_dataset[i]
        blur = blur.unsqueeze(0).to(device)
        sharp = sharp.unsqueeze(0).to(device)
        
        fake_sharp = G(blur)
        
        blur_01 = (blur + 1) / 2
        sharp_01 = (sharp + 1) / 2
        fake_01 = (fake_sharp.squeeze(0) + 1) / 2
        
        psnr_blur = psnr(blur_01, sharp_01).item()
        psnr_deblur = psnr(fake_01, sharp_01).item()
        
        all_psnr_blur.append(psnr_blur)
        all_psnr_deblur.append(psnr_deblur)
        
        test_results.append({
            'Image': i + 1,
            'PSNR Blurred': f'{psnr_blur:.2f}',
            'PSNR Deblurred': f'{psnr_deblur:.2f}',
            'Improvement': f'{psnr_deblur - psnr_blur:+.2f}'
        })

results_df = pd.DataFrame(test_results)
print("Test Results:")
print(results_df.to_string(index=False))
print(f"\n{'='*50}")
print(f"Average PSNR (Blurred):   {np.mean(all_psnr_blur):.2f} dB")
print(f"Average PSNR (Deblurred): {np.mean(all_psnr_deblur):.2f} dB")
print(f"Average Improvement:      {np.mean(all_psnr_deblur) - np.mean(all_psnr_blur):+.2f} dB")

In [None]:
G.eval()

n_samples = min(5, len(test_dataset))
fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 9))

with torch.no_grad():
    for i in range(n_samples):
        blur, sharp = test_dataset[i]
        blur_input = blur.unsqueeze(0).to(device)
        
        fake_sharp = G(blur_input)
        
        blur_np = ((blur + 1) / 2).permute(1, 2, 0).cpu().numpy().clip(0, 1)
        fake_np = ((fake_sharp.squeeze(0) + 1) / 2).permute(1, 2, 0).cpu().numpy().clip(0, 1)
        sharp_np = ((sharp + 1) / 2).permute(1, 2, 0).cpu().numpy().clip(0, 1)
        
        axes[0, i].imshow(blur_np)
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Blurred', fontsize=12)
        
        axes[1, i].imshow(fake_np)
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel('Deblurred', fontsize=12)
        
        axes[2, i].imshow(sharp_np)
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_ylabel('Ground Truth', fontsize=12)

plt.suptitle('DeblurGAN Results: Blurred → Deblurred → Ground Truth', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
save_path = 'deblur_gan_model.pth'
torch.save({
    'generator': G.state_dict(),
    'discriminator': D.state_dict(),
    'history': history,
    'config': {
        'img_size': IMG_SIZE,
        'lambda_l1': lambda_l1,
        'lambda_vgg': lambda_vgg,
        'lambda_gan': lambda_gan,
    }
}, save_path)
print(f"Model saved to {save_path}")
print(f"Final PSNR: {history['psnr'][-1]:.2f} dB")