# 🎨 Advanced Neural Style Transfer Training
## Training AdaIN, CNN, and ViT Models on Google Colab Free GPU

This notebook trains three state-of-the-art style transfer models:
- **AdaIN**: Fast arbitrary style transfer (100x faster)
- **CNN**: VGG-based traditional approach
- **ViT**: Vision Transformer with global context

### Dataset
- **Content**: MS-COCO 2017 (naturalistic images)
- **Style**: WikiArt (diverse artistic styles)

### Hardware Requirements
- Google Colab Free GPU (T4 with ~15GB VRAM)
- Training time: ~3-4 hours per model

---

- ✅ Added memory cleanup between models
- ✅ Added error handling for checkpoints
- ✅ Added dataset validation
- ✅ Optimized for Colab Free GPU constraints

## 📦 Setup and Installation

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Clone repository
!git clone https://github.com/Ab-Romia/StyleTransferApp.git
%cd StyleTransferApp

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q timm lpips matplotlib seaborn tqdm Pillow tensorboard kaggle

## 📥 Download Datasets

We'll use:
- **WikiArt**: 80K+ artistic images across 27 styles
- **MS-COCO 2017 Train**: 118K naturalistic images

In [None]:
import os
from pathlib import Path
import shutil

# Create data directories
data_dir = Path('data')
data_dir.mkdir(exist_ok=True)

content_dir = data_dir / 'coco'
style_dir = data_dir / 'wikiart'

content_dir.mkdir(exist_ok=True)
style_dir.mkdir(exist_ok=True)

In [None]:
# Download MS-COCO 2017 Train (subset for faster training on Colab Free)
# We'll download and use 10K images to fit in Colab's disk space

print("Downloading MS-COCO dataset...")
!wget -q http://images.cocodataset.org/zips/train2017.zip -P data/
print("Extracting...")
!unzip -q data/train2017.zip -d data/
!rm data/train2017.zip

# Move images properly
if Path('data/train2017').exists():
    for img in Path('data/train2017').glob('*'):
        shutil.move(str(img), str(content_dir / img.name))
    shutil.rmtree('data/train2017')

# Keep only 10K images for Colab (saves disk space)
import random
coco_images = list(content_dir.glob('*.jpg'))
if len(coco_images) > 10000:
    images_to_remove = random.sample(coco_images, len(coco_images) - 10000)
    for img in images_to_remove:
        img.unlink()
    print(f"Kept 10,000 images for training")
else:
    print(f"Using all {len(coco_images)} images")

# Validate we have enough data
if len(list(content_dir.glob('*.jpg'))) < 1000:
    raise RuntimeError(f"Insufficient content images. Need at least 1000, got {len(list(content_dir.glob('*.jpg')))}")

In [None]:
# Download WikiArt dataset
# For best results, use Kaggle dataset. Otherwise, we'll use a curated collection.

use_kaggle = True  # Set to False to skip Kaggle and use alternative

if use_kaggle:
    try:
        print("Downloading WikiArt dataset from Kaggle...")
        print("Note: You need to upload your kaggle.json to authenticate")
        print("Get it from: https://www.kaggle.com/settings/account -> Create New API Token")
        
        # Upload kaggle.json
        from google.colab import files
        print("\nPlease upload your kaggle.json file:")
        uploaded = files.upload()
        
        # Setup Kaggle
        !mkdir -p ~/.kaggle
        !cp kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json
        
        # Download WikiArt dataset  
        !kaggle datasets download -d ipythonx/wikiart-gangogh-creating-art-gan -p data/
        !unzip -q data/wikiart-gangogh-creating-art-gan.zip -d data/wikiart/
        !rm data/wikiart-gangogh-creating-art-gan.zip
    except Exception as e:
        print(f"Kaggle download failed: {e}")
        use_kaggle = False

