# ðŸŽŒ Kintsugi AI - Digital Image Restoration

This notebook trains a U-Net model with attention gates for restoring damaged historical images.

**Features:**
- Synthetic degradation pipeline (scratches, noise, masks, fading)
- Attention U-Net architecture
- Composite loss (L1 + SSIM + Perceptual)
- Progressive training

---

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Clone the Kintsugi AI repository
!git clone https://github.com/AleenaTahir1/Kintsugi-AI.git
%cd Kintsugi-AI

import os
PROJECT_ROOT = '/content/Kintsugi-AI'

# Create data directories
os.makedirs(f'{PROJECT_ROOT}/data/train', exist_ok=True)
os.makedirs(f'{PROJECT_ROOT}/data/val', exist_ok=True)
os.makedirs(f'{PROJECT_ROOT}/checkpoints', exist_ok=True)
print('Setup complete!')

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q opencv-python Pillow tqdm gradio matplotlib gdown

In [None]:
# Verify PyTorch and CUDA
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Download Dataset

We'll use CelebA-HQ for face restoration. You can also use your own dataset.

In [None]:
# Option 1: Download CelebA-HQ subset (recommended for quick training)
# This downloads a smaller subset for demonstration

import gdown
import zipfile

# CelebA-HQ 256x256 subset (adjust URL to your dataset)
# For full dataset, use Kaggle or official sources

# Example: Download from Google Drive (replace with your link)
# gdown.download('https://drive.google.com/uc?id=YOUR_FILE_ID', 'celeba_hq.zip', quiet=False)

# For now, we'll use a small sample
print("Please upload your dataset to /content/kintsugi-ai/data/train/")
print("Or run the cell below to use sample images")

In [None]:
# Option 2: Generate sample images for testing
# This creates random colored squares as placeholder images

import numpy as np
from PIL import Image
import os

def create_sample_images(output_dir, num_images=100, size=256):
    """Create sample gradient images for testing."""
    os.makedirs(output_dir, exist_ok=True)
    
    for i in range(num_images):
        # Create gradient image with random colors
        x = np.linspace(0, 1, size)
        y = np.linspace(0, 1, size)
        xx, yy = np.meshgrid(x, y)
        
        # Random gradient direction and colors
        angle = np.random.uniform(0, 2 * np.pi)
        gradient = np.cos(angle) * xx + np.sin(angle) * yy
        
        # Create RGB channels
        r = (gradient * np.random.uniform(100, 255)).astype(np.uint8)
        g = ((1 - gradient) * np.random.uniform(100, 255)).astype(np.uint8)
        b = (np.abs(gradient - 0.5) * 2 * np.random.uniform(100, 255)).astype(np.uint8)
        
        img = np.stack([r, g, b], axis=-1)
        
        # Add some patterns
        if np.random.random() > 0.5:
            # Add circles
            for _ in range(np.random.randint(3, 10)):
                cx, cy = np.random.randint(0, size, 2)
                radius = np.random.randint(10, 50)
                color = np.random.randint(0, 255, 3)
                yy, xx = np.ogrid[:size, :size]
                mask = (xx - cx) ** 2 + (yy - cy) ** 2 <= radius ** 2
                img[mask] = color
        
        Image.fromarray(img).save(f"{output_dir}/sample_{i:04d}.png")
    
    print(f"Created {num_images} sample images in {output_dir}")

# Create sample images
create_sample_images(f'{PROJECT_ROOT}/data/train', num_images=500)
create_sample_images(f'{PROJECT_ROOT}/data/val', num_images=50)

In [None]:
# Option 3: Upload from Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Copy your dataset from Drive
# !cp -r /content/drive/MyDrive/your_dataset/* /content/kintsugi-ai/data/train/

## 3. Load Source Code

Upload the project files or paste them here.

In [None]:
# If you uploaded the project as a zip
# !unzip kintsugi-ai.zip -d /content/

# Add project to path
import sys
sys.path.insert(0, PROJECT_ROOT)

# Verify imports
try:
    from src.degradation import DegradationPipeline
    from src.dataset import KintsugiDataset, ProgressiveDataLoader
    from src.model import create_model, AttentionUNet
    from src.losses import CompositeLoss, compute_psnr
    from src.trainer import Trainer
    from src.inference import Restorer
    print("All modules imported successfully!")
