In [None]:
"""
Notebook 02: Class-Specific Augmentation for Imbalanced Data
=============================================================
Generates synthetic samples for minority classes ONLY
"""

import json
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as T
from tqdm import tqdm
import random

# ==========================================
# 0) Configuration
# ==========================================
RANDOM_SEED = 42
IMAGE_SIZE = 256
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

print("="*70)
print("CLASS-SPECIFIC AUGMENTATION FOR IMBALANCED DATA")
print("="*70)

# ==========================================
# 1) Load Training Data
# ==========================================
print("\n[1/5] Loading training data...")

train_df = pd.read_csv("train_manifest.csv")
val_df = pd.read_csv("val_manifest.csv")
test_df = pd.read_csv("test_manifest.csv")

with open("classes.json") as f:
    class_names = json.load(f)

print(f"✓ Train: {len(train_df)} images")
print(f"✓ Val:   {len(val_df)} images")
print(f"✓ Test:  {len(test_df)} images")

# ==========================================
# 2) Analyze Class Imbalance
# ==========================================
print("\n[2/5] Analyzing class distribution...")

class_counts = train_df['label'].value_counts()
print("\nOriginal class distribution (TRAIN only):")
for cls in class_names:
    count = class_counts.get(cls, 0)
    print(f"  {cls:20s}: {count:4d} samples")

# Determine target count
target_count = class_counts.max()
print(f"\n✓ Target count per class: {target_count}")

# Calculate needed augmentations
augmentation_needed = {}
for cls in class_names:
    current = class_counts.get(cls, 0)
    needed = max(0, target_count - current)
    augmentation_needed[cls] = needed
    if needed > 0:
        print(f"  {cls:20s}: needs {needed} augmented samples")

# ==========================================
# 3) Define Augmentation Transform
# ==========================================
print("\n[3/5] Defining augmentation strategy...")

augment_transform = T.Compose([
    T.ToPILImage(),
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.RandomCrop(224),
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Resize back
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=30),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    T.RandomPerspective(distortion_scale=0.2, p=0.3),
])

print("✓ Augmentation strategy defined")

# ==========================================
# 4) Generate Augmented Samples
# ==========================================
print("\n[4/5] Generating augmented samples...")

augmented_images = []
augmented_labels = []

total_to_generate = sum(augmentation_needed.values())
print(f"Total augmented samples to generate: {total_to_generate}")

if total_to_generate == 0:
    print("\n⚠ Dataset is already balanced! No augmentation needed.")
else:
    for class_name in class_names:
        needed = augmentation_needed[class_name]
        
        if needed == 0:
            continue
        
        print(f"\nProcessing {class_name}...")
        
        # Get all images from this class
        class_df = train_df[train_df['label'] == class_name]
        class_paths = class_df['path'].values
        
        # Generate augmented samples
        for i in tqdm(range(needed), desc=f"Augmenting {class_name}"):
            # Randomly select a source image
            source_path = np.random.choice(class_paths)
            
            # Load image
            img = Image.open(source_path).convert('RGB')
            img_array = np.array(img)
            
            # Apply augmentation
            aug_img = augment_transform(img_array)
            aug_array = np.array(aug_img, dtype=np.uint8)
            
            # Store
            augmented_images.append(aug_array)
            augmented_labels.append(class_name)
    
    # Convert to numpy arrays
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)
    
    print(f"\n✓ Generated {len(augmented_images)} augmented images")
    print(f"  Shape: {augmented_images.shape}")
    print(f"  Memory: {augmented_images.nbytes / (1024**2):.1f} MB")

# ==========================================
# 5) Save Augmented Data
# ==========================================
print("\n[5/5] Saving augmented data...")

if total_to_generate > 0:
    np.save("augmented_train_images.npy", augmented_images)
    np.save("augmented_train_labels.npy", augmented_labels)
    
    print("✓ Saved files:")
    print("  - augmented_train_images.npy")
    print("  - augmented_train_labels.npy")
else:
    print("✓ No augmented data to save (dataset already balanced)")

# ==========================================
# Summary
# ==========================================
print("\n" + "="*70)
print("AUGMENTATION SUMMARY")
print("="*70)

print("\nClass distribution after augmentation:")
for cls in class_names:
    original = class_counts.get(cls, 0)
    augmented = augmentation_needed[cls]
    total = original + augmented
    print(f"  {cls:20s}: {original:4d} (original) + {augmented:4d} (augmented) = {total:4d} (total)")

print(f"\nTotal training samples: {len(train_df)} → {len(train_df) + total_to_generate}")

print("\nNext step: Run Notebook 03 to combine and export to final NPY format")
print("="*70)