In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import wandb
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
import torchvision
from torchvision import transforms
from random import choice, shuffle
import math
import cv2
import os
# from tqdm.notebook import tqdm
from tqdm import tqdm
from matplotlib import pyplot as plt
import torch.nn.functional as F
import numpy as np

import warnings
warnings.filterwarnings('ignore')


In [None]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# wandb_api_key = user_secrets.get_secret("wandb_api_key")
# wandb.login(key=wandb_api_key)

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

torch.__version__, torchvision.__version__, DEVICE

In [None]:
BASE_DIR  = "dataset/flickr30k_images"
IMG_DIMENSIONS = (128, 128)
KERNEL_SIZES = (3, 5, 7)

In [None]:
all_images = os.listdir(BASE_DIR) [:100] # TODO: CHANGE
all_images = [os.path.join(BASE_DIR, x) for x in all_images]
len(all_images)

test_image = cv2.imread(all_images[2])
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)

plt.imshow(test_image), test_image.shape

In [None]:
class Flickr30kDataset(Dataset):
    def __init__(self, img_paths=all_images, train=True, test=False, blur_type="gaussian",
                 cache_size=50,
                 train_ratio=0.8):
        self.img_paths = img_paths
        shuffle(self.img_paths)  # Shuffle once
        self.dataset_length = len(self.img_paths)
        self.blur = blur_type
        
        self.image_cache = {}
        self.cache_size = min(cache_size, len(img_paths))
        
        # Pre-generate blur kernels
        self._setup_blur_kernels()
        
        if train:
            self.image_paths = self.img_paths[:math.floor(train_ratio*self.dataset_length)]
        else:
            self.image_paths = self.img_paths[math.ceil(train_ratio*self.dataset_length):]
        
        self._preload_images()
        
    
    def _preload_images(self):
        """Cache frequently used images in memory"""
        for i, path in enumerate(self.image_paths[:self.cache_size]):
            img = cv2.imread(path)
            img = cv2.resize(img, IMG_DIMENSIONS)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            self.image_cache[i] = transforms.ToTensor()(img)
            
    def _setup_blur_kernels(self):
        """Pre-create blur kernels as tensors"""
        self.blur_kernels = {}
        for ksize in KERNEL_SIZES:
            if self.blur == "gaussian":
                kernel = cv2.getGaussianKernel(ksize, 0)
                kernel = np.outer(kernel, kernel.transpose())
            else:  # box blur
                kernel = np.ones((ksize, ksize)) / (ksize * ksize)
            
            self.blur_kernels[ksize] = torch.from_numpy(kernel).float()
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        # Load image on demand
        img_path = self.image_paths[index]
        base_image = cv2.imread(img_path)
        base_image = cv2.resize(base_image, IMG_DIMENSIONS)
        
        # Apply blur based on type
        ksize = KERNEL_SIZES[index%3]
        if self.blur == "gaussian":
            blur_image = cv2.GaussianBlur(base_image, (ksize, ksize), 0)
        elif self.blur == "box":
            blur_image = cv2.blur(base_image, (ksize, ksize))
        else:
            raise ValueError("Invalid blur type")
        
        # Convert to RGB
        base_image = cv2.cvtColor(base_image, cv2.COLOR_BGR2RGB)
        blur_image = cv2.cvtColor(blur_image, cv2.COLOR_BGR2RGB)
        
        # Convert to tensor
        base_tensor = transforms.ToTensor()(base_image)
        blur_tensor = transforms.ToTensor()(blur_image)
        return blur_tensor, base_tensor

In [None]:
train_dataset = Flickr30kDataset(train=True, test=False)
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=2, pin_memory=True,persistent_workers=True)

val_dataset = Flickr30kDataset(train=False, test=True)
val_loader = DataLoader(train_dataset, batch_size=4, num_workers=2, pin_memory=True,persistent_workers=True)

In [None]:

