# 📊 Data Exploration - EPRA Violence Classification

This notebook provides comprehensive data exploration for the EPRA violence classification project. We'll analyze the dataset structure, distribution, and characteristics to inform our modeling decisions.

## Objectives
1. **Load and inspect the dataset**
2. **Analyze class distribution and balance**
3. **Explore image characteristics (size, format, quality)**
4. **Identify potential data issues**
5. **Visualize sample images from each class**
6. **Statistical analysis of the dataset**

In [None]:
# Setup and imports
import sys
from pathlib import Path


# Add src to path
sys.path.append(str(Path.cwd().parent / "src"))

import warnings

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


warnings.filterwarnings("ignore")

# EPRA imports
from epra_classifier.data.loaders import ViolenceDataset, get_dataset_statistics
from epra_classifier.utils.config import DataConfig


# Plotting configuration
plt.style.use("seaborn-v0_8")
sns.set_palette("husl")
plt.rcParams["figure.figsize"] = (12, 8)

print("✅ Setup completed successfully!")

## 📁 Dataset Configuration

Configure the paths and parameters for data exploration.

In [None]:
# Dataset configuration
DATA_DIR = Path("../data")  # Adjust this path to your dataset
OUTPUT_DIR = Path("../outputs/data_exploration")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Create data configuration
config = DataConfig(
    image_size=(224, 224),
    batch_size=32,
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
)

print(f"📂 Data directory: {DATA_DIR}")
print(f"📊 Output directory: {OUTPUT_DIR}")
print(f"⚙️ Configuration: {config.image_size} images")

## 🔍 Dataset Loading and Basic Statistics

In [None]:
# Load dataset
try:
    dataset = ViolenceDataset(
        root_dir=DATA_DIR, config=config, classification_type="auto"
    )
    print("✅ Dataset loaded successfully!")
    print(f"📊 Total samples: {len(dataset)}")
    print(f"🏷️ Number of classes: {len(dataset.classes)}")
    print(f"📋 Classes: {dataset.classes}")
    print(f"🔄 Classification type: {dataset.get_classification_type()}")

except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    print("Please check the DATA_DIR path and ensure the dataset exists.")

In [None]:
# Get comprehensive dataset statistics
stats = get_dataset_statistics(dataset)
print("📈 Dataset Statistics:")
print("=" * 40)
for key, value in stats.items():
    if isinstance(value, dict):
        print(f"{key}:")
        for k, v in value.items():
            print(f"  {k}: {v}")
    else:
        print(f"{key}: {value}")

## 📊 Class Distribution Analysis

In [None]:
# Analyze class distribution
class_counts = dataset.get_class_counts()
total_samples = sum(class_counts.values())

# Create distribution DataFrame
dist_df = pd.DataFrame(
    [
        {
            "Class": class_name,
            "Count": count,
            "Percentage": (count / total_samples) * 100,
        }
        for class_name, count in class_counts.items()
    ]
)

print("📊 Class Distribution:")
print(dist_df.to_string(index=False))

# Calculate imbalance ratio
max_count = max(class_counts.values())
min_count = min(class_counts.values())
imbalance_ratio = max_count / min_count if min_count > 0 else float("inf")

print(f"\n⚖️ Imbalance ratio: {imbalance_ratio:.2f}")
if imbalance_ratio > 3:
    print("⚠️ Significant class imbalance detected!")
elif imbalance_ratio > 1.5:
    print("⚠️ Moderate class imbalance detected.")
else:
    print("✅ Classes are relatively balanced.")

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Bar plot
bars = axes[0].bar(
    dist_df["Class"], dist_df["Count"], color=sns.color_palette("husl", len(dist_df))
)
axes[0].set_title("Class Distribution (Count)", fontsize=14, fontweight="bold")
axes[0].set_xlabel("Class")
axes[0].set_ylabel("Number of Samples")
axes[0].tick_params(axis="x", rotation=45)

# Add value labels on bars
for bar, count in zip(bars, dist_df["Count"]):
    axes[0].text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.01 * max(dist_df["Count"]),
        str(count),
        ha="center",
        va="bottom",
        fontweight="bold",
    )

# Pie chart
colors = sns.color_palette("husl", len(dist_df))
wedges, texts, autotexts = axes[1].pie(
    dist_df["Count"],
    labels=dist_df["Class"],
    autopct="%1.1f%%",
    colors=colors,
    startangle=90,
)
axes[1].set_title("Class Distribution (Percentage)", fontsize=14, fontweight="bold")

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "class_distribution.png", dpi=300, bbox_inches="tight")
plt.show()

print(f"💾 Class distribution plot saved to {OUTPUT_DIR / 'class_distribution.png'}")