# Fallback: Download curated collection
if not use_kaggle or len(list(style_dir.glob('**/*.jpg'))) < 50:
    print("\nDownloading curated style collection...")
    
    # Download from Hugging Face or other sources
    !pip install -q gdown
    
    # Alternative: Use best-of-wikiart dataset (smaller but high quality)
    style_urls = [
        # Add direct download URLs for famous artworks
        ("https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg", "starry_night.jpg"),
        ("https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/Edvard_Munch%2C_1893%2C_The_Scream%2C_oil%2C_tempera_and_pastel_on_cardboard%2C_91_x_73_cm%2C_National_Gallery_of_Norway.jpg/800px-Edvard_Munch%2C_1893%2C_The_Scream%2C_oil%2C_tempera_and_pastel_on_cardboard%2C_91_x_73_cm%2C_National_Gallery_of_Norway.jpg", "scream.jpg"),
        ("https://upload.wikimedia.org/wikipedia/commons/thumb/5/5c/Claude_Monet_-_Water_Lilies_-_1916_-_Google_Art_Project.jpg/1280px-Claude_Monet_-_Water_Lilies_-_1916_-_Google_Art_Project.jpg", "water_lilies.jpg"),
    ]
    
    for url, name in style_urls:
        !wget -q {url} -O data/wikiart/{name}
    
    print(f"Downloaded {len(style_urls)} style images")

# Validate style dataset
style_count = len(list(style_dir.glob('**/*.jpg'))) + len(list(style_dir.glob('**/*.png')))
print(f"\n✓ Content images: {len(list(content_dir.glob('*.jpg')))}")
print(f"✓ Style images: {style_count}")

if style_count < 10:
    print("\n⚠️ WARNING: Very few style images. Training quality may be limited.")
    print("   Consider uploading more style images to data/wikiart/")

## 🔍 Data Exploration and Visualization

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import random

def show_images(image_paths, titles=None, n_cols=5, figsize=(15, 3)):
    """Display a grid of images"""
    n_images = len(image_paths)
    n_rows = (n_images + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    axes = axes.flatten() if n_images > 1 else [axes]
    
    for idx, img_path in enumerate(image_paths):
        img = Image.open(img_path).convert('RGB')
        axes[idx].imshow(img)
        axes[idx].axis('off')
        if titles:
            axes[idx].set_title(titles[idx], fontsize=10)
    
    # Hide empty subplots
    for idx in range(n_images, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

# Show sample content images
content_images = list(Path('data/coco').glob('*.jpg'))
sample_content = random.sample(content_images, min(10, len(content_images)))
print("📷 Sample Content Images (MS-COCO):")
show_images(sample_content, n_cols=5, figsize=(15, 6))

# Show sample style images
style_images = list(Path('data/wikiart').glob('**/*.jpg')) + list(Path('data/wikiart').glob('**/*.png'))
sample_styles = random.sample(style_images, min(10, len(style_images)))
print("\n🎨 Sample Style Images (WikiArt):")
show_images(sample_styles, n_cols=5, figsize=(15, 6))

## 🎯 Training Configuration

In [None]:
# Training configuration optimized for Colab Free GPU
CONFIG = {
    # Data
    'content_dir': 'data/coco',
    'style_dir': 'data/wikiart',
    'image_size': 256,  # Reduced for Colab Free memory
    'batch_size': 4,    # Conservative batch size for 15GB GPU
    'num_workers': 2,
    
    # Training
    'num_epochs': 20,   # Reduced for Colab session limits
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'lr_patience': 3,
    
    # Model
    'use_attention': True,
    
    # Loss weights (optimized for quality)
    'content_weight': 1.0,
    'style_weight': 100.0,
    'perceptual_weight': 0.5,
    'lpips_weight': 0.3,  # Reduced to save memory
    'tv_weight': 1e-4,
    'use_lpips': True,
    'use_multiscale': False,  # Disabled to save memory on Colab Free
    
    # Optimization
    'use_amp': True,  # Mixed precision for faster training
    'use_tensorboard': True,
    
    # Logging
    'log_interval': 50,
    'sample_interval': 200,
    'save_interval': 5,
    'keep_checkpoints': 3,
    
    # Validation split
    'val_split': 0.02,  # 2% for validation (saves time)
}

print("Training Configuration:")
print("=" * 50)
for key, value in CONFIG.items():
    print(f"{key:.<30} {value}")
print("=" * 50)

## 📊 Dataset Preparation

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
from pathlib import Path

class StyleTransferDataset(Dataset):
    """Optimized dataset for style transfer training"""
    
    def __init__(self, content_dir, style_dir, image_size=256, mode='train'):
        self.content_dir = Path(content_dir)
        self.style_dir = Path(style_dir)
        self.image_size = image_size
        self.mode = mode
        
        # Collect image paths
        self.content_images = self._collect_images(self.content_dir)
        self.style_images = self._collect_images(self.style_dir)
        
        if len(self.content_images) == 0:
            raise RuntimeError(f"No content images found in {content_dir}")
        if len(self.style_images) == 0:
            raise RuntimeError(f"No style images found in {style_dir}")
        
        print(f"Loaded {len(self.content_images)} content images")
        print(f"Loaded {len(self.style_images)} style images")
        
        # Transforms
        if mode == 'train':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                transforms.ToTensor(),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
            ])
    
    def _collect_images(self, directory):
        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
        images = []
        for ext in valid_extensions:
            images.extend(list(directory.rglob(f'*{ext}')))
            images.extend(list(directory.rglob(f'*{ext.upper()}')))
        return sorted(images)
    
    def __len__(self):
        return len(self.content_images)
    
    def __getitem__(self, idx):
        # Load content
        content_path = self.content_images[idx]
        content_img = Image.open(content_path).convert('RGB')
        
        # Random style
        style_idx = np.random.randint(0, len(self.style_images))
        style_path = self.style_images[style_idx]
        style_img = Image.open(style_path).convert('RGB')
        
        # Transform
        content_tensor = self.transform(content_img)
        style_tensor = self.transform(style_img)
        
        # Random alpha for training
        alpha = np.random.uniform(0.5, 1.0) if self.mode == 'train' else 1.0
        
        return {
            'content': content_tensor,
            'style': style_tensor,
            'alpha': torch.tensor(alpha, dtype=torch.float32)
        }

