# Task 4 Fine-Tuning Implementation

This notebook contains the fine-tunning implementation for generating new class-conditional images of the selected damage type. In addition, the notebook is used on Colab

## Step 1: Upload Dataset

**Option A: Upload via Colab**

Run the cell below and upload the `bottle` folder from your local machine:
`Defect_Spectrum/DS-MVTec/bottle/`

In [None]:
# Option A: Direct upload
import os
from google.colab import files
import shutil

# Create directory structure
os.makedirs('Defect_Spectrum/DS-MVTec/bottle/image', exist_ok=True)

print("Please upload your bottle dataset as a zip file...")
print("Create a zip of: Defect_Spectrum/DS-MVTec/bottle/")
print("(Should contain image/ subfolder with damage type subdirectories)")

uploaded = files.upload()
for filename in uploaded.keys():
    if filename.endswith('.zip'):
        shutil.unpack_archive(filename, '.')

## Step 2: Install Dependencies

In [None]:
!pip install -q diffusers transformers accelerate peft safetensors
print("Dependencies installed")

## Step 3: Dataset Class Definition

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, Optional, List
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF

@dataclass(frozen=True)
class LocalDatasetConfig:
    """Configuration for local DefectSpectrum dataset loading."""
    resolution: int = 512
    augment: bool = True
    seed: int = 0
    normalize_to_neg1_pos1: bool = True
    
    load_masks: bool = False
    dataset_sources: Optional[List[str]] = None
    product_classes: Optional[List[str]] = None
    damage_types: Optional[List[str]] = None
    max_samples_per_damage_type: Optional[int] = None
    damage_type_to_class_id: Optional[Dict[str, int]] = None


class DefectSpectrumLocalDataset(Dataset[Dict[str, Any]]):
    """PyTorch Dataset for local DefectSpectrum folder structure."""
    
    def __init__(self, root_dir: str, cfg: LocalDatasetConfig):
        self.root_dir = Path(root_dir)
        self.cfg = cfg
        self.samples = self._scan_directory()
    
    def _scan_directory(self) -> List[Dict[str, Any]]:
        """Scan directory for image files."""
        samples = []
        
        for dataset_dir in self.root_dir.iterdir():
            if not dataset_dir.is_dir():
                continue
            
            dataset_source = dataset_dir.name
            if self.cfg.dataset_sources and dataset_source not in self.cfg.dataset_sources:
                continue
            
            for product_dir in dataset_dir.iterdir():
                if not product_dir.is_dir():
                    continue
                
                product_class = product_dir.name
                if self.cfg.product_classes and product_class not in self.cfg.product_classes:
                    continue
                
                image_dir = product_dir / "image"
                if not image_dir.exists():
                    continue
                
                for damage_type_dir in image_dir.iterdir():
                    if not damage_type_dir.is_dir():
                        continue
                    
                    damage_type = damage_type_dir.name
                    if self.cfg.damage_types and damage_type not in self.cfg.damage_types:
                        continue
                    
                    damage_samples = []
                    for img_path in sorted(damage_type_dir.glob("*")):
                        if img_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
                            sample = {
                                "image_path": str(img_path),
                                "dataset_source": dataset_source,
                                "product_class": product_class,
                                "damage_type": damage_type,
                                "filename": img_path.name,
                            }
                            damage_samples.append(sample)
                    
                    if self.cfg.max_samples_per_damage_type:
                        damage_samples = damage_samples[:self.cfg.max_samples_per_damage_type]
                    
                    samples.extend(damage_samples)
        
        return samples
    
    def _apply_transforms(self, img: Image.Image, index: int) -> Image.Image:
        """Apply transforms to image."""
        if img.mode == "L":
            img = img.point(lambda p: int(p * 255 / 4))
        
        if img.mode != "RGB":
            img = img.convert("RGB")
        
        img = TF.resize(img, [self.cfg.resolution, self.cfg.resolution], 
                       interpolation=TF.InterpolationMode.BICUBIC)
        
        if self.cfg.augment:
            rng = np.random.RandomState(self.cfg.seed + index)
            
            # Horizontal flip (50% chance)
            if rng.rand() < 0.5:
                img = TF.hflip(img)
            
            # Random rotation 15 degrees (adds variety without affecting defect visibility)
            if rng.rand() < 0.5:
                angle = rng.uniform(-15, 15)
                img = TF.rotate(img, angle, interpolation=TF.InterpolationMode.BILINEAR)
        
        return img
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, index: int) -> Dict[str, Any]:
        sample_meta = self.samples[index]
        img = Image.open(sample_meta["image_path"])
        original_size = img.size
        original_mode = img.mode
        
        img = self._apply_transforms(img, index=index)
        pixel_values = TF.to_tensor(img)
        
        if self.cfg.normalize_to_neg1_pos1:
            pixel_values = pixel_values * 2.0 - 1.0
        
        class_id = None
        if self.cfg.damage_type_to_class_id is not None:
            damage_type = sample_meta["damage_type"]
            class_id = self.cfg.damage_type_to_class_id.get(damage_type, None)
        
        result = {
            "pixel_values": pixel_values,
            "class_id": class_id,
            "meta": {
                "index": index,
                "dataset_source": sample_meta["dataset_source"],
                "product_class": sample_meta["product_class"],
                "damage_type": sample_meta["damage_type"],
                "filename": sample_meta["filename"],
                "original_size": original_size,
                "original_mode": original_mode,
            },
        }
        
        return result