except ImportError as e:
    print(f"Import error: {e}")
    print("Please ensure all source files are uploaded to the project directory.")

## 4. Visualize Degradation Pipeline

In [None]:
import matplotlib.pyplot as plt
import cv2
import numpy as np
from PIL import Image

# Load a sample image
sample_path = f'{PROJECT_ROOT}/data/train/'
sample_files = [f for f in os.listdir(sample_path) if f.endswith(('.jpg', '.png'))]

if sample_files:
    img = cv2.imread(os.path.join(sample_path, sample_files[0]))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (256, 256))
    
    # Apply degradations
    pipeline = DegradationPipeline(severity=0.8)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    axes[0, 0].imshow(img)
    axes[0, 0].set_title('Original')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(pipeline.apply_scratches(img.copy()))
    axes[0, 1].set_title('Scratches')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(pipeline.apply_gaussian_noise(img.copy()))
    axes[0, 2].set_title('Gaussian Noise')
    axes[0, 2].axis('off')
    
    axes[0, 3].imshow(pipeline.apply_random_mask(img.copy()))
    axes[0, 3].set_title('Random Masks')
    axes[0, 3].axis('off')
    
    axes[1, 0].imshow(pipeline.apply_color_fading(img.copy()))
    axes[1, 0].set_title('Color Fading')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(pipeline.apply_stains(img.copy()))
    axes[1, 1].set_title('Stains')
    axes[1, 1].axis('off')
    
    axes[1, 2].imshow(pipeline.apply_folding_lines(img.copy()))
    axes[1, 2].set_title('Folding Lines')
    axes[1, 2].axis('off')
    
    # Full degradation
    axes[1, 3].imshow(pipeline(img.copy()))
    axes[1, 3].set_title('Combined')
    axes[1, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'{PROJECT_ROOT}/degradation_examples.png', dpi=150)
    plt.show()
else:
    print("No sample images found. Please upload images first.")

## 5. Configure Training

In [None]:
# Training Configuration
from dataclasses import dataclass

@dataclass
class ColabConfig:
    # Data
    train_dir: str = f'{PROJECT_ROOT}/data/train'
    val_dir: str = f'{PROJECT_ROOT}/data/val'
    image_size: int = 256
    batch_size: int = 16  # Adjust based on GPU memory
    num_workers: int = 2
    
    # Model
    model_type: str = 'attention_unet'
    features: list = None  # Will be set in __post_init__
    
    # Training
    epochs: int = 50
    lr: float = 2e-4
    weight_decay: float = 1e-4
    use_amp: bool = True
    
    # Progressive training
    progressive: bool = True
    start_severity: float = 0.3
    end_severity: float = 1.0
    warmup_epochs: int = 10
    
    # Loss weights
    l1_weight: float = 1.0
    ssim_weight: float = 0.5
    perceptual_weight: float = 0.1
    edge_weight: float = 0.1
    
    # Checkpointing
    save_dir: str = f'{PROJECT_ROOT}/checkpoints'
    save_every: int = 5
    
    def __post_init__(self):
        if self.features is None:
            self.features = [64, 128, 256, 512]

config = ColabConfig()
print("Configuration:")
for key, value in vars(config).items():
    print(f"  {key}: {value}")

## 6. Initialize Model and Data

In [None]:
import torch
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Create datasets
print("\nLoading datasets...")
train_dataset = KintsugiDataset(
    root_dir=config.train_dir,
    image_size=config.image_size,
    severity=config.start_severity,
    augment=True
)

val_dataset = KintsugiDataset(
    root_dir=config.val_dir,
    image_size=config.image_size,
    severity=1.0,
    augment=False
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Create data loaders
train_loader = ProgressiveDataLoader(
    train_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    start_severity=config.start_severity,
    end_severity=config.end_severity,
    warmup_epochs=config.warmup_epochs
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)

# Create model
print("\nCreating model...")
model = create_model(
    model_type=config.model_type,
    features=config.features,
    device=device
)

In [None]:
# Visualize a training batch
degraded, clean = next(iter(train_loader))

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(4):
    # Degraded
    img = degraded[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].set_title('Degraded')
    axes[0, i].axis('off')
    
    # Clean
    img = clean[i].permute(1, 2, 0).numpy()
    axes[1, i].imshow(img)
    axes[1, i].set_title('Clean (Target)')
    axes[1, i].axis('off')

plt.suptitle('Training Batch: Degraded (top) vs Clean (bottom)')
plt.tight_layout()
plt.show()

## 7. Train Model

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    lr=config.lr,
    weight_decay=config.weight_decay,
    epochs=config.epochs,
    use_amp=config.use_amp,
    l1_weight=config.l1_weight,
    ssim_weight=config.ssim_weight,
    perceptual_weight=config.perceptual_weight,
    edge_weight=config.edge_weight,
    save_dir=config.save_dir,
    save_every=config.save_every
)