# Create datasets
print("Creating datasets...")
full_dataset = StyleTransferDataset(
    CONFIG['content_dir'],
    CONFIG['style_dir'],
    CONFIG['image_size'],
    mode='train'
)

# Split into train/val
val_size = int(len(full_dataset) * CONFIG['val_split'])
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"\nTrain: {len(train_dataset)} images")
print(f"Val: {len(val_dataset)} images")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"\n✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")

## 🚀 Model Training Functions (FIXED VERSION)

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import time
from collections import defaultdict
import gc

# Import models
from models.adain_model import AdaINStyleTransfer
from models.cnn_model import StyleTransferModel as CNNModel, VGGFeatures
from models.vit_model import StyleTransferModel as ViTModel
from models.losses import CombinedLoss

class ModelTrainer:
    """Unified trainer for all models (FIXED VERSION)"""
    
    def __init__(self, model, model_name, config, device):
        self.model = model.to(device)
        self.model_name = model_name
        self.config = config
        self.device = device
        
        # Create output directories
        self.output_dir = Path(f'outputs/{model_name}')
        self.checkpoint_dir = self.output_dir / 'checkpoints'
        self.sample_dir = self.output_dir / 'samples'
        
        for d in [self.checkpoint_dir, self.sample_dir]:
            d.mkdir(parents=True, exist_ok=True)
        
        # Loss function setup
        if model_name == 'AdaIN':
            self.criterion = CombinedLoss(
                content_weight=config['content_weight'],
                style_weight=config['style_weight'],
                perceptual_weight=config['perceptual_weight'],
                lpips_weight=config['lpips_weight'],
                tv_weight=config['tv_weight'],
                use_lpips=config['use_lpips']
            ).to(device)
            self.vgg_loss = None
        else:
            # Initialize VGG features for loss computation
            from models.cnn_model import content_loss, style_loss
            self.content_loss_fn = content_loss
            self.style_loss_fn = style_loss
            self.vgg_loss = VGGFeatures()
            # Remove auto device placement from VGGFeatures
            self.vgg_loss = self.vgg_loss.to(device)
            self.vgg_loss.eval()
            print(f"✓ VGGFeatures initialized for {model_name}")
        
        # Optimizer
        self.optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        # Scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=config['lr_patience'],
            verbose=True
        )
        
        # Mixed precision
        self.use_amp = config['use_amp']
        if self.use_amp:
            self.scaler = GradScaler()
        
        # Training state
        self.history = defaultdict(list)
        self.best_val_loss = float('inf')
    
    def train_epoch(self, train_loader, epoch):
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['num_epochs']}")
        
        for batch_idx, batch in enumerate(pbar):
            content = batch['content'].to(self.device)
            style = batch['style'].to(self.device)
            alpha = batch['alpha'].to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.use_amp:
                with autocast():
                    loss = self._compute_loss(content, style, alpha)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss = self._compute_loss(content, style, alpha)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                self.optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Save samples periodically
            if batch_idx % (self.config['sample_interval'] // self.config['batch_size']) == 0:
                self._save_samples(content[:2], style[:2], epoch, batch_idx)
        
        return epoch_loss / len(train_loader)
    
    @torch.no_grad()
    def validate(self, val_loader):
        """Validate the model"""
        self.model.eval()
        val_loss = 0
        
        for batch in tqdm(val_loader, desc="Validation", leave=False):
            content = batch['content'].to(self.device)
            style = batch['style'].to(self.device)
            alpha = batch['alpha'].to(self.device)
            
            loss = self._compute_loss(content, style, alpha)
            val_loss += loss.item()
        
        return val_loss / len(val_loader)
    
    def _compute_loss(self, content, style, alpha):
        """Compute loss based on model type (FIXED VERSION)"""
        if self.model_name == 'AdaIN':
            output = self.model(content, style, alpha=alpha.mean().item())
            loss, _ = self.criterion(output, content, style, return_components=True)
        else:
            # Format style threshold for batch processing
            batch_size = content.size(0)
            style_threshold = alpha.mean().view(1).expand(batch_size)
            
            output = self.model(content, style, style_threshold)
            
            # Use VGG features for loss computation
            output_content, output_styles = self.vgg_loss(output)
            target_content, target_styles = self.vgg_loss(content)
            _, style_features = self.vgg_loss(style)
            
            c_loss = self.content_loss_fn(output_content, target_content)
            s_loss = self.style_loss_fn(output_styles, style_features)
            
            loss = c_loss + self.config['style_weight'] * s_loss
        
        return loss
    
    def _save_samples(self, content, style, epoch, batch_idx):
        """Save sample outputs"""
        self.model.eval()
        with torch.no_grad():
            if self.model_name == 'AdaIN':
                output = self.model(content, style, alpha=1.0)
            else:
                # Format style threshold
                batch_size = content.size(0)
                style_threshold = torch.tensor([0.8] * batch_size).to(self.device)
                output = self.model(content, style, style_threshold)
        
        # Create comparison grid
        import torchvision.utils as vutils
        samples = torch.cat([content, style, output.clamp(0, 1)], dim=0)
        grid = vutils.make_grid(samples, nrow=len(content), padding=2)
        
        save_path = self.sample_dir / f'epoch{epoch:03d}_batch{batch_idx:04d}.png'
        vutils.save_image(grid, save_path)
        self.model.train()
    
    def train(self, train_loader, val_loader):
        """Main training loop"""
        print(f"\n{'='*60}")
        print(f"Training {self.model_name} Model")
        print(f"{'='*60}\n")
        
        start_time = time.time()
        
        for epoch in range(self.config['num_epochs']):
            # Train
            train_loss = self.train_epoch(train_loader, epoch)
            self.history['train_loss'].append(train_loss)
            
            # Validate
            val_loss = self.validate(val_loader)
            self.history['val_loss'].append(val_loss)
            
            # Learning rate scheduling
            self.scheduler.step(val_loss)
            
            # Logging
            print(f"\nEpoch {epoch+1}/{self.config['num_epochs']}")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss: {val_loss:.4f}")
            print(f"  LR: {self.optimizer.param_groups[0]['lr']:.2e}")
            
            # Save checkpoint
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
                self._save_checkpoint(epoch, is_best=True)
                print(f"  ✓ Best model saved! (val_loss: {val_loss:.4f})")
            
            if (epoch + 1) % self.config['save_interval'] == 0:
                self._save_checkpoint(epoch, is_best=False)
        
        elapsed = time.time() - start_time
        print(f"\n✓ Training completed in {elapsed/3600:.2f} hours")
        print(f"✓ Best validation loss: {self.best_val_loss:.4f}")
        
        return self.history
    
    def _save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val_loss': self.best_val_loss,
            'history': dict(self.history)
        }
        
        if is_best:
            path = self.checkpoint_dir / 'best_model.pth'
        else:
            path = self.checkpoint_dir / f'checkpoint_epoch{epoch:03d}.pth'
        
        torch.save(checkpoint, path)