print("✓ Dataset class defined")

## Step 4: Training Class Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import numpy as np
from PIL import Image
import warnings
import json
import csv
import matplotlib.pyplot as plt
from datetime import datetime

warnings.filterwarnings('ignore')


class DefectDiffusionTrainer:
    """Trainer for class-conditional defect generation with LR decay."""
    
    def __init__(
        self,
        output_dir: str = "./outputs",
        num_epochs: int = 100,
        batch_size: int = 4,
        learning_rate: float = 1e-4,
        lora_rank: int = 8,
        lora_alpha: int = 16,
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.lora_rank = lora_rank
        self.lora_alpha = lora_alpha
        self.device = device
        
        # Track loss and config
        self.loss_history = []
        
        # Determine LR schedule based on num_epochs
        if num_epochs <= 100:
            lr_schedule = f"Constant {learning_rate} (augmentation enhancement)"
        else:
            lr_schedule = f"{learning_rate} (0-99), 5e-5 (100-174), 2e-5 (175+)"
        
        self.config = {
            "num_epochs": num_epochs,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "lora_rank": lora_rank,
            "lora_alpha": lora_alpha,
            "device": str(device),
            "start_time": datetime.now().isoformat(),
            "lr_decay_schedule": lr_schedule
        }
        
        print("=" * 70)
        print("INITIALIZING DEFECT DIFFUSION TRAINER V3")
        print("=" * 70)
        print(f"Output directory: {self.output_dir}")
        print(f"Device: {self.device}")
        print(f"Epochs: {num_epochs}")
        print(f"Batch size: {batch_size}")
        print(f"Learning rate: {learning_rate} (with decay)")
        print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}")
        
    def setup_models(self):
        """Load and configure all models."""
        
        print("\n" + "=" * 70)
        print("STEP 1: Loading Models")
        print("=" * 70)
        
        model_id = "runwayml/stable-diffusion-v1-5"
        
        print("Loading VAE...")
        self.vae = AutoencoderKL.from_pretrained(
            model_id,
            subfolder="vae",
            torch_dtype=torch.float32
        ).to(self.device)
        self.vae.requires_grad_(False)
        self.vae.eval()
        print(f"VAE loaded (frozen)")
        
        print("\n Loading UNet...")
        self.unet = UNet2DConditionModel.from_pretrained(
            model_id,
            subfolder="unet",
            torch_dtype=torch.float32
        )
        
        print("Adding LoRA adapters...")
        lora_config = LoraConfig(
            r=self.lora_rank,
            lora_alpha=self.lora_alpha,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            lora_dropout=0.1,
            bias="none",
            init_lora_weights=True
        )
        
        self.config.update({
            "lora_target_modules": list(lora_config.target_modules),
            "lora_dropout": lora_config.lora_dropout,
            "lora_scaling_factor": lora_config.lora_alpha / lora_config.r,
        })
        
        self.unet = get_peft_model(self.unet, lora_config)
        self.unet.to(self.device)
        
        trainable_params = sum(p.numel() for p in self.unet.parameters() if p.requires_grad)
        print(f"UNet loaded with LoRA ({trainable_params/1e6:.2f}M trainable params)")
        
        print("\nCreating class embeddings...")
        self.class_embedder = nn.Embedding(
            num_embeddings=4,
            embedding_dim=768
        ).to(self.device)
        nn.init.normal_(self.class_embedder.weight, mean=0.0, std=0.02)
        
        class_params = self.class_embedder.weight.numel()
        print(f"Class embeddings created ({class_params} params)")
        
        self.config.update({
            "unet_lora_params": trainable_params,
            "class_embedding_params": class_params,
            "total_trainable_params": trainable_params + class_params,
        })
        
        print("\nSetting up noise scheduler...")
        self.scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False
        )
        print(f"DDPM scheduler ready (1000 timesteps)")
        
        print(f"\nTotal trainable parameters: {(trainable_params + class_params)/1e6:.2f}M")
        
    def setup_data(self):
        """Setup bottle dataset and dataloader."""
        
        print("\n" + "=" * 70)
        print("STEP 2: Loading Dataset")
        print("=" * 70)
        
        dataset_root = Path("./Defect_Spectrum")
        
        self.damage_type_to_class_id = {
            "broken_large": 0,
            "broken_small": 1,
            "contamination": 2,
            "good": 3
        }
        
        config = LocalDatasetConfig(
            resolution=512,
            augment=True,
            seed=42,
            normalize_to_neg1_pos1=True,
            load_masks=False,
            dataset_sources=["DS-MVTec"],
            product_classes=["bottle"],
            damage_types=None,
            damage_type_to_class_id=self.damage_type_to_class_id
        )
        
        self.dataset = DefectSpectrumLocalDataset(
            root_dir=str(dataset_root),
            cfg=config
        )
        
        print(f"Dataset loaded: {len(self.dataset)} samples")
        
        class_counts = {i: 0 for i in range(4)}
        for i in range(len(self.dataset)):
            class_id = self.dataset[i]["class_id"]
            class_counts[class_id] += 1
        
        print("\nSamples per class:")
        for class_id, count in sorted(class_counts.items()):
            damage_type = [k for k, v in self.damage_type_to_class_id.items() if v == class_id][0]
            print(f"  Class {class_id} ({damage_type:15s}): {count} samples")
        
        self.config.update({
            "dataset_size": len(self.dataset),
            "dataset_sources": ["DS-MVTec"],
            "product_classes": ["bottle"],
            "resolution": 512,
            "augmentation": True,
            "damage_types": list(self.damage_type_to_class_id.keys()),
            "class_distribution": {str(self.damage_type_to_class_id[k]): v for k, v in 
                                  [(k, class_counts[self.damage_type_to_class_id[k]]) 
                                   for k in self.damage_type_to_class_id.keys()]},
        })
        
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True if self.device == "cuda" else False
        )
        
        print(f"\nDataLoader created: {len(self.dataloader)} batches per epoch")
        
    def setup_optimizer(self):
        """Setup optimizer for joint training."""
        
        print("\n" + "=" * 70)
        print("STEP 3: Setting Up Optimizer")
        print("=" * 70)
        
        trainable_params = list(self.unet.parameters()) + list(self.class_embedder.parameters())
        trainable_params = [p for p in trainable_params if p.requires_grad]
        
        self.optimizer = torch.optim.AdamW(
            trainable_params,
            lr=self.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=0.01
        )
        
        self.config.update({
            "optimizer": "AdamW",
            "optimizer_betas": [0.9, 0.999],
            "optimizer_weight_decay": 0.01,
        })
        
        print(f" AdamW optimizer configured")
        print(f" Learning rate: {self.learning_rate}")
        print(f" Trainable parameters: {sum(p.numel() for p in trainable_params)/1e6:.2f}M")
        
    def train_epoch(self, epoch: int):
        """Train for one epoch."""
        
        self.unet.train()
        self.class_embedder.train()
        
        epoch_loss = 0.0
        progress_bar = tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{self.num_epochs}")
        
        for batch_idx, batch in enumerate(progress_bar):
            images = batch["pixel_values"].to(self.device)
            class_ids = batch["class_id"].to(self.device)
            
            class_embeddings = self.class_embedder(class_ids)
            class_embeddings = class_embeddings.unsqueeze(1)
            
            with torch.no_grad():
                latents = self.vae.encode(images).latent_dist.sample()
                latents = latents * 0.18215
            
            noise = torch.randn_like(latents)
            timesteps = torch.randint(
                0, self.scheduler.config.num_train_timesteps,
                (images.shape[0],),
                device=self.device
            ).long()
            
            noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
            
            noise_pred = self.unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=class_embeddings
            ).sample
            
            loss = F.mse_loss(noise_pred, noise)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        avg_loss = epoch_loss / len(self.dataloader)
        return avg_loss
    
    @torch.no_grad()
    def validate(self, epoch: int):
        """Generate validation samples with fixed seeds."""
        
        print(f"\nGenerating validation samples (epoch {epoch+1})...")
        
        self.unet.eval()
        self.class_embedder.eval()
        
        num_samples_per_class = 5
        
        for class_id in range(4):
            class_ids = torch.full((num_samples_per_class,), class_id, device=self.device)
            class_embeddings = self.class_embedder(class_ids).unsqueeze(1)
            
            # Fixed seed for reproducible validation
            torch.manual_seed(42 + class_id + epoch)
            latents = torch.randn(
                num_samples_per_class, 4, 64, 64,
                device=self.device
            )
            
            self.scheduler.set_timesteps(50)
            
            for t in self.scheduler.timesteps:
                timesteps = torch.full((num_samples_per_class,), t, device=self.device).long()
                
                noise_pred = self.unet(
                    latents,
                    timesteps,
                    encoder_hidden_states=class_embeddings
                ).sample
                
                latents = self.scheduler.step(noise_pred, t, latents).prev_sample
            
            latents = latents / 0.18215
            images = self.vae.decode(latents).sample
            
            images = (images / 2 + 0.5).clamp(0, 1)
            images = images.cpu().permute(0, 2, 3, 1).numpy()
            images = (images * 255).astype(np.uint8)
            
            damage_type = [k for k, v in self.damage_type_to_class_id.items() if v == class_id][0]
            
            for i, img in enumerate(images):
                img_pil = Image.fromarray(img)
                save_path = self.output_dir / f"epoch{epoch+1:03d}_class{class_id}_{damage_type}_{i}.png"
                img_pil.save(save_path)
        
        print(f" Validation samples saved to {self.output_dir}")
    
    def save_checkpoint(self, epoch: int):
        """Save model checkpoint."""
        
        checkpoint_path = self.output_dir / f"checkpoint_epoch{epoch+1:03d}.pt"
        
        torch.save({
            "epoch": epoch,
            "unet_state_dict": self.unet.state_dict(),
            "class_embedder_state_dict": self.class_embedder.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "damage_type_to_class_id": self.damage_type_to_class_id,
            "loss_history": self.loss_history,
            "config": self.config,
        }, checkpoint_path)
        
        print(f" Checkpoint saved: {checkpoint_path}")
    
    def save_config(self):
        """Save training configuration to JSON."""
        
        config_path = self.output_dir / "training_config.json"
        
        self.config.update({
            "total_epochs_trained": len(self.loss_history),
            "final_loss": self.loss_history[-1] if self.loss_history else None,
            "min_loss": min(self.loss_history) if self.loss_history else None,
            "max_loss": max(self.loss_history) if self.loss_history else None,
            "end_time": datetime.now().isoformat(),
        })
        
        with open(config_path, 'w') as f:
            json.dump(self.config, f, indent=2)
        
        print(f" Config saved: {config_path}")
    
    def save_loss_history(self):
        """Save loss history to CSV and plot."""
        
        csv_path = self.output_dir / "loss_history.csv"
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['Epoch', 'Loss'])
            for epoch, loss in enumerate(self.loss_history, 1):
                writer.writerow([epoch, loss])
        
        print(f" Loss history saved: {csv_path}")
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(self.loss_history) + 1), self.loss_history, 
                marker='o', linewidth=2, markersize=4, color='#2E86AB')
        plt.xlabel('Epoch', fontsize=12, fontweight='bold')
        plt.ylabel('Average Loss (MSE)', fontsize=12, fontweight='bold')
        plt.title('Training Loss Over Time', fontsize=14, fontweight='bold')
        plt.grid(True, alpha=0.3, linestyle='--')
        plt.tight_layout()
        
        plot_path = self.output_dir / "loss_curve.png"
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f" Loss plot saved: {plot_path}")
    
    def train(self):
        """Full training loop with learning rate decay."""
        
        print("\n" + "=" * 70)
        print("STEP 4: Training")
        print("=" * 70)
        
        for epoch in range(self.num_epochs):
            # Learning rate decay schedule
            if epoch == 100 and self.num_epochs > 100:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = 5e-5
                print(f" Learning rate reduced to 5e-5 (epoch 100)")
            elif epoch == 175 and self.num_epochs > 100:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = 2e-5
                print(f" Learning rate reduced to 2e-5 (epoch 175)")
            
            avg_loss = self.train_epoch(epoch)
            self.loss_history.append(avg_loss)
            print(f"\nEpoch {epoch+1}/{self.num_epochs} - Average Loss: {avg_loss:.4f}")
            
            if (epoch + 1) % 25 == 0:
                self.validate(epoch)
            
            if (epoch + 1) % 25 == 0:
                self.save_checkpoint(epoch)
        
        # Save final checkpoint
        self.save_checkpoint(self.num_epochs - 1)
        
        # Save training artifacts
        self.save_loss_history()
        self.save_config()
        
        print("\n" + "=" * 70)
        print("TRAINING COMPLETE!")
        print("=" * 70)
        print(f"Final loss: {self.loss_history[-1]:.4f}")
        print(f"Min loss: {min(self.loss_history):.4f} at epoch {self.loss_history.index(min(self.loss_history))+1}")
        print(f"Outputs saved to: {self.output_dir}")

