# Synthetic Image Generation: Pin + Background Composition

This notebook demonstrates how to create synthetic images by composing a pin object with different backgrounds.

## Use Cases:
- **Data Augmentation**: Generate diverse training data
- **Testing**: Create controlled test scenarios
- **Evaluation**: Test detection pipeline with known ground truth
- **Prototyping**: Quickly visualize pin in different environments

## Workflow:
1. Load pin image (with optional mask)
2. Load background images
3. Apply transformations (scale, rotation, position)
4. Composite pin onto backgrounds
5. Save synthetic images
6. (Optional) Test with detection pipeline

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional, Dict
import random

print("✓ Imports successful")

## 2. Helper Functions for Image Composition

In [None]:
def load_pin_with_mask(pin_path: str, mask_path: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load pin image and its mask.
    
    Args:
        pin_path: Path to pin image
        mask_path: Optional path to mask. If None, creates automatic mask.
    
    Returns:
        pin_img: RGB image as numpy array
        mask: Binary mask as numpy array
    """
    # Load pin image
    pin_img = cv2.imread(pin_path)
    if pin_img is None:
        raise ValueError(f"Could not load pin image from {pin_path}")
    pin_img = cv2.cvtColor(pin_img, cv2.COLOR_BGR2RGB)
    
    # Load or create mask
    if mask_path is not None:
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise ValueError(f"Could not load mask from {mask_path}")
        # Ensure binary mask
        mask = (mask > 127).astype(np.uint8) * 255
    else:
        # Create automatic mask (assume white/bright background)
        gray = cv2.cvtColor(pin_img, cv2.COLOR_RGB2GRAY)
        _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
        
        # Clean up mask with morphological operations
        kernel = np.ones((5, 5), np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    return pin_img, mask


def transform_pin(pin_img: np.ndarray, 
                  mask: np.ndarray,
                  scale: float = 1.0,
                  rotation: float = 0.0,
                  flip_horizontal: bool = False,
                  brightness: float = 1.0,
                  contrast: float = 1.0) -> Tuple[np.ndarray, np.ndarray]:
    """
    Apply transformations to pin image and mask.
    
    Args:
        pin_img: RGB image
        mask: Binary mask
        scale: Scale factor (e.g., 0.5 = half size, 2.0 = double size)
        rotation: Rotation angle in degrees (clockwise)
        flip_horizontal: Whether to flip horizontally
        brightness: Brightness adjustment (1.0 = no change)
        contrast: Contrast adjustment (1.0 = no change)
    
    Returns:
        Transformed pin image and mask
    """
    h, w = pin_img.shape[:2]
    
    # Convert to PIL for easier transformations
    pin_pil = Image.fromarray(pin_img)
    mask_pil = Image.fromarray(mask)
    
    # Flip
    if flip_horizontal:
        pin_pil = pin_pil.transpose(Image.FLIP_LEFT_RIGHT)
        mask_pil = mask_pil.transpose(Image.FLIP_LEFT_RIGHT)
    
    # Rotate
    if rotation != 0:
        pin_pil = pin_pil.rotate(rotation, expand=True, fillcolor=(0, 0, 0))
        mask_pil = mask_pil.rotate(rotation, expand=True, fillcolor=0)
    
    # Scale
    if scale != 1.0:
        new_w = int(pin_pil.width * scale)
        new_h = int(pin_pil.height * scale)
        pin_pil = pin_pil.resize((new_w, new_h), Image.LANCZOS)
        mask_pil = mask_pil.resize((new_w, new_h), Image.LANCZOS)
    
    # Brightness and contrast
    if brightness != 1.0:
        enhancer = ImageEnhance.Brightness(pin_pil)
        pin_pil = enhancer.enhance(brightness)
    
    if contrast != 1.0:
        enhancer = ImageEnhance.Contrast(pin_pil)
        pin_pil = enhancer.enhance(contrast)
    
    # Convert back to numpy
    pin_transformed = np.array(pin_pil)
    mask_transformed = np.array(mask_pil)
    
    # Ensure mask is binary
    mask_transformed = (mask_transformed > 127).astype(np.uint8) * 255
    
    return pin_transformed, mask_transformed


def composite_pin_on_background(background: np.ndarray,
                                pin_img: np.ndarray,
                                mask: np.ndarray,
                                position: Optional[Tuple[int, int]] = None,
                                blend_edges: bool = True,
                                edge_blur_radius: int = 2) -> Tuple[np.ndarray, Dict]:
    """
    Composite pin onto background at specified position.
    
    Args:
        background: Background image (RGB)
        pin_img: Pin image (RGB)
        mask: Binary mask for pin
        position: (x, y) position for pin center. If None, places randomly.
        blend_edges: Whether to blend edges for smoother composite
        edge_blur_radius: Radius for edge blending
    
    Returns:
        Composite image and metadata dictionary
    """
    bg_h, bg_w = background.shape[:2]
    pin_h, pin_w = pin_img.shape[:2]
    
    # Determine position
    if position is None:
        # Random position (ensure pin fits in background)
        max_x = max(bg_w - pin_w, pin_w // 2)
        max_y = max(bg_h - pin_h, pin_h // 2)
        x = random.randint(pin_w // 2, max_x)
        y = random.randint(pin_h // 2, max_y)
    else:
        x, y = position
    
    # Calculate bounding box
    x1 = max(0, x - pin_w // 2)
    y1 = max(0, y - pin_h // 2)
    x2 = min(bg_w, x1 + pin_w)
    y2 = min(bg_h, y1 + pin_h)
    
    # Adjust pin size if it doesn't fit
    pin_x1 = 0 if x1 >= 0 else -(x - pin_w // 2)
    pin_y1 = 0 if y1 >= 0 else -(y - pin_h // 2)
    pin_x2 = pin_x1 + (x2 - x1)
    pin_y2 = pin_y1 + (y2 - y1)
    
    # Create composite
    result = background.copy()
    
    # Get regions
    pin_region = pin_img[pin_y1:pin_y2, pin_x1:pin_x2]
    mask_region = mask[pin_y1:pin_y2, pin_x1:pin_x2]
    
    # Blend edges if requested
    if blend_edges and edge_blur_radius > 0:
        mask_region = cv2.GaussianBlur(mask_region, 
                                       (edge_blur_radius * 2 + 1, edge_blur_radius * 2 + 1), 
                                       0)
    
    # Normalize mask to [0, 1]
    mask_normalized = mask_region.astype(float) / 255.0
    mask_normalized = np.expand_dims(mask_normalized, axis=2)
    
    # Composite
    bg_region = result[y1:y2, x1:x2]
    result[y1:y2, x1:x2] = (pin_region * mask_normalized + 
                            bg_region * (1 - mask_normalized)).astype(np.uint8)
    
    # Metadata
    metadata = {
        'bbox': [x1, y1, x2, y2],
        'center': [x, y],
        'pin_size': [pin_w, pin_h],
        'background_size': [bg_w, bg_h]
    }
    
    return result, metadata


def visualize_composite(image: np.ndarray, 
                       metadata: Dict,
                       title: str = "Synthetic Image"):
    """
    Visualize composite image with bounding box.
    """
    plt.figure(figsize=(12, 8))
    plt.imshow(image)
    
    # Draw bounding box
    x1, y1, x2, y2 = metadata['bbox']
    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                         fill=False, color='red', linewidth=2)
    plt.gca().add_patch(rect)
    
    # Draw center point
    cx, cy = metadata['center']
    plt.plot(cx, cy, 'r+', markersize=15, markeredgewidth=2)
    
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

print("✓ Helper functions defined")

## 3. Configuration

In [None]:
# ==================== CONFIGURATION ====================

# Pin configuration
PIN_IMAGE_PATH = "../data/pins/my_pin.jpg"  # Path to your pin image
PIN_MASK_PATH = None  # Optional: path to pin mask (None = auto-generate)

# Background images
BACKGROUND_PATHS = [
    "../data/backgrounds/outdoor_scene.jpg",
    "../data/backgrounds/indoor_table.jpg",
    "../data/backgrounds/wall.jpg",
    "../data/backgrounds/desk.jpg",
]

# Transformation settings
TRANSFORMATIONS = [
    # Format: (scale, rotation, flip_horizontal, brightness, contrast)
    (1.0, 0, False, 1.0, 1.0),      # Original
    (0.8, 15, False, 1.0, 1.0),     # Smaller, rotated
    (1.2, -10, False, 1.1, 1.0),    # Larger, rotated, brighter
    (1.0, 0, True, 0.9, 1.1),       # Flipped, darker, more contrast
    (0.9, 30, False, 1.0, 0.9),     # Smaller, rotated, less contrast
]

# Positioning
# Set to None for random positioning, or specify (x, y) coordinates
POSITIONS = None  # Will place randomly
# POSITIONS = [(400, 300), (600, 400), ...]  # Specific positions

# Composition settings
BLEND_EDGES = True
EDGE_BLUR_RADIUS = 2

# Output settings
OUTPUT_DIR = "../data/synthetic_images"
SAVE_METADATA = True  # Save JSON with bounding box info

print("✓ Configuration set")
print(f"Pin: {PIN_IMAGE_PATH}")
print(f"Backgrounds: {len(BACKGROUND_PATHS)} images")
print(f"Transformations: {len(TRANSFORMATIONS)} variants")
print(f"Output: {OUTPUT_DIR}")

## 4. Load Pin Image

In [None]:
# Load pin
print("Loading pin image...")
pin_img, pin_mask = load_pin_with_mask(PIN_IMAGE_PATH, PIN_MASK_PATH)

print(f"✓ Pin loaded: {pin_img.shape}")
print(f"✓ Mask loaded: {pin_mask.shape}")

# Visualize pin and mask
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(pin_img)
axes[0].set_title("Pin Image")
axes[0].axis('off')

axes[1].imshow(pin_mask, cmap='gray')
axes[1].set_title("Pin Mask")
axes[1].axis('off')

# Show masked pin
masked_pin = pin_img.copy()
masked_pin[pin_mask == 0] = 0
axes[2].imshow(masked_pin)
axes[2].set_title("Masked Pin")
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 5. Load Background Images

In [None]:
# Load all backgrounds
print("Loading background images...")
backgrounds = []

for bg_path in BACKGROUND_PATHS:
    if os.path.exists(bg_path):
        bg = cv2.imread(bg_path)
        if bg is not None:
            bg = cv2.cvtColor(bg, cv2.COLOR_BGR2RGB)
            backgrounds.append((bg_path, bg))
            print(f"  ✓ {bg_path}: {bg.shape}")
        else:
            print(f"  ✗ Failed to load: {bg_path}")
    else:
        print(f"  ✗ File not found: {bg_path}")

print(f"\n✓ Loaded {len(backgrounds)} backgrounds")

# Visualize backgrounds
if len(backgrounds) > 0:
    n_cols = min(4, len(backgrounds))
    n_rows = (len(backgrounds) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    if len(backgrounds) == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_rows > 1 else axes
    
    for idx, (bg_path, bg) in enumerate(backgrounds):
        axes[idx].imshow(bg)
        axes[idx].set_title(Path(bg_path).name)
        axes[idx].axis('off')
    
    # Hide extra subplots
    for idx in range(len(backgrounds), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("\n⚠ No backgrounds loaded. Please check your BACKGROUND_PATHS.")

## 6. Generate Synthetic Images

In [None]:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Generating synthetic images...")
print(f"Total combinations: {len(backgrounds)} backgrounds × {len(TRANSFORMATIONS)} transforms = {len(backgrounds) * len(TRANSFORMATIONS)} images\n")

synthetic_images = []
image_count = 0

for bg_idx, (bg_path, background) in enumerate(backgrounds):
    bg_name = Path(bg_path).stem
    
    for trans_idx, (scale, rotation, flip, brightness, contrast) in enumerate(TRANSFORMATIONS):
        # Transform pin
        pin_transformed, mask_transformed = transform_pin(
            pin_img, pin_mask,
            scale=scale,
            rotation=rotation,
            flip_horizontal=flip,
            brightness=brightness,
            contrast=contrast
        )
        
        # Determine position
        if POSITIONS is not None and isinstance(POSITIONS, list):
            position = POSITIONS[image_count % len(POSITIONS)]
        else:
            position = None  # Random
        
        # Composite
        composite, metadata = composite_pin_on_background(
            background,
            pin_transformed,
            mask_transformed,
            position=position,
            blend_edges=BLEND_EDGES,
            edge_blur_radius=EDGE_BLUR_RADIUS
        )
        
        # Generate filename
        filename = f"synthetic_{bg_name}_t{trans_idx}_{image_count:04d}.jpg"
        output_path = os.path.join(OUTPUT_DIR, filename)
        
        # Save image
        composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR)
        cv2.imwrite(output_path, composite_bgr)
        
        # Save metadata
        if SAVE_METADATA:
            import json
            metadata_path = output_path.replace('.jpg', '.json')
            metadata_full = {
                'image_path': output_path,
                'background': bg_path,
                'pin': PIN_IMAGE_PATH,
                'transformations': {
                    'scale': scale,
                    'rotation': rotation,
                    'flip_horizontal': flip,
                    'brightness': brightness,
                    'contrast': contrast
                },
                'bbox': metadata['bbox'],
                'center': metadata['center'],
                'pin_size': metadata['pin_size'],
                'background_size': metadata['background_size']
            }
            with open(metadata_path, 'w') as f:
                json.dump(metadata_full, f, indent=2)
        
        synthetic_images.append((composite, metadata, filename))
        image_count += 1
        
        print(f"  ✓ [{image_count}/{len(backgrounds) * len(TRANSFORMATIONS)}] {filename}")

print(f"\n✓ Generated {len(synthetic_images)} synthetic images")
print(f"✓ Saved to: {OUTPUT_DIR}")

## 7. Visualize Results

In [None]:
# Display sample synthetic images
n_samples = min(6, len(synthetic_images))
sample_indices = random.sample(range(len(synthetic_images)), n_samples)

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, sample_idx in enumerate(sample_indices):
    composite, metadata, filename = synthetic_images[sample_idx]
    
    axes[idx].imshow(composite)
    
    # Draw bounding box
    x1, y1, x2, y2 = metadata['bbox']
    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                         fill=False, color='red', linewidth=2)
    axes[idx].add_patch(rect)
    
    axes[idx].set_title(filename, fontsize=10)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

## 8. (Optional) Test with Detection Pipeline

Test the generated synthetic images with the custom grounding pipeline to evaluate detection performance.

In [None]:
# Uncomment to test with detection pipeline

# from ai_vision_tool.pipelines.custom_grounding import CustomGroundingPipeline
# from ai_vision_tool.utils.visualization import save_results

# # Initialize pipeline
# print("Initializing detection pipeline...")
# pipeline = CustomGroundingPipeline(
#     use_visual_matching=True,
#     use_text_grounding=True
# )

# # Process reference (original pin)
# print("Processing reference image...")
# reference_features = pipeline.process_reference_image(
#     reference_image=Image.fromarray(pin_img),
#     reference_mask=Image.fromarray(pin_mask)
# )

# # Test on synthetic images
# TEXT_PROMPT = "custom pin"  # Modify based on your pin
# detection_results = []

# print(f"\nTesting detection on {len(synthetic_images)} synthetic images...")
# for idx, (composite, gt_metadata, filename) in enumerate(synthetic_images):
#     # Run detection
#     results = pipeline.detect_and_segment(
#         target_image=Image.fromarray(composite),
#         reference_features=reference_features,
#         text_prompt=TEXT_PROMPT,
#         fusion_strategy='multiply',
#         similarity_threshold=0.5,
#         detection_threshold=0.3
#     )
#     
#     n_detected = len(results['masks'])
#     detection_results.append({
#         'filename': filename,
#         'ground_truth': gt_metadata['bbox'],
#         'detected': n_detected,
#         'results': results
#     })
#     
#     status = "✓" if n_detected > 0 else "✗"
#     print(f"  {status} [{idx+1}/{len(synthetic_images)}] {filename}: {n_detected} detections")

# # Summary
# total = len(detection_results)
# detected = sum(1 for r in detection_results if r['detected'] > 0)
# print(f"\nDetection Summary:")
# print(f"  Total images: {total}")
# print(f"  Detected: {detected} ({100*detected/total:.1f}%)")
# print(f"  Missed: {total - detected} ({100*(total-detected)/total:.1f}%)")

## 9. Advanced: Batch Generation with Random Parameters

In [None]:
def generate_random_transformations(n: int = 10) -> List[Tuple]:
    """
    Generate random transformation parameters.
    """
    transformations = []
    for _ in range(n):
        scale = random.uniform(0.5, 1.5)
        rotation = random.uniform(-30, 30)
        flip = random.choice([True, False])
        brightness = random.uniform(0.8, 1.2)
        contrast = random.uniform(0.8, 1.2)
        transformations.append((scale, rotation, flip, brightness, contrast))
    return transformations

# Example: Generate 20 images with random transformations
# random_transforms = generate_random_transformations(20)
# print(f"Generated {len(random_transforms)} random transformations")
# print("\nExample transformations:")
# for i, t in enumerate(random_transforms[:5]):
#     print(f"  {i+1}. scale={t[0]:.2f}, rot={t[1]:.1f}°, flip={t[2]}, bright={t[3]:.2f}, contrast={t[4]:.2f}")

## Summary

This notebook provides tools for:

1. **Loading** pin images with masks
2. **Transforming** pins (scale, rotation, flip, brightness, contrast)
3. **Compositing** pins onto various backgrounds
4. **Saving** synthetic images with metadata
5. **Testing** with the detection pipeline (optional)

### Next Steps:

- Adjust transformation parameters for more variety
- Add more backgrounds for diverse scenarios
- Use generated images for training/testing
- Evaluate detection pipeline performance
- Create augmentation pipelines for model training

### Tips:

- Use high-quality pin images with clear edges
- Provide masks for best results (or ensure clean backgrounds)
- Match lighting conditions between pin and backgrounds
- Test different blend settings for realistic composites
- Save metadata for training object detection models