# ðŸŽ² GAN Data Augmentation (COCO Format)

This notebook trains a **Conditional DCGAN** and generates synthetic full-scene images with COCO annotations.
The output is directly compatible with `DiceDetectionDataset` used in `3_augmentation_comparison.ipynb`.

**Key Features:**
- Uses local balanced annotations (`Annotations/train_image_balanced.coco.json`) - Zipfian balanced on images
- Downloads images from Roboflow
- Trains conditional DCGAN on dice crops
- Generates full scene images with dice placed on backgrounds
- Outputs COCO-format `_annotations.coco.json` file

## 1. Setup & Installation

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install roboflow pillow matplotlib seaborn tqdm numpy

In [None]:
# Clone the git repository with the code
import os
if not os.path.exists('Dice-Detection'):
    !git clone https://github.com/Adr44mo/Dice-Detection.git
    %cd Dice-Detection
else:
    %cd Dice-Detection

# Add src to path
import sys
sys.path.append('./src')

## 2. Import Libraries

In [None]:
import json
import random
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

# Import custom modules from src
from src.gan import (
    Generator,
    Discriminator,
    weights_init,
    generate_dice_image,
    extract_backgrounds,
    create_synthetic_coco_dataset
)
from src.dataset import DiceDetectionDataset  # For verification

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Download Dataset from Roboflow

In [None]:
from roboflow import Roboflow

rf = Roboflow(api_key="kd9lS9tvh5StEQtSA6i9")
project = rf.workspace("workspace-spezm").project("dice-0sexk")
dataset = project.version(2).download("coco")
print(f"Dataset downloaded to: {dataset.location}")

## 4. Load Balanced Annotations

We use the local balanced annotations (`train_image_balanced.coco.json`) which contains a Zipfian-balanced subset of images, while using the images downloaded from Roboflow.

In [None]:
# Load balanced annotations from local file
BALANCED_ANNOTATION_FILE = 'Annotations/train_image_balanced.coco.json'
IMAGE_BASE_PATH = f'{dataset.location}/train'

with open(BALANCED_ANNOTATION_FILE, 'r') as f:
    balanced_annotations = json.load(f)

print(f"Loaded balanced annotations from: {BALANCED_ANNOTATION_FILE}")
print(f"Images: {len(balanced_annotations['images'])}")
print(f"Annotations: {len(balanced_annotations['annotations'])}")
print(f"Categories: {len(balanced_annotations['categories'])}")

# Extract categories (only numeric dice classes)
categories = {cat['id']: cat['name'] for cat in balanced_annotations['categories']}
valid_categories = {k: v for k, v in categories.items() if v.isdigit()}
print(f"\nValid dice categories: {valid_categories}")

# Create image lookup
image_id_to_info = {
    img['id']: {'file_name': img['file_name'], 'width': img['width'], 'height': img['height']}
    for img in balanced_annotations['images']
}

## 5. Extract Dice Crops

Crop dice from downloaded images using bounding boxes from balanced annotations.

In [None]:
IMG_SIZE = 64
OUTPUT_DIR = 'gan_training_data'

os.makedirs(OUTPUT_DIR, exist_ok=True)
for cat_name in valid_categories.values():
    os.makedirs(os.path.join(OUTPUT_DIR, cat_name), exist_ok=True)

class_counts = Counter()
print("\nCropping dice images from balanced dataset...")

for ann in tqdm(balanced_annotations['annotations']):
    category_id = ann['category_id']
    if category_id not in valid_categories:
        continue
    
    image_id = ann['image_id']
    bbox = ann['bbox']
    category_name = valid_categories[category_id]
    image_info = image_id_to_info.get(image_id)
    
    if not image_info:
        continue
    
    image_path = os.path.join(IMAGE_BASE_PATH, image_info['file_name'])
    
    try:
        img = Image.open(image_path).convert('RGB')
        x_min, y_min, width, height = [int(b) for b in bbox]
        x_max, y_max = x_min + width, y_min + height
        x_min, y_min = max(0, x_min), max(0, y_min)
        x_max, y_max = min(img.width, x_max), min(img.height, y_max)
        
        cropped = img.crop((x_min, y_min, x_max, y_max))
        resized = cropped.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
        
        output_filename = f"{image_id}_{ann['id']}.png"
        output_path = os.path.join(OUTPUT_DIR, category_name, output_filename)
        resized.save(output_path)
        class_counts[category_name] += 1
    except Exception as e:
        print(f"Error processing {image_path}: {e}")

print("\nâœ… Cropping complete!")
for cat in sorted(class_counts.keys()):
    print(f"  Class {cat}: {class_counts[cat]} images")

## 6. GAN Architecture & Hyperparameters

In [None]:
# Hyperparameters
LATENT_DIM = 100
NUM_CLASSES = 6
EMBED_DIM = 50
NGF = 64
NDF = 64
NC = 3
BATCH_SIZE = 32
NUM_EPOCHS = 200
LR = 0.0002
BETA1 = 0.5

print(f"Latent dimension: {LATENT_DIM}")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Number of epochs: {NUM_EPOCHS}")

In [None]:
# Initialize models from src.gan
netG = Generator(LATENT_DIM, NUM_CLASSES, EMBED_DIM, NGF, NC).to(device)
netD = Discriminator(NUM_CLASSES, NDF, NC, IMG_SIZE).to(device)

# Apply weight initialization
netG.apply(weights_init)
netD.apply(weights_init)

print(f"Generator params: {sum(p.numel() for p in netG.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in netD.parameters()):,}")

## 7. Prepare Dataset & DataLoader

In [None]:
class DiceDataset(Dataset):
    """Dataset for loading cropped dice images for GAN training."""
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        for label_idx, label_name in enumerate(['1', '2', '3', '4', '5', '6']):
            label_dir = os.path.join(root_dir, label_name)
            if os.path.exists(label_dir):
                for img_name in os.listdir(label_dir):
                    if img_name.endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((os.path.join(label_dir, img_name), label_idx))
        print(f"Loaded {len(self.samples)} samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Data transforms
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dice_dataset = DiceDataset(OUTPUT_DIR, transform=transform)
dataloader = DataLoader(dice_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

print(f"Number of batches: {len(dataloader)}")

## 8. Training

In [None]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))

G_losses, D_losses = [], []

print("Starting Training...")
for epoch in range(NUM_EPOCHS):
    epoch_D_loss, epoch_G_loss = 0, 0
    
    for real_imgs, labels in dataloader:
        real_imgs, labels = real_imgs.to(device), labels.to(device)
        batch_size = real_imgs.size(0)
        
        real_label = torch.ones(batch_size, device=device) * 0.9
        fake_label = torch.zeros(batch_size, device=device) + 0.1
        
        # Train Discriminator
        netD.zero_grad()
        output_real = netD(real_imgs, labels)
        errD_real = criterion(output_real, real_label)
        errD_real.backward()
        
        noise = torch.randn(batch_size, LATENT_DIM, device=device)
        fake_imgs = netG(noise, labels)
        output_fake = netD(fake_imgs.detach(), labels)
        errD_fake = criterion(output_fake, fake_label)
        errD_fake.backward()
        optimizerD.step()
        
        # Train Generator
        netG.zero_grad()
        output = netD(fake_imgs, labels)
        errG = criterion(output, real_label)
        errG.backward()
        optimizerG.step()
        
        epoch_D_loss += (errD_real + errD_fake).item()
        epoch_G_loss += errG.item()
    
    G_losses.append(epoch_G_loss / len(dataloader))
    D_losses.append(epoch_D_loss / len(dataloader))
    
    if (epoch + 1) % 20 == 0 or epoch == 0:
        print(f"[{epoch+1:3d}/{NUM_EPOCHS}] Loss_D: {D_losses[-1]:.4f} | Loss_G: {G_losses[-1]:.4f}")

print("\nâœ… Training complete!")

In [None]:
# Plot training losses
plt.figure(figsize=(10, 4))
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('training_losses.png', dpi=150)
plt.show()

## 9. Generate COCO-Format Dataset

Create full scene images with dice on backgrounds and COCO annotations using functions from `src.gan`.

In [None]:
# Extract background images from downloaded dataset
BACKGROUND_DIR = 'backgrounds'
train_images_dir = f'{dataset.location}/train'

print("Extracting background samples from training images...")
bg_count = extract_backgrounds(train_images_dir, BACKGROUND_DIR, num_backgrounds=50)
print(f"Extracted {bg_count} background images")

In [None]:
# Configuration for synthetic dataset generation
SYNTHETIC_COCO_DIR = 'synthetic_coco_dataset'