print("Trainer class defined with LR decay")

## Step 5: Run Training

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

In [None]:
# Initialize trainer - train from scratch with early LR decay
trainer = DefectDiffusionTrainer(
    output_dir="./outputs/bottle_diffusion",
    num_epochs=250,
    batch_size=4,        # 4 for T4 GPU
    learning_rate=1e-4,
    lora_rank=16,        # Use rank=16 for better capacity
    lora_alpha=32,       # Maintain 2.0 scaling factor
)

# Setup everything
trainer.setup_models()
trainer.setup_data()
trainer.setup_optimizer()

## Step 6: View Results

Check generated samples and loss curve

In [None]:
# Display loss curve
from IPython.display import Image, display
import glob

# Show loss curve
if Path('outputs/bottle_diffusion/loss_curve.png').exists():
    print("Training Loss Curve:")
    display(Image('outputs/bottle_diffusion/loss_curve.png'))

# Show latest validation images (now 5 per class)
images = sorted(glob.glob('outputs/bottle_diffusion/epoch250*.png'))
if images:
    print(f"\nFinal epoch validation samples ({len(images)} images):")
    print("Showing first 20 (5 per class × 4 classes):")
    for img_path in images[:20]:
        print(Path(img_path).name)
        display(Image(img_path, width=250))