# ====================== IMPROVED DEBLURRING MODEL ======================
class SharpDeblurViT(nn.Module):
    def __init__(self, image_size=IMG_DIMENSIONS[0]):
        super().__init__()
        self.enc1 = self._make_encoder_block(3, 32, kernel_size=5, stride=1)
        self.enc2 = self._make_encoder_block(32, 64, stride=2)
        self.enc3 = self._make_encoder_block(64, 128, stride=2)
        
        # Transformer with multi-scale processing
        self.bottleneck_size = image_size // 8
        num_patches = self.bottleneck_size ** 2
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, 128))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=128, 
                nhead=8, 
                dim_feedforward=256,
                dropout=0.1, 
                activation='gelu',
                norm_first=True,
                batch_first=True
            ),
            num_layers=2
        )
        
        # Sharpness-enhancing decoder
        self.dec1 = self._make_decoder_block(128, 64, scale_factor=2)
        self.dec2 = self._make_decoder_block(64, 32, scale_factor=2)
        
        # Final reconstruction with residual connection
        self.final_conv = nn.Sequential(
            nn.Conv2d(32 + 3, 32, kernel_size=3, padding=1),  # Input + skip
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
        # Edge enhancement module
        self.edge_enhancer = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.GELU(),
            nn.Conv2d(16, 3, kernel_size=3, padding=1)
        )
    
    def _make_encoder_block(self, in_c, out_c, kernel_size=3, stride=2):
        """Use depthwise separable convolutions"""
        return nn.Sequential(
            # Depthwise convolution
            nn.Conv2d(in_c, in_c, kernel_size, stride=stride, padding=kernel_size//2, groups=in_c),
            # Pointwise convolution
            nn.Conv2d(in_c, out_c, 1),
            nn.BatchNorm2d(out_c),
            nn.GELU(),
            nn.Dropout2d(0.1)
        )
    
    def _make_decoder_block(self, in_c, out_c, scale_factor=2):
        return nn.Sequential(
            nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True),
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.GELU(),
            ResidualBlock(out_c, out_c)
        )
    
    def forward(self, x):
        # Initial features
        x0 = x  # Save input for residual connection
    
        # Encoder pathway
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        
        # Transformer processing
        B, C, H, W = e3.shape
        patches = e3.flatten(2).transpose(1, 2)  # [B, N, C]
        
        # Adapt positional embedding to match the actual number of patches
        pos_embed = self.pos_embed
        if pos_embed.shape[1] != H * W:
            # Resize positional embedding to match feature map size
            pos_embed = nn.functional.interpolate(
                pos_embed.reshape(1, int(math.sqrt(pos_embed.shape[1])), 
                                int(math.sqrt(pos_embed.shape[1])), C).permute(0, 3, 1, 2),
                size=(H, W), mode='bilinear', align_corners=True
            ).permute(0, 2, 3, 1).reshape(1, H*W, C)
        
        # Apply positional embedding
        patches = patches + pos_embed
        transformed = self.transformer(patches)
        bottleneck = transformed.transpose(1, 2).view(B, C, H, W)
        
        # Decoder pathway
        d1 = self.dec1(bottleneck)  # (B, 64, H/2, W/2)
        d2 = self.dec2(d1)          # (B, 32, H, W)
        
        # Residual connection + final reconstruction
        d2 = torch.cat([d2, x0], dim=1)  # Combine features with original input
        output = self.final_conv(d2)
        
        # Edge refinement
        edge = self.edge_enhancer(output)
        return output + 0.3 * edge  # Sharpened output