print("✓ Trainer class loaded (FIXED VERSION)")

## 🎯 Train Model 1: AdaIN (Recommended)

AdaIN is the fastest and most versatile model. Perfect for real-time applications.

In [None]:
# Initialize AdaIN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on: {device}")

adain_model = AdaINStyleTransfer(use_attention=CONFIG['use_attention'])
adain_trainer = ModelTrainer(adain_model, 'AdaIN', CONFIG, device)

# Train
adain_history = adain_trainer.train(train_loader, val_loader)

# Save final model
torch.save(adain_model.state_dict(), 'outputs/AdaIN/adain_final.pth')
print("\n✓ AdaIN model training complete!")

# Clean up GPU memory
del adain_model, adain_trainer
torch.cuda.empty_cache()
gc.collect()
print("✓ Memory cleaned")

## 🎯 Train Model 2: CNN

Traditional VGG-based approach with style intensity control.

In [None]:
# Initialize CNN model
cnn_model = CNNModel()
cnn_trainer = ModelTrainer(cnn_model, 'CNN', CONFIG, device)

# Train
cnn_history = cnn_trainer.train(train_loader, val_loader)

# Save final model
torch.save(cnn_model.state_dict(), 'outputs/CNN/cnn_final.pth')
print("\n✓ CNN model training complete!")