In [None]:
# View training config
import json

config_path = 'outputs/bottle_diffusion/training_config.json'
if Path(config_path).exists():
    with open(config_path) as f:
        config = json.load(f)
    
    print("Training Configuration:")
    print("=" * 50)
    for key, value in config.items():
        print(f"{key:30s}: {value}")

## Step 7: Download Results

Download only essentials: final checkpoint, config, loss data, and validation images

In [None]:
from google.colab import files

# Download final checkpoint only
print("Downloading final checkpoint...")
files.download('outputs/bottle_diffusion/checkpoint_epoch250.pt')

# Download training artifacts
print("\nDownloading training config and loss data...")
files.download('outputs/bottle_diffusion/training_config.json')
files.download('outputs/bottle_diffusion/loss_history.csv')
files.download('outputs/bottle_diffusion/loss_curve.png')

# Zip validation images
print("\nPreparing validation images...")
!zip -r validation_images.zip outputs/bottle_diffusion/epoch*.png
files.download('validation_images.zip')

print("\n Downloads complete:")
print("  - checkpoint_epoch250.pt (final model)")
print("  - training_config.json (hyperparameters + LR schedule)")
print("  - loss_history.csv (per-epoch loss)")
print("  - loss_curve.png (loss visualization)")
print("  - validation_images.zip (all validation images - 5 per class)")