class ResidualBlock(nn.Module):
    """Residual block for better gradient flow and feature preservation"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation = nn.GELU()
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x):
        identity = self.shortcut(x)
        out = self.activation(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.activation(out)

In [None]:
class SharpnessLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=0.5, gamma=0.2):
        super().__init__()
        self.alpha = alpha  # Pixel loss weight
        self.beta = beta    # Frequency loss weight
        self.gamma = gamma  # Edge loss weight
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
        sobel_x = torch.tensor([[[[1, 0, -1], [2, 0, -2], [1, 0, -1]]]], dtype=torch.float32)
        sobel_y = torch.tensor([[[[1, 2, 1], [0, 0, 0], [-1, -2, -1]]]], dtype=torch.float32)
        
        self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
        self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
        
        # Cache for FFT plans (PyTorch automatically optimizes repeated FFT sizes)
        self.fft_cache_size = None
    
    def forward(self, pred, target):
        # Pixel loss (unchanged)
        pixel_loss = F.l1_loss(pred, target) + 0.5 * F.mse_loss(pred, target)
        
        # Optimized frequency loss - compute less frequently
        if self.training and torch.rand(1) < 0.3:  # Only 30% of the time during training
            pred_fft = torch.fft.rfft2(pred, norm='ortho')
            target_fft = torch.fft.rfft2(target, norm='ortho')
            freq_loss = F.l1_loss(torch.abs(pred_fft), torch.abs(target_fft))
        else:
            freq_loss = 0.0
        
        # Optimized edge loss - use pre-registered kernels
        edges_x_pred = F.conv2d(pred, self.sobel_x, padding=1, groups=3)
        edges_y_pred = F.conv2d(pred, self.sobel_y, padding=1, groups=3)
        edges_pred = torch.sqrt(edges_x_pred**2 + edges_y_pred**2 + 1e-6)
        
        edges_x_target = F.conv2d(target, self.sobel_x, padding=1, groups=3)
        edges_y_target = F.conv2d(target, self.sobel_y, padding=1, groups=3)
        edges_target = torch.sqrt(edges_x_target**2 + edges_y_target**2 + 1e-6)
        
        edge_loss = F.l1_loss(edges_pred, edges_target)
        
        return (self.alpha * pixel_loss + 
                self.beta * freq_loss + 
                self.gamma * edge_loss)

In [None]:
class DeblurTrainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.device = config['device']
        self.model = model.to(self.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        if hasattr(torch, 'compile'):
            self.model = torch.compile(self.model, mode='reduce-overhead')
        
        # Loss function with sharpness emphasis
        self.loss_fn = SharpnessLoss(
            alpha=1.0, 
            beta=0.7,  # Higher weight for frequency loss
            gamma=0.3  # Edge preservation
        )
        
        # Optimizer
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay'],
            fused=True  # Faster fused implementation
        )
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=config['lr'],
            steps_per_epoch=len(train_loader),
            epochs=config['epochs'],
            pct_start=0.3
        )
        
        # Mixed precision training
        self.scaler = GradScaler()
        
        # Initialize wandb
        # wandb.init(
        #     project=config['project_name'],
        #     config=config,
        #     name=config['run_name'],
        #     reinit=True
        # )
        # wandb.watch(model, log='all', log_freq=100)
    
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0.0
        grad_accum = self.config['grad_accum']
        
        for i, (blur, sharp) in tqdm(enumerate(self.train_loader), desc="Train Batch"):
            blur = blur.to(self.device, non_blocking=True)
            sharp = sharp.to(self.device, non_blocking=True)
            
            # Mixed precision training
            with autocast():
                pred = self.model(blur)
                loss = self.loss_fn(pred, sharp) / grad_accum
            
            # Backpropagation
            self.scaler.scale(loss).backward()
            
            # Gradient accumulation step
            if (i + 1) % grad_accum == 0:
                # Gradient clipping
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                
                # Update weights
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()
                
                # Update learning rate
                self.scheduler.step()
                
                # Log batch loss
                wandb.log({
                    "batch_loss": loss.item() * grad_accum,
                    "lr": self.scheduler.get_last_lr()[0]
                })
            
            total_loss += loss.item() * grad_accum
        
        return total_loss / len(self.train_loader)
    
    def evaluate(self, epoch):
        self.model.eval()
        total_loss = 0.0
        total_psnr = 0.0
        total_ssim = 0.0
        count = 0
        
        with torch.no_grad():
            for blur, sharp in tqdm(self.val_loader, desc="Val Batch"):
                blur = blur.to(self.device)
                sharp = sharp.to(self.device)
                
                pred = self.model(blur)
                pred = torch.clamp(pred, 0, 1)
                
                # Calculate loss
                loss = self.loss_fn(pred, sharp)
                total_loss += loss.item() * blur.size(0)
                
                # Calculate PSNR
                mse = torch.mean((pred - sharp) ** 2)
                psnr_val = 20 * torch.log10(1.0 / torch.sqrt(mse))
                total_psnr += psnr_val.item() * blur.size(0)
                
                # Calculate SSIM
                mu_x = pred.mean(dim=[1, 2, 3])
                mu_y = sharp.mean(dim=[1, 2, 3])
                sigma_x = pred.var(dim=[1, 2, 3])
                sigma_y = sharp.var(dim=[1, 2, 3])
                sigma_xy = torch.mean(pred * sharp, dim=[1, 2, 3]) - mu_x * mu_y
                
                ssim_val = ((2 * mu_x * mu_y) * (2 * sigma_xy)) / \
                          ((mu_x**2 + mu_y**2) * (sigma_x + sigma_y) + 1e-8)
                total_ssim += ssim_val.sum().item()
                count += blur.size(0)
                
                # Log sample images every 3 epochs
                if epoch % 3 == 0 and count < 16:  # Log first few batches
                    self.log_sample_images(blur, sharp, pred)
        
        return {
            'loss': total_loss / count,
            'PSNR': total_psnr / count,
            'SSIM': total_ssim / count
        }
    
    def log_sample_images(self, blur, sharp, pred):
        """Log comparison images to wandb"""
        # Convert to numpy and denormalize if needed
        blur_np = blur.cpu().numpy()
        sharp_np = sharp.cpu().numpy()
        pred_np = pred.cpu().numpy()
        
        # Log first 3 samples
        images = []
        for i in range(min(3, blur.size(0))):
            # Calculate sharpness metrics
            sharpness_gt = self.calculate_sharpness(sharp[i])
            sharpness_pred = self.calculate_sharpness(pred[i])
            
            images.append(wandb.Image(
                blur_np[i].transpose(1, 2, 0), 
                caption=f"Blurred Input (Epoch {self.current_epoch})"
            ))
            images.append(wandb.Image(
                sharp_np[i].transpose(1, 2, 0), 
                caption=f"Sharp Target | Sharpness: {sharpness_gt:.3f}"
            ))
            images.append(wandb.Image(
                pred_np[i].transpose(1, 2, 0), 
                caption=f"Predicted | Sharpness: {sharpness_pred:.3f}"
            ))
        
        wandb.log({"results": images})
    
    def calculate_sharpness(self, image_tensor):
        """Calculate sharpness metric (variance of Laplacian)"""
        # Convert to grayscale
        if image_tensor.shape[0] == 3:  # RGB
            gray = 0.2989 * image_tensor[0] + 0.5870 * image_tensor[1] + 0.1140 * image_tensor[2]
        else:
            gray = image_tensor[0]
        
        # Create proper 4D tensor for input (B,C,H,W)
        gray_input = gray.unsqueeze(0).unsqueeze(0)
        
        # Create proper 4D Laplacian kernel (out_channels, in_channels, H, W)
        # and move to the same device as the input
        laplacian_kernel = torch.tensor([[[[0, 1, 0], 
                                        [1, -4, 1], 
                                        [0, 1, 0]]]], 
                                        dtype=torch.float32, 
                                        device=image_tensor.device)
        
        # Apply convolution
        laplacian = F.conv2d(gray_input, laplacian_kernel, padding=1)
        
        return torch.var(laplacian).item()
    
    def fit(self):
        best_psnr = 0
        best_sharpness = 0
        early_stop_counter = 0
        patience = 8
        
        for epoch in tqdm(range(self.config['epochs']), desc="Epoch"):
            self.current_epoch = epoch
            
            # Train for one epoch
            train_loss = self.train_epoch(epoch)
            
            # Validate
            val_metrics = self.evaluate(epoch)
            
            # Log metrics
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_metrics['loss'],
                "val_psnr": val_metrics['PSNR'],
                "val_ssim": val_metrics['SSIM']
            })
            
            # Print progress
            print(f"Epoch {epoch+1}/{self.config['epochs']} | "
                  f"Train Loss: {train_loss:.5f} | "
                  f"Val Loss: {val_metrics['loss']:.5f} | "
                  f"PSNR: {val_metrics['PSNR']:.2f} dB | "
                  f"SSIM: {val_metrics['SSIM']:.4f}")
            
            # Early stopping and model saving
            if val_metrics['PSNR'] > best_psnr:
                best_psnr = val_metrics['PSNR']
                torch.save(self.model.state_dict(), "best_model.pth")
                wandb.save("best_model.pth")
                print(f"Saved best model with PSNR: {best_psnr:.2f} dB")
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        
        # Save final model
        torch.save(self.model.state_dict(), "final_model.pth")
        wandb.save("final_model.pth")
        wandb.finish()

In [None]:
config = {
    'project_name': 'image-deblurring-sharp',
    'run_name': 'sharp-vit-128',
    'image_size': IMG_DIMENSIONS[0],
    'batch_size': 8,
    'grad_accum': 4,
    'epochs': 50,
    'lr': 3e-4,
    'weight_decay': 1e-5,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [None]:
# Initialize model
model = SharpDeblurViT(image_size=config['image_size'])

# Calculate parameters
total_params = sum(p.numel() for p in model.parameters())
config['parameters'] = f"{total_params/1e6:.2f}M"
print(f"Model parameters: {config['parameters']}")

# Create datasets and loaders (replace with your data)

transform = Compose([
    Resize((config['image_size'], config['image_size'])),
    ToTensor()
])

# Train model
trainer = DeblurTrainer(model, train_loader, val_loader, config)
# trainer.fit()