# Clean up GPU memory
del cnn_model, cnn_trainer
torch.cuda.empty_cache()
gc.collect()
print("✓ Memory cleaned")

## 🎯 Train Model 3: ViT (Vision Transformer)

Most advanced model with global context understanding. Best quality but slower.

In [None]:
# Initialize ViT model
vit_model = ViTModel()
vit_trainer = ModelTrainer(vit_model, 'ViT', CONFIG, device)

# Train
vit_history = vit_trainer.train(train_loader, val_loader)

# Save final model
torch.save(vit_model.state_dict(), 'outputs/ViT/vit_final.pth')
print("\n✓ ViT model training complete!")

# Clean up
torch.cuda.empty_cache()
gc.collect()

## 📊 Load Models for Evaluation

In [None]:
# Load checkpoints safely

def load_model_safely(model, checkpoint_path):
    """Load model with error handling"""
    if not Path(checkpoint_path).exists():
        print(f"⚠️ Warning: {checkpoint_path} not found, using current model state")
        return False
    
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✓ Loaded checkpoint from {checkpoint_path}")
        return True
    except Exception as e:
        print(f"⚠️ Error loading checkpoint: {e}")
        return False

# Reload models for evaluation
adain_model = AdaINStyleTransfer(use_attention=CONFIG['use_attention']).to(device)
cnn_model = CNNModel().to(device)
vit_model = ViTModel().to(device)

load_model_safely(adain_model, 'outputs/AdaIN/checkpoints/best_model.pth')
load_model_safely(cnn_model, 'outputs/CNN/checkpoints/best_model.pth')
load_model_safely(vit_model, 'outputs/ViT/checkpoints/best_model.pth')

adain_model.eval()
cnn_model.eval()
vit_model.eval()

models_dict = {
    'AdaIN': adain_model,
    'CNN': cnn_model,
    'ViT': vit_model
}

print("\n✓ All models loaded for evaluation")

## 📊 Performance Visualizations

*(Continue with the existing visualization cells from the original notebook)*