print("Trainer initialized. Starting training...")

In [None]:
# Train!
history = trainer.train()

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True)

# PSNR
axes[1].plot(history['val_psnr'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('PSNR (dB)')
axes[1].set_title(f'Validation PSNR (Best: {max(history["val_psnr"]):.2f} dB)')
axes[1].axhline(y=25, color='r', linestyle='--', label='Target (25 dB)')
axes[1].legend()
axes[1].grid(True)

# SSIM
axes[2].plot(history['val_ssim'])
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('SSIM')
axes[2].set_title(f'Validation SSIM (Best: {max(history["val_ssim"]):.4f})')
axes[2].axhline(y=0.85, color='r', linestyle='--', label='Target (0.85)')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig(f'{PROJECT_ROOT}/training_history.png', dpi=150)
plt.show()

print(f"\nFinal Results:")
print(f"  Best PSNR: {max(history['val_psnr']):.2f} dB")
print(f"  Best SSIM: {max(history['val_ssim']):.4f}")

## 8. Test Restoration

In [None]:
# Load best model
restorer = Restorer(
    checkpoint_path=f'{config.save_dir}/best_model.pth',
    model_type=config.model_type,
    features=config.features,
    device=device
)

In [None]:
# Test on validation images
from src.inference import compare_images

# Get some test samples
test_degraded, test_clean = next(iter(val_loader))

fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(4):
    # Get single image
    degraded_np = (test_degraded[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    clean_np = (test_clean[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    
    # Restore
    with torch.no_grad():
        restored = restorer.model(test_degraded[i:i+1].to(device))
    restored_np = (restored[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
    
    # Display
    axes[0, i].imshow(degraded_np)
    axes[0, i].set_title('Degraded Input')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(restored_np)
    axes[1, i].set_title('Restored')
    axes[1, i].axis('off')
    
    axes[2, i].imshow(clean_np)
    axes[2, i].set_title('Ground Truth')
    axes[2, i].axis('off')

plt.suptitle('Restoration Results: Degraded â†’ Restored â†’ Ground Truth', fontsize=14)
plt.tight_layout()
plt.savefig(f'{PROJECT_ROOT}/restoration_results.png', dpi=150)
plt.show()

## 9. Save Model to Google Drive

In [None]:
# Copy best model to Google Drive for later use
!cp {config.save_dir}/best_model.pth /content/drive/MyDrive/kintsugi_best_model.pth
print("Model saved to Google Drive!")

## 10. Launch Gradio Interface

In [None]:
# Simple Gradio interface
import gradio as gr

def restore_image(input_image):
    """Restore an uploaded image."""
    if input_image is None:
        return None
    
    # Convert to numpy
    if isinstance(input_image, Image.Image):
        input_image = np.array(input_image)
    
    # Restore
    restored = restorer.restore(input_image, target_size=512)
    
    return restored

demo = gr.Interface(
    fn=restore_image,
    inputs=gr.Image(label="Upload Damaged Image"),
    outputs=gr.Image(label="Restored Image"),
    title="ðŸŽŒ Kintsugi AI - Image Restoration",
    description="Upload a damaged historical photo and watch AI restore it."
)

demo.launch(share=True)

---

## Summary

This notebook trained a Kintsugi AI model for image restoration with:

- **Architecture:** Attention U-Net with skip connections
- **Training:** Self-supervised with synthetic degradation
- **Loss:** L1 + SSIM + Perceptual (VGG) + Edge
- **Metrics:** Target PSNR > 25dB, SSIM > 0.85

For better results:
1. Use a larger dataset (CelebA-HQ full, or domain-specific images)
2. Train for more epochs (100+)
3. Fine-tune on real damaged images if available