# DAIA - Data Exploration Notebook

This notebook helps you explore your dataset before training.

In [None]:
import sys
sys.path.append('../src')

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from utils import load_config
from data_loader import load_dataset, split_dataset

%matplotlib inline

## 1. Load Configuration

In [None]:
config = load_config('../config.yaml')
print("Configuration loaded successfully!")
print(f"\nData directory: {config['data']['data_dir']}")
print(f"Image size: {config['data']['image_size']}")
print(f"Batch size: {config['data']['batch_size']}")

## 2. Dataset Overview

In [None]:
# Load dataset
data_dir = '../' + config['data']['data_dir']
image_paths, labels = load_dataset(data_dir, config)

# Statistics
total_images = len(image_paths)
real_count = labels.count(0)
ai_count = labels.count(1)

print(f"\nDataset Statistics:")
print(f"  Total images: {total_images}")
print(f"  Real images: {real_count} ({real_count/total_images*100:.1f}%)")
print(f"  AI-generated: {ai_count} ({ai_count/total_images*100:.1f}%)")
print(f"  Balance ratio: {min(real_count, ai_count) / max(real_count, ai_count):.2f}")

## 3. Visualize Class Distribution

In [None]:
# Bar chart
fig, ax = plt.subplots(figsize=(8, 5))
classes = ['Real', 'AI-Generated']
counts = [real_count, ai_count]
colors = ['#2ecc71', '#e74c3c']

bars = ax.bar(classes, counts, color=colors, alpha=0.7, edgecolor='black')
ax.set_ylabel('Number of Images', fontsize=12)
ax.set_title('Dataset Class Distribution', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(height)}',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# Check if balanced
if abs(real_count - ai_count) / total_images > 0.2:
    print("‚ö†Ô∏è  Warning: Dataset is imbalanced! Consider balancing classes.")
else:
    print("‚úì Dataset is well-balanced!")

## 4. Sample Images

In [None]:
# Show sample images from each class
import random

def show_samples(image_paths, labels, class_label, n_samples=5):
    # Get images of specific class
    class_images = [path for path, label in zip(image_paths, labels) if label == class_label]
    
    if len(class_images) < n_samples:
        n_samples = len(class_images)
    
    # Random sample
    samples = random.sample(class_images, n_samples)
    
    # Plot
    fig, axes = plt.subplots(1, n_samples, figsize=(15, 3))
    if n_samples == 1:
        axes = [axes]
    
    class_name = "Real" if class_label == 0 else "AI-Generated"
    fig.suptitle(f"Sample {class_name} Images", fontsize=14, fontweight='bold')
    
    for ax, img_path in zip(axes, samples):
        img = Image.open(img_path).convert('RGB')
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(os.path.basename(img_path), fontsize=8)
    
    plt.tight_layout()
    plt.show()

# Show real images
print("Real Images:")
show_samples(image_paths, labels, class_label=0, n_samples=5)

# Show AI-generated images
print("\nAI-Generated Images:")
show_samples(image_paths, labels, class_label=1, n_samples=5)

## 5. Image Size Analysis

In [None]:
# Analyze image dimensions
widths = []
heights = []

# Sample 100 random images
sample_paths = random.sample(image_paths, min(100, len(image_paths)))

for img_path in sample_paths:
    img = Image.open(img_path)
    widths.append(img.width)
    heights.append(img.height)

print(f"Image Size Statistics (from {len(sample_paths)} samples):")
print(f"  Width: {np.mean(widths):.0f} ¬± {np.std(widths):.0f} px")
print(f"  Height: {np.mean(heights):.0f} ¬± {np.std(heights):.0f} px")
print(f"  Min: {min(widths)}x{min(heights)}")
print(f"  Max: {max(widths)}x{max(heights)}")

# Scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(widths, heights, alpha=0.5)
plt.xlabel('Width (px)', fontsize=12)
plt.ylabel('Height (px)', fontsize=12)
plt.title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
plt.grid(alpha=0.3)
plt.axhline(y=224, color='r', linestyle='--', label='Model input size (224x224)')
plt.axvline(x=224, color='r', linestyle='--')
plt.legend()
plt.tight_layout()
plt.show()

## 6. Train/Val/Test Split Preview

In [None]:
# Split dataset
train_paths, train_labels, val_paths, val_labels, test_paths, test_labels = split_dataset(
    image_paths, labels,
    train_split=config['data']['train_split'],
    val_split=config['data']['val_split'],
    test_split=config['data']['test_split'],
    random_seed=config['seed']
)

# Visualize split
fig, ax = plt.subplots(figsize=(10, 6))

splits = ['Train', 'Validation', 'Test']
real_counts = [
    train_labels.count(0),
    val_labels.count(0),
    test_labels.count(0)
]
ai_counts = [
    train_labels.count(1),
    val_labels.count(1),
    test_labels.count(1)
]

x = np.arange(len(splits))
width = 0.35

bars1 = ax.bar(x - width/2, real_counts, width, label='Real', color='#2ecc71', alpha=0.7)
bars2 = ax.bar(x + width/2, ai_counts, width, label='AI-Generated', color='#e74c3c', alpha=0.7)

ax.set_ylabel('Number of Images', fontsize=12)
ax.set_title('Train/Val/Test Split Distribution', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(splits)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Add value labels
def add_labels(bars):
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}',
                ha='center', va='bottom', fontsize=10)

add_labels(bars1)
add_labels(bars2)

plt.tight_layout()
plt.show()

## 7. Data Augmentation Preview

In [None]:
from data_loader import get_transforms
import albumentations as A

# Get a sample image
sample_path = random.choice(image_paths)
image = np.array(Image.open(sample_path).convert('RGB'))

# Get augmentation transforms
transform = get_transforms(config, is_training=True)

# Generate augmented versions
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

# Original
axes[0].imshow(image)
axes[0].set_title('Original', fontsize=12, fontweight='bold')
axes[0].axis('off')

# Augmented versions
for i in range(1, 6):
    augmented = transform(image=image)['image']
    axes[i].imshow(augmented)
    axes[i].set_title(f'Augmented {i}', fontsize=12)
    axes[i].axis('off')

plt.suptitle('Data Augmentation Examples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Augmentations applied:")
print("  - Random horizontal flip")
print("  - Random rotation")
print("  - Color jittering (brightness, contrast, saturation)")
print("  - Gaussian noise")
print("  - Random blur")

## 8. Ready to Train?

If all checks pass, you're ready to train!

In [None]:
print("\n" + "="*60)
print("Dataset Readiness Checklist")
print("="*60)

checks = []

# Check 1: Sufficient data
if total_images >= 500:
    print("‚úì Sufficient images (>= 500)")
    checks.append(True)
else:
    print(f"‚úó Not enough images ({total_images} < 500)")
    checks.append(False)

# Check 2: Balanced classes
balance = min(real_count, ai_count) / max(real_count, ai_count)
if balance >= 0.8:
    print("‚úì Classes are balanced (ratio >= 0.8)")
    checks.append(True)
else:
    print(f"‚úó Classes imbalanced (ratio: {balance:.2f})")
    checks.append(False)

# Check 3: Image quality
avg_size = np.mean(widths + heights)
if avg_size >= 224:
    print("‚úì Image resolution adequate")
    checks.append(True)
else:
    print(f"‚ö†Ô∏è  Images are small (avg: {avg_size:.0f}px)")
    checks.append(True)  # Still passable

print("="*60)

if all(checks):
    print("\nüöÄ Dataset is ready! You can start training:")
    print("   python src/train.py")
else:
    print("\n‚ö†Ô∏è  Please address the issues above before training.")
    print("   You can still try training, but results may be suboptimal.")