# Calculate images needed per class for balancing
current_counts = {str(i): class_counts.get(str(i), 0) for i in range(1, 7)}
target_count = max(current_counts.values())
images_to_generate = {k: max(0, target_count - v) for k, v in current_counts.items()}

print(f"Current class distribution: {current_counts}")
print(f"Target count per class: {target_count}")
total_synthetic_images = sum(images_to_generate.values()) // 2 + 50
print(f"Will generate approximately {total_synthetic_images} full scene images")

# Generation config
gen_config = {
    'scene_size': (640, 640),
    'dice_size_range': (60, 120),
    'dice_per_image': (1, 4),
    'num_images': total_synthetic_images
}

# Generate synthetic COCO dataset
coco_data = create_synthetic_coco_dataset(
    generator=netG,
    background_dir=BACKGROUND_DIR,
    output_dir=SYNTHETIC_COCO_DIR,
    config=gen_config,
    device=device,
    class_counts=current_counts,
    latent_dim=LATENT_DIM
)

In [None]:
# Visualize sample generated scenes
import matplotlib.patches as patches

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

sample_images = random.sample(coco_data['images'], min(6, len(coco_data['images'])))

for ax, img_info in zip(axes, sample_images):
    img_path = os.path.join(SYNTHETIC_COCO_DIR, 'train', img_info['file_name'])
    img = Image.open(img_path)
    ax.imshow(img)
    
    # Draw bounding boxes
    for ann in coco_data['annotations']:
        if ann['image_id'] == img_info['id']:
            x, y, w, h = ann['bbox']
            rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
            ax.text(x, y - 5, f"Dice {ann['category_id']}", color='lime', fontsize=10, weight='bold')
    
    ax.axis('off')
    ax.set_title(img_info['file_name'])

plt.suptitle('Sample Generated Scenes with COCO Annotations', fontsize=14)
plt.tight_layout()
plt.savefig('sample_coco_scenes.png', dpi=150)
plt.show()

## 10. Verify Compatibility with DiceDetectionDataset

In [None]:
# Test loading with DiceDetectionDataset from src
try:
    synthetic_dataset = DiceDetectionDataset(
        root_dir=os.path.join(SYNTHETIC_COCO_DIR, 'train'),
        annotation_file='_annotations.coco.json',
        split='train'
    )
    
    print(f"âœ… Successfully loaded synthetic dataset!")
    print(f"   Number of images: {len(synthetic_dataset)}")
    print(f"   Number of classes: {synthetic_dataset.num_classes}")
    print(f"   Class distribution: {synthetic_dataset.get_class_distribution()}")
    
    # Test getting an item
    image, target = synthetic_dataset[0]
    print(f"\n   Sample image shape: {image.shape}")
    print(f"   Sample boxes: {target['boxes'].shape}")
    print(f"   Sample labels: {target['labels']}")
    
except Exception as e:
    print(f"Could not test with DiceDetectionDataset: {e}")
    print("The dataset structure is still COCO-compatible.")

## 11. Save Model

In [None]:
MODEL_DIR = 'gan_models'
os.makedirs(MODEL_DIR, exist_ok=True)

torch.save({
    'generator_state_dict': netG.state_dict(),
    'discriminator_state_dict': netD.state_dict(),
    'epoch': NUM_EPOCHS,
    'G_losses': G_losses,
    'D_losses': D_losses,
}, os.path.join(MODEL_DIR, 'conditional_dcgan_dice.pth'))

print(f"âœ… Model saved to {MODEL_DIR}/conditional_dcgan_dice.pth")

## ðŸš€ Summary

This notebook:
1. âœ… Downloaded images from Roboflow
2. âœ… Used local balanced annotations (`Annotations/train_image_balanced.coco.json`) - Zipfian balanced on images
3. âœ… Extracted dice crops from balanced subset
4. âœ… Trained Conditional DCGAN on dice crops
5. âœ… Generated synthetic COCO dataset with full scenes

**Output files:**
- `synthetic_coco_dataset/train/` - Images with dice on backgrounds
- `synthetic_coco_dataset/train/_annotations.coco.json` - COCO annotations
- `gan_models/conditional_dcgan_dice.pth` - Trained GAN

**Usage in other notebooks:**
```python
from src.dataset import DiceDetectionDataset

synthetic_train = DiceDetectionDataset(
    root_dir='synthetic_coco_dataset/train',
    annotation_file='_annotations.coco.json',
    split='train'
)
```