# ðŸŽ² GAN Data Augmentation (COCO Format)

This notebook trains a **Conditional DCGAN** and generates synthetic full-scene images with COCO annotations.
The output is directly compatible with `DiceDetectionDataset` used in `3_augmentation_comparison.ipynb`.

**Output format:**
- Full scene images with dice placed on backgrounds
- COCO-format `_annotations.coco.json` file

## 1. Setup & Installation

In [None]:
!pip install roboflow torchvision matplotlib tqdm pillow --quiet

In [None]:
import os
import json
import random
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Download Dataset & Extract Dice Crops

In [None]:
from roboflow import Roboflow

rf = Roboflow(api_key="kd9lS9tvh5StEQtSA6i9")
project = rf.workspace("workspace-spezm").project("dice-0sexk")
dataset = project.version(2).download("coco")
print(f"Dataset downloaded to: {dataset.location}")

In [None]:
IMG_SIZE = 64
OUTPUT_DIR = 'gan_training_data'
ANNOTATION_FILE = f'{dataset.location}/train/_annotations.coco.json'
IMAGE_BASE_PATH = f'{dataset.location}/train'

with open(ANNOTATION_FILE, 'r') as f:
    annotations = json.load(f)

categories = {cat['id']: cat['name'] for cat in annotations['categories']}
valid_categories = {k: v for k, v in categories.items() if v.isdigit()}
print(f"Categories: {valid_categories}")

image_id_to_info = {
    img['id']: {'file_name': img['file_name'], 'width': img['width'], 'height': img['height']}
    for img in annotations['images']
}

os.makedirs(OUTPUT_DIR, exist_ok=True)
for cat_name in valid_categories.values():
    os.makedirs(os.path.join(OUTPUT_DIR, cat_name), exist_ok=True)

class_counts = Counter()
print("\nCropping dice images...")
for ann in tqdm(annotations['annotations']):
    category_id = ann['category_id']
    if category_id not in valid_categories:
        continue
    
    image_id = ann['image_id']
    bbox = ann['bbox']
    category_name = valid_categories[category_id]
    image_info = image_id_to_info.get(image_id)
    if not image_info:
        continue
    
    image_path = os.path.join(IMAGE_BASE_PATH, image_info['file_name'])
    try:
        img = Image.open(image_path).convert('RGB')
        x_min, y_min, width, height = [int(b) for b in bbox]
        x_max, y_max = x_min + width, y_min + height
        x_min, y_min = max(0, x_min), max(0, y_min)
        x_max, y_max = min(img.width, x_max), min(img.height, y_max)
        
        cropped = img.crop((x_min, y_min, x_max, y_max))
        resized = cropped.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
        
        output_filename = f"{image_id}_{ann['id']}.png"
        output_path = os.path.join(OUTPUT_DIR, category_name, output_filename)
        resized.save(output_path)
        class_counts[category_name] += 1
    except Exception as e:
        print(f"Error processing {image_path}: {e}")

print("\nâœ… Cropping complete!")
for cat in sorted(class_counts.keys()):
    print(f"  Class {cat}: {class_counts[cat]} images")

## 3. GAN Architecture

In [None]:
LATENT_DIM = 100
NUM_CLASSES = 6
EMBED_DIM = 50
NGF = 64
NDF = 64
NC = 3
BATCH_SIZE = 32
NUM_EPOCHS = 200
LR = 0.0002
BETA1 = 0.5

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, embed_dim, ngf, nc):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        input_dim = latent_dim + embed_dim
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(input_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)
        x = torch.cat([noise, label_embed], dim=1)
        x = x.view(x.size(0), -1, 1, 1)
        return self.main(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, ndf, nc, img_size):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, img_size * img_size)
        self.img_size = img_size
        
        self.main = nn.Sequential(
            nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        label_embed = self.label_embedding(labels)
        label_channel = label_embed.view(labels.size(0), 1, self.img_size, self.img_size)
        x = torch.cat([img, label_channel], dim=1)
        return self.main(x).view(-1, 1).squeeze(1)

## 4. Dataset & DataLoader

In [None]:
class DiceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        for label_idx, label_name in enumerate(['1', '2', '3', '4', '5', '6']):
            label_dir = os.path.join(root_dir, label_name)
            if os.path.exists(label_dir):
                for img_name in os.listdir(label_dir):
                    if img_name.endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((os.path.join(label_dir, img_name), label_idx))
        print(f"Loaded {len(self.samples)} samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dice_dataset = DiceDataset(OUTPUT_DIR, transform=transform)
dataloader = DataLoader(dice_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

## 5. Training

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG = Generator(LATENT_DIM, NUM_CLASSES, EMBED_DIM, NGF, NC).to(device)
netD = Discriminator(NUM_CLASSES, NDF, NC, IMG_SIZE).to(device)
netG.apply(weights_init)
netD.apply(weights_init)

print(f"Generator params: {sum(p.numel() for p in netG.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in netD.parameters()):,}")

In [None]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))

G_losses, D_losses = [], []

print("Starting Training...")
for epoch in range(NUM_EPOCHS):
    epoch_D_loss, epoch_G_loss = 0, 0
    
    for real_imgs, labels in dataloader:
        real_imgs, labels = real_imgs.to(device), labels.to(device)
        batch_size = real_imgs.size(0)
        
        real_label = torch.ones(batch_size, device=device) * 0.9
        fake_label = torch.zeros(batch_size, device=device) + 0.1
        
        # Train Discriminator
        netD.zero_grad()
        output_real = netD(real_imgs, labels)
        errD_real = criterion(output_real, real_label)
        errD_real.backward()
        
        noise = torch.randn(batch_size, LATENT_DIM, device=device)
        fake_imgs = netG(noise, labels)
        output_fake = netD(fake_imgs.detach(), labels)
        errD_fake = criterion(output_fake, fake_label)
        errD_fake.backward()
        optimizerD.step()
        
        # Train Generator
        netG.zero_grad()
        output = netD(fake_imgs, labels)
        errG = criterion(output, real_label)
        errG.backward()
        optimizerG.step()
        
        epoch_D_loss += (errD_real + errD_fake).item()
        epoch_G_loss += errG.item()
    
    G_losses.append(epoch_G_loss / len(dataloader))
    D_losses.append(epoch_D_loss / len(dataloader))
    
    if (epoch + 1) % 20 == 0 or epoch == 0:
        print(f"[{epoch+1:3d}/{NUM_EPOCHS}] Loss_D: {D_losses[-1]:.4f} | Loss_G: {G_losses[-1]:.4f}")

print("\nâœ… Training complete!")

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('training_losses.png', dpi=150)
plt.show()

## 6. Generate COCO-Format Dataset

Create full scene images with dice on backgrounds and COCO annotations.

In [None]:
# Extract background images from original dataset
BACKGROUND_DIR = 'backgrounds'
os.makedirs(BACKGROUND_DIR, exist_ok=True)

print("Extracting background samples from training images...")
train_images_dir = f'{dataset.location}/train'
bg_count = 0

for img_file in os.listdir(train_images_dir):
    if img_file.endswith(('.jpg', '.png', '.jpeg')):
        img = Image.open(os.path.join(train_images_dir, img_file)).convert('RGB')
        img.save(os.path.join(BACKGROUND_DIR, f'bg_{bg_count:04d}.jpg'))
        bg_count += 1
        if bg_count >= 50:
            break

print(f"Extracted {bg_count} background images")

In [None]:
# Configuration
SYNTHETIC_COCO_DIR = 'synthetic_coco_dataset'
SCENE_SIZE = (640, 640)
DICE_SIZE_RANGE = (60, 120)
DICE_PER_IMAGE = (1, 4)

os.makedirs(os.path.join(SYNTHETIC_COCO_DIR, 'train'), exist_ok=True)

# Calculate images needed per class
current_counts = {str(i): class_counts.get(str(i), 0) for i in range(1, 7)}
target_count = max(current_counts.values())
images_to_generate = {k: max(0, target_count - v) for k, v in current_counts.items()}

print(f"Target count per class: {target_count}")
total_synthetic_images = sum(images_to_generate.values()) // 2 + 50  # Estimate
print(f"Will generate approximately {total_synthetic_images} full scene images")

In [None]:
def generate_dice_image(generator, class_idx, size):
    """Generate a single dice image of given class and resize"""
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(1, LATENT_DIM, device=device)
        label = torch.tensor([class_idx], device=device)
        fake_img = generator(noise, label)[0].cpu().numpy().transpose(1, 2, 0)
        fake_img = ((fake_img + 1) / 2 * 255).astype(np.uint8)
        fake_img = np.clip(fake_img, 0, 255)
        pil_img = Image.fromarray(fake_img)
        return pil_img.resize((size, size), Image.LANCZOS)

def check_overlap(new_box, existing_boxes, min_distance=10):
    """Check if new box overlaps with existing boxes"""
    for box in existing_boxes:
        if (new_box[0] < box[2] + min_distance and new_box[2] > box[0] - min_distance and
            new_box[1] < box[3] + min_distance and new_box[3] > box[1] - min_distance):
            return True
    return False

# Load backgrounds
background_files = [f for f in os.listdir(BACKGROUND_DIR) if f.endswith(('.jpg', '.png'))]

# COCO format structures
coco_images = []
coco_annotations = []
coco_categories = [{'id': i, 'name': str(i)} for i in range(1, 7)]

image_id = 1
annotation_id = 1
generated_per_class = {str(i): 0 for i in range(1, 7)}

print("\nðŸŽ¨ Generating synthetic COCO dataset...")

for scene_idx in tqdm(range(total_synthetic_images)):
    # Load random background and resize
    bg_file = random.choice(background_files)
    background = Image.open(os.path.join(BACKGROUND_DIR, bg_file)).convert('RGB')
    background = background.resize(SCENE_SIZE, Image.LANCZOS)
    scene = background.copy()
    
    # Determine number of dice and which classes to generate
    num_dice = random.randint(*DICE_PER_IMAGE)
    
    # Prioritize underrepresented classes
    needed_classes = [k for k, v in images_to_generate.items() if generated_per_class[k] < v]
    if not needed_classes:
        needed_classes = [str(i) for i in range(1, 7)]
    
    placed_boxes = []
    scene_annotations = []
    
    for _ in range(num_dice):
        # Select class
        class_name = random.choice(needed_classes)
        class_idx = int(class_name) - 1
        
        # Random dice size and position
        dice_size = random.randint(*DICE_SIZE_RANGE)
        
        # Try to place dice without overlap
        max_attempts = 20
        for attempt in range(max_attempts):
            x = random.randint(0, SCENE_SIZE[0] - dice_size)
            y = random.randint(0, SCENE_SIZE[1] - dice_size)
            new_box = [x, y, x + dice_size, y + dice_size]
            
            if not check_overlap(new_box, placed_boxes):
                # Generate and paste dice
                dice_img = generate_dice_image(netG, class_idx, dice_size)
                scene.paste(dice_img, (x, y))
                
                placed_boxes.append(new_box)
                scene_annotations.append({
                    'id': annotation_id,
                    'image_id': image_id,
                    'category_id': int(class_name),
                    'bbox': [x, y, dice_size, dice_size],  # COCO format: x, y, w, h
                    'area': dice_size * dice_size,
                    'iscrowd': 0
                })
                annotation_id += 1
                generated_per_class[class_name] += 1
                break
    
    if scene_annotations:
        # Save image
        img_filename = f"synthetic_{image_id:05d}.jpg"
        scene.save(os.path.join(SYNTHETIC_COCO_DIR, 'train', img_filename))
        
        coco_images.append({
            'id': image_id,
            'file_name': img_filename,
            'width': SCENE_SIZE[0],
            'height': SCENE_SIZE[1]
        })
        coco_annotations.extend(scene_annotations)
        image_id += 1

# Save COCO annotations
coco_data = {
    'images': coco_images,
    'annotations': coco_annotations,
    'categories': coco_categories
}

with open(os.path.join(SYNTHETIC_COCO_DIR, 'train', '_annotations.coco.json'), 'w') as f:
    json.dump(coco_data, f, indent=2)

print(f"\nâœ… Generated {len(coco_images)} images with {len(coco_annotations)} annotations")
print(f"\nPer-class generation counts:")
for k, v in generated_per_class.items():
    print(f"  Class {k}: {v} dice")

In [None]:
# Visualize sample generated scenes
import matplotlib.patches as patches

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

sample_images = random.sample(coco_images, min(6, len(coco_images)))

for ax, img_info in zip(axes, sample_images):
    img_path = os.path.join(SYNTHETIC_COCO_DIR, 'train', img_info['file_name'])
    img = Image.open(img_path)
    ax.imshow(img)
    
    # Draw bounding boxes
    for ann in coco_annotations:
        if ann['image_id'] == img_info['id']:
            x, y, w, h = ann['bbox']
            rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
            ax.text(x, y - 5, f"Dice {ann['category_id']}", color='lime', fontsize=10, weight='bold')
    
    ax.axis('off')
    ax.set_title(img_info['file_name'])

plt.suptitle('Sample Generated Scenes with COCO Annotations', fontsize=14)
plt.tight_layout()
plt.savefig('sample_coco_scenes.png', dpi=150)
plt.show()

## 7. Verify Compatibility with DiceDetectionDataset

In [None]:
# Test loading with DiceDetectionDataset
import sys
sys.path.append('./src')

try:
    from src.dataset import DiceDetectionDataset
    
    synthetic_dataset = DiceDetectionDataset(
        root_dir=os.path.join(SYNTHETIC_COCO_DIR, 'train'),
        annotation_file='_annotations.coco.json',
        split='train'
    )
    
    print(f"âœ… Successfully loaded synthetic dataset!")
    print(f"   Number of images: {len(synthetic_dataset)}")
    print(f"   Number of classes: {synthetic_dataset.num_classes}")
    print(f"   Class distribution: {synthetic_dataset.get_class_distribution()}")
    
    # Test getting an item
    image, target = synthetic_dataset[0]
    print(f"\n   Sample image shape: {image.shape}")
    print(f"   Sample boxes: {target['boxes'].shape}")
    print(f"   Sample labels: {target['labels']}")
    
except Exception as e:
    print(f"Could not test with DiceDetectionDataset: {e}")
    print("The dataset structure is still COCO-compatible.")

## 8. Save Model

In [None]:
MODEL_DIR = 'gan_models'
os.makedirs(MODEL_DIR, exist_ok=True)

torch.save({
    'generator_state_dict': netG.state_dict(),
    'discriminator_state_dict': netD.state_dict(),
    'epoch': NUM_EPOCHS,
    'G_losses': G_losses,
    'D_losses': D_losses,
}, os.path.join(MODEL_DIR, 'conditional_dcgan_dice.pth'))

print(f"âœ… Model saved to {MODEL_DIR}/conditional_dcgan_dice.pth")

## ðŸš€ Usage

The synthetic dataset is now ready for use with `3_augmentation_comparison.ipynb`:

```python
from src.dataset import DiceDetectionDataset

synthetic_train = DiceDetectionDataset(
    root_dir='synthetic_coco_dataset/train',
    annotation_file='_annotations.coco.json',
    split='train'
)
```

**Generated files:**
- `synthetic_coco_dataset/train/` - Images with dice on backgrounds
- `synthetic_coco_dataset/train/_annotations.coco.json` - COCO annotations
- `gan_models/conditional_dcgan_dice.pth` - Trained GAN