# Sports Image Dataset Exploration

## EE4745 Neural Networks Final Project

This notebook provides comprehensive exploration and analysis of the sports image dataset used for classification tasks.

### Objectives:
- Dataset overview and structure analysis
- Class distribution analysis
- Sample image visualization
- Data augmentation demonstration
- Dataset validation and quality assessment

---

## 1. Setup and Imports

Import necessary libraries and set up the environment for data exploration.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
import torchvision.transforms as transforms
from collections import defaultdict, Counter
import warnings
warnings.filterwarnings('ignore')

# Add src directory to path
sys.path.append('../src')

# Import custom modules
from dataset.sports_dataset import SportsDataset, get_dataloaders
from training.utils import set_seed

# Set style and seed for reproducibility
plt.style.use('default')
sns.set_palette('husl')
set_seed(42)

# Display settings
pd.set_option('display.max_columns', None)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 150
plt.rcParams['font.size'] = 10

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 2. Dataset Overview

Let's start by examining the basic structure of our sports image dataset.

In [None]:
# Dataset configuration
DATA_DIR = '../data'
IMAGE_SIZE = 32  # Start with 32x32 images
BATCH_SIZE = 32

# Check if data directory exists
if not os.path.exists(DATA_DIR):
    print(f"Data directory {DATA_DIR} not found!")
    print("Please ensure the data symlink is properly configured.")
else:
    print(f"Data directory found: {DATA_DIR}")
    
    # List the splits
    splits = [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))]
    print(f"Available splits: {splits}")
    
    # List classes from the training split
    train_dir = os.path.join(DATA_DIR, 'train')
    if os.path.exists(train_dir):
        classes = sorted([d for d in os.listdir(train_dir) 
                         if os.path.isdir(os.path.join(train_dir, d))])
        print(f"\nClasses in dataset ({len(classes)}):")
        for i, cls in enumerate(classes):
            print(f"  {i}: {cls}")
    else:
        print(f"Training directory not found: {train_dir}")

### Dataset Statistics

Let's analyze the size and distribution of our dataset across different splits and classes.

In [None]:
def analyze_dataset_structure(data_dir):
    """Analyze the structure and statistics of the dataset"""
    
    stats = {
        'splits': {},
        'total_images': 0,
        'classes': SportsDataset.CLASSES
    }
    
    # Analyze each split
    for split in ['train', 'valid']:
        split_dir = os.path.join(data_dir, split)
        if not os.path.exists(split_dir):
            continue
            
        split_stats = {
            'total_images': 0,
            'class_distribution': {},
            'image_sizes': [],
            'file_formats': Counter()
        }
        
        # Analyze each class
        for class_name in SportsDataset.CLASSES:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.exists(class_dir):
                split_stats['class_distribution'][class_name] = 0
                continue
                
            # Count images and analyze properties
            images = [f for f in os.listdir(class_dir) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            split_stats['class_distribution'][class_name] = len(images)
            split_stats['total_images'] += len(images)
            
            # Sample a few images to check properties
            sample_images = images[:min(5, len(images))]
            for img_name in sample_images:
                img_path = os.path.join(class_dir, img_name)
                try:
                    with Image.open(img_path) as img:
                        split_stats['image_sizes'].append(img.size)
                        split_stats['file_formats'][img.format] += 1
                except Exception as e:
                    print(f"Error loading {img_path}: {e}")
        
        stats['splits'][split] = split_stats
        stats['total_images'] += split_stats['total_images']
    
    return stats

# Analyze dataset
print("Analyzing dataset structure...")
dataset_stats = analyze_dataset_structure(DATA_DIR)

# Display results
print("\n" + "="*50)
print("DATASET STATISTICS")
print("="*50)

print(f"Total images: {dataset_stats['total_images']:,}")
print(f"Number of classes: {len(dataset_stats['classes'])}")
print(f"Classes: {', '.join(dataset_stats['classes'])}")

for split_name, split_data in dataset_stats['splits'].items():
    print(f"\n{split_name.upper()} SET:")
    print(f"  Total images: {split_data['total_images']:,}")
    print(f"  Images per class:")
    for class_name, count in split_data['class_distribution'].items():
        percentage = (count / split_data['total_images']) * 100 if split_data['total_images'] > 0 else 0
        print(f"    {class_name:12}: {count:4d} ({percentage:5.1f}%)")
    
    # Image format statistics
    if split_data['file_formats']:
        print(f"  File formats: {dict(split_data['file_formats'])}")
    
    # Image size statistics
    if split_data['image_sizes']:
        sizes = split_data['image_sizes']
        unique_sizes = list(set(sizes))
        print(f"  Sample image sizes: {unique_sizes[:5]}{'...' if len(unique_sizes) > 5 else ''}")

## 3. Class Distribution Analysis

Visualize the distribution of images across different sports classes.

In [None]:
# Create visualizations for class distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Dataset Class Distribution Analysis', fontsize=16, fontweight='bold')

# Prepare data for visualization
train_dist = dataset_stats['splits'].get('train', {}).get('class_distribution', {})
valid_dist = dataset_stats['splits'].get('valid', {}).get('class_distribution', {})

classes = list(train_dist.keys())
train_counts = list(train_dist.values())
valid_counts = list(valid_dist.values())

# 1. Training set distribution (bar plot)
axes[0, 0].bar(range(len(classes)), train_counts, color=sns.color_palette('husl', len(classes)))
axes[0, 0].set_title('Training Set - Class Distribution', fontweight='bold')
axes[0, 0].set_xlabel('Sports Classes')
axes[0, 0].set_ylabel('Number of Images')
axes[0, 0].set_xticks(range(len(classes)))
axes[0, 0].set_xticklabels(classes, rotation=45, ha='right')
axes[0, 0].grid(True, alpha=0.3)

# Add value labels on bars
for i, v in enumerate(train_counts):
    axes[0, 0].text(i, v + max(train_counts)*0.01, str(v), ha='center', va='bottom', fontsize=9)

# 2. Validation set distribution (bar plot)
axes[0, 1].bar(range(len(classes)), valid_counts, color=sns.color_palette('husl', len(classes)))
axes[0, 1].set_title('Validation Set - Class Distribution', fontweight='bold')
axes[0, 1].set_xlabel('Sports Classes')
axes[0, 1].set_ylabel('Number of Images')
axes[0, 1].set_xticks(range(len(classes)))
axes[0, 1].set_xticklabels(classes, rotation=45, ha='right')
axes[0, 1].grid(True, alpha=0.3)

# Add value labels on bars
for i, v in enumerate(valid_counts):
    axes[0, 1].text(i, v + max(valid_counts)*0.01, str(v), ha='center', va='bottom', fontsize=9)

# 3. Combined comparison (grouped bar plot)
x = np.arange(len(classes))
width = 0.35

axes[1, 0].bar(x - width/2, train_counts, width, label='Training', alpha=0.8)
axes[1, 0].bar(x + width/2, valid_counts, width, label='Validation', alpha=0.8)
axes[1, 0].set_title('Training vs Validation - Class Distribution', fontweight='bold')
axes[1, 0].set_xlabel('Sports Classes')
axes[1, 0].set_ylabel('Number of Images')
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(classes, rotation=45, ha='right')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Distribution pie chart (training set)
axes[1, 1].pie(train_counts, labels=classes, autopct='%1.1f%%', startangle=90)
axes[1, 1].set_title('Training Set - Class Distribution (Pie Chart)', fontweight='bold')

plt.tight_layout()
plt.show()

# Print class balance analysis
print("\nCLASS BALANCE ANALYSIS:")
print("-" * 30)

if train_counts:
    train_mean = np.mean(train_counts)
    train_std = np.std(train_counts)
    train_cv = train_std / train_mean if train_mean > 0 else 0
    
    print(f"Training set:")
    print(f"  Mean images per class: {train_mean:.1f}")
    print(f"  Standard deviation: {train_std:.1f}")
    print(f"  Coefficient of variation: {train_cv:.3f}")
    print(f"  Balance ratio (min/max): {min(train_counts)/max(train_counts):.3f}")

if valid_counts:
    valid_mean = np.mean(valid_counts)
    valid_std = np.std(valid_counts)
    valid_cv = valid_std / valid_mean if valid_mean > 0 else 0
    
    print(f"\nValidation set:")
    print(f"  Mean images per class: {valid_mean:.1f}")
    print(f"  Standard deviation: {valid_std:.1f}")
    print(f"  Coefficient of variation: {valid_cv:.3f}")
    print(f"  Balance ratio (min/max): {min(valid_counts)/max(valid_counts):.3f}")

## 4. Sample Image Visualization

Display representative images from each sports class to understand the visual characteristics of the dataset.

In [None]:
def display_sample_images(data_dir, classes, samples_per_class=3, image_size=(64, 64)):
    """Display sample images from each class"""
    
    num_classes = len(classes)
    fig, axes = plt.subplots(num_classes, samples_per_class, 
                            figsize=(samples_per_class * 3, num_classes * 2.5))
    
    fig.suptitle('Sample Images from Each Sports Class', fontsize=16, fontweight='bold')
    
    if num_classes == 1:
        axes = [axes]
    
    for class_idx, class_name in enumerate(classes):
        class_dir = os.path.join(data_dir, 'train', class_name)
        
        if not os.path.exists(class_dir):
            continue
            
        # Get image files
        images = [f for f in os.listdir(class_dir) 
                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        # Sample random images
        np.random.seed(42)  # For reproducible sampling
        sampled_images = np.random.choice(images, 
                                        min(samples_per_class, len(images)), 
                                        replace=False)
        
        for img_idx, img_name in enumerate(sampled_images):
            img_path = os.path.join(class_dir, img_name)
            
            try:
                # Load and display image
                with Image.open(img_path) as img:
                    # Convert to RGB if necessary
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    
                    # Resize for consistent display
                    img_resized = img.resize(image_size, Image.Resampling.LANCZOS)
                    
                    # Display
                    if samples_per_class == 1:
                        ax = axes[class_idx]
                    else:
                        ax = axes[class_idx, img_idx]
                    
                    ax.imshow(img_resized)
                    ax.axis('off')
                    
                    # Add title with class name and image info
                    title = f'{class_name}'
                    if img_idx == 0:  # Only add class name to first image
                        title += f'\n({img.size[0]}x{img.size[1]})'
                    else:
                        title = f'({img.size[0]}x{img.size[1]})'
                    
                    ax.set_title(title, fontsize=10)
                    
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                # Display placeholder
                if samples_per_class == 1:
                    ax = axes[class_idx]
                else:
                    ax = axes[class_idx, img_idx]
                
                ax.text(0.5, 0.5, 'Error\nLoading\nImage', 
                       ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f'{class_name} (Error)', fontsize=10)
                ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Display sample images
print("Displaying sample images from each class...")
display_sample_images(DATA_DIR, dataset_stats['classes'], samples_per_class=4)

### Image Properties Analysis

Let's analyze the properties of images in our dataset such as dimensions, aspect ratios, and color characteristics.

In [None]:
def analyze_image_properties(data_dir, classes, max_samples=50):
    """Analyze properties of images in the dataset"""
    
    properties = {
        'dimensions': [],
        'aspect_ratios': [],
        'file_sizes': [],
        'mean_colors': [],
        'formats': Counter()
    }
    
    print(f"Analyzing image properties (sampling up to {max_samples} images per class)...")
    
    for class_name in classes:
        class_dir = os.path.join(data_dir, 'train', class_name)
        
        if not os.path.exists(class_dir):
            continue
            
        images = [f for f in os.listdir(class_dir) 
                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        # Sample images for analysis
        np.random.seed(42)
        sampled_images = np.random.choice(images, 
                                        min(max_samples, len(images)), 
                                        replace=False)
        
        for img_name in sampled_images:
            img_path = os.path.join(class_dir, img_name)
            
            try:
                with Image.open(img_path) as img:
                    # Basic properties
                    properties['dimensions'].append(img.size)
                    properties['aspect_ratios'].append(img.size[0] / img.size[1])
                    properties['file_sizes'].append(os.path.getsize(img_path))
                    properties['formats'][img.format] += 1
                    
                    # Color analysis (convert to RGB first)
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    
                    # Calculate mean color
                    img_array = np.array(img)
                    mean_color = img_array.mean(axis=(0, 1))
                    properties['mean_colors'].append(mean_color)
                    
            except Exception as e:
                print(f"Error analyzing {img_path}: {e}")
    
    return properties

# Analyze image properties
img_properties = analyze_image_properties(DATA_DIR, dataset_stats['classes'])

# Visualize the analysis
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Image Properties Analysis', fontsize=16, fontweight='bold')

# 1. Image dimensions scatter plot
if img_properties['dimensions']:
    widths, heights = zip(*img_properties['dimensions'])
    
    axes[0, 0].scatter(widths, heights, alpha=0.6, s=20)
    axes[0, 0].set_title('Image Dimensions Distribution', fontweight='bold')
    axes[0, 0].set_xlabel('Width (pixels)')
    axes[0, 0].set_ylabel('Height (pixels)')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Add diagonal line for square images
    max_dim = max(max(widths), max(heights))
    axes[0, 0].plot([0, max_dim], [0, max_dim], 'r--', alpha=0.5, label='Square')
    axes[0, 0].legend()

# 2. Aspect ratio histogram
if img_properties['aspect_ratios']:
    axes[0, 1].hist(img_properties['aspect_ratios'], bins=30, alpha=0.7, edgecolor='black')
    axes[0, 1].axvline(1.0, color='red', linestyle='--', label='Square (1:1)')
    axes[0, 1].set_title('Aspect Ratio Distribution', fontweight='bold')
    axes[0, 1].set_xlabel('Aspect Ratio (Width/Height)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

# 3. File size distribution
if img_properties['file_sizes']:
    file_sizes_kb = [size / 1024 for size in img_properties['file_sizes']]  # Convert to KB
    axes[0, 2].hist(file_sizes_kb, bins=30, alpha=0.7, edgecolor='black')
    axes[0, 2].set_title('File Size Distribution', fontweight='bold')
    axes[0, 2].set_xlabel('File Size (KB)')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].grid(True, alpha=0.3)

# 4. Mean RGB color distribution
if img_properties['mean_colors']:
    mean_colors = np.array(img_properties['mean_colors'])
    
    colors = ['red', 'green', 'blue']
    labels = ['Red', 'Green', 'Blue']
    
    for i, (color, label) in enumerate(zip(colors, labels)):
        axes[1, 0].hist(mean_colors[:, i], bins=20, alpha=0.7, 
                       color=color, label=label, edgecolor='black')
    
    axes[1, 0].set_title('Mean RGB Channel Distribution', fontweight='bold')
    axes[1, 0].set_xlabel('Mean Pixel Value (0-255)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

# 5. Format distribution pie chart
if img_properties['formats']:
    formats = list(img_properties['formats'].keys())
    counts = list(img_properties['formats'].values())
    
    axes[1, 1].pie(counts, labels=formats, autopct='%1.1f%%', startangle=90)
    axes[1, 1].set_title('File Format Distribution', fontweight='bold')

# 6. Summary statistics
axes[1, 2].axis('off')
summary_text = "Dataset Statistics Summary\n" + "="*25 + "\n\n"

if img_properties['dimensions']:
    widths, heights = zip(*img_properties['dimensions'])
    summary_text += f"Dimensions:\n"
    summary_text += f"  Width: {np.mean(widths):.1f}¬±{np.std(widths):.1f}px\n"
    summary_text += f"  Height: {np.mean(heights):.1f}¬±{np.std(heights):.1f}px\n\n"

if img_properties['aspect_ratios']:
    summary_text += f"Aspect Ratio: {np.mean(img_properties['aspect_ratios']):.2f}¬±{np.std(img_properties['aspect_ratios']):.2f}\n\n"

if img_properties['file_sizes']:
    sizes_kb = [s/1024 for s in img_properties['file_sizes']]
    summary_text += f"File Size: {np.mean(sizes_kb):.1f}¬±{np.std(sizes_kb):.1f} KB\n\n"

if img_properties['mean_colors']:
    mean_colors = np.array(img_properties['mean_colors'])
    summary_text += f"Mean RGB:\n"
    summary_text += f"  R: {np.mean(mean_colors[:, 0]):.1f}¬±{np.std(mean_colors[:, 0]):.1f}\n"
    summary_text += f"  G: {np.mean(mean_colors[:, 1]):.1f}¬±{np.std(mean_colors[:, 1]):.1f}\n"
    summary_text += f"  B: {np.mean(mean_colors[:, 2]):.1f}¬±{np.std(mean_colors[:, 2]):.1f}\n"

axes[1, 2].text(0.1, 0.9, summary_text, transform=axes[1, 2].transAxes, 
               fontsize=11, verticalalignment='top', fontfamily='monospace',
               bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.tight_layout()
plt.show()

print(f"\nAnalyzed {len(img_properties['dimensions'])} images total.")

## 5. Data Augmentation Demonstration

Explore the data augmentation techniques used in training and their effects on the images.

In [None]:
def demonstrate_augmentations(data_dir, class_name='basketball', image_size=64):
    """Demonstrate various data augmentation techniques"""
    
    # Load a sample image
    class_dir = os.path.join(data_dir, 'train', class_name)
    images = [f for f in os.listdir(class_dir) 
             if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    if not images:
        print(f"No images found in {class_dir}")
        return
    
    # Select a sample image
    sample_image_path = os.path.join(class_dir, images[0])
    original_image = Image.open(sample_image_path).convert('RGB')
    
    # Define different augmentation transforms
    augmentations = {
        'Original': transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ]),
        'Horizontal Flip': transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor()
        ]),
        'Rotation (¬±15¬∞)': transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomRotation(degrees=15),
            transforms.ToTensor()
        ]),
        'Color Jitter': transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.ToTensor()
        ]),
        'Random Crop': transforms.Compose([
            transforms.Resize((int(image_size * 1.2), int(image_size * 1.2))),
            transforms.RandomCrop(image_size),
            transforms.ToTensor()
        ]),
        'Combined': transforms.Compose([
            transforms.Resize((int(image_size * 1.2), int(image_size * 1.2))),
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor()
        ])
    }
    
    # Create visualization
    num_augs = len(augmentations)
    num_samples = 3  # Show 3 samples of each augmentation
    
    fig, axes = plt.subplots(num_augs, num_samples, figsize=(num_samples * 3, num_augs * 2.5))
    fig.suptitle(f'Data Augmentation Examples\nOriginal Image: {class_name}', 
                fontsize=14, fontweight='bold')
    
    for row, (aug_name, transform) in enumerate(augmentations.items()):
        for col in range(num_samples):
            # Apply transformation
            torch.manual_seed(42 + col)  # For reproducible results
            transformed = transform(original_image)
            
            # Convert tensor back to displayable format
            if transformed.shape[0] == 3:  # RGB channels
                img_display = transformed.permute(1, 2, 0).numpy()
                img_display = np.clip(img_display, 0, 1)
            else:
                img_display = transformed.squeeze().numpy()
            
            # Display
            ax = axes[row, col] if num_augs > 1 else axes[col]
            ax.imshow(img_display)
            ax.axis('off')
            
            # Add title only to first column
            if col == 0:
                ax.set_title(aug_name, fontweight='bold', fontsize=11)
    
    plt.tight_layout()
    plt.show()

# Demonstrate augmentations
print("Demonstrating data augmentation techniques...")
demonstrate_augmentations(DATA_DIR)

### Training vs Validation Transforms Comparison

Compare the actual transforms used in the SportsDataset for training and validation.

In [None]:
# Load actual datasets to see the transforms in action
train_dataset = SportsDataset(root_dir=DATA_DIR, split='train', image_size=IMAGE_SIZE, augment=True)
val_dataset = SportsDataset(root_dir=DATA_DIR, split='valid', image_size=IMAGE_SIZE, augment=False)

print("Training Dataset Transform:")
print(train_dataset.transform)

print("\nValidation Dataset Transform:")
print(val_dataset.transform)

# Show examples of actual training transforms
def show_training_samples(dataset, num_samples=6, title="Training Dataset Samples"):
    """Show samples from dataset with applied transforms"""
    
    fig, axes = plt.subplots(2, num_samples//2, figsize=(15, 6))
    fig.suptitle(title, fontsize=14, fontweight='bold')
    
    # Set seed for reproducible sampling
    torch.manual_seed(42)
    
    for i in range(num_samples):
        # Get a random sample
        idx = torch.randint(0, len(dataset), (1,)).item()
        image, label = dataset[idx]
        class_name = dataset.CLASSES[label]
        
        # Denormalize image for display
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        denormalized = image * std + mean
        denormalized = torch.clamp(denormalized, 0, 1)
        
        # Display
        row, col = i // 3, i % 3
        ax = axes[row, col]
        ax.imshow(denormalized.permute(1, 2, 0))
        ax.set_title(f'{class_name}', fontsize=10)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Show training and validation samples
show_training_samples(train_dataset, title="Training Dataset Samples (with augmentation)")
show_training_samples(val_dataset, title="Validation Dataset Samples (no augmentation)")

# Dataset size information
print(f"\nDataset Sizes:")
print(f"Training set: {len(train_dataset):,} images")
print(f"Validation set: {len(val_dataset):,} images")
print(f"Total: {len(train_dataset) + len(val_dataset):,} images")

# Class distribution from actual dataset
train_dist = train_dataset.get_class_distribution()
val_dist = val_dataset.get_class_distribution()

print(f"\nClass distribution comparison:")
print(f"{'Class':<12} {'Train':<8} {'Valid':<8} {'Ratio':<8}")
print("-" * 40)
for class_name in dataset_stats['classes']:
    train_count = train_dist.get(class_name, 0)
    val_count = val_dist.get(class_name, 0)
    ratio = train_count / val_count if val_count > 0 else float('inf')
    print(f"{class_name:<12} {train_count:<8} {val_count:<8} {ratio:<8.1f}")

## 6. Dataset Validation and Quality Assessment

Perform quality checks on the dataset to identify potential issues.

In [None]:
def validate_dataset_quality(data_dir, classes, sample_size=100):
    """Validate dataset quality and identify potential issues"""
    
    print("Performing dataset quality validation...")
    
    issues = {
        'corrupted_images': [],
        'extremely_small_images': [],
        'extremely_large_images': [],
        'unusual_aspect_ratios': [],
        'grayscale_images': [],
        'very_dark_images': [],
        'very_bright_images': []
    }
    
    processed_count = 0
    total_images = 0
    
    for split in ['train', 'valid']:
        for class_name in classes:
            class_dir = os.path.join(data_dir, split, class_name)
            
            if not os.path.exists(class_dir):
                continue
                
            images = [f for f in os.listdir(class_dir) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            total_images += len(images)
            
            # Sample images for quality check
            np.random.seed(42)
            sampled_images = np.random.choice(images, 
                                            min(sample_size, len(images)), 
                                            replace=False)
            
            for img_name in sampled_images:
                img_path = os.path.join(class_dir, img_name)
                
                try:
                    with Image.open(img_path) as img:
                        # Basic corruption check
                        img.verify()  # Verify image integrity
                        
                        # Reopen for analysis (verify closes the file)
                        with Image.open(img_path) as img:
                            width, height = img.size
                            aspect_ratio = width / height
                            
                            # Check for extremely small images
                            if width < 32 or height < 32:
                                issues['extremely_small_images'].append(
                                    (img_path, f"{width}x{height}")
                                )
                            
                            # Check for extremely large images
                            if width > 2000 or height > 2000:
                                issues['extremely_large_images'].append(
                                    (img_path, f"{width}x{height}")
                                )
                            
                            # Check for unusual aspect ratios
                            if aspect_ratio < 0.3 or aspect_ratio > 3.0:
                                issues['unusual_aspect_ratios'].append(
                                    (img_path, f"{aspect_ratio:.2f}")
                                )
                            
                            # Convert to RGB for color analysis
                            if img.mode != 'RGB':
                                if img.mode in ['L', 'LA']:  # Grayscale
                                    issues['grayscale_images'].append(img_path)
                                img = img.convert('RGB')
                            
                            # Analyze brightness
                            img_array = np.array(img)
                            mean_brightness = img_array.mean()
                            
                            if mean_brightness < 30:  # Very dark
                                issues['very_dark_images'].append(
                                    (img_path, f"brightness: {mean_brightness:.1f}")
                                )
                            elif mean_brightness > 225:  # Very bright
                                issues['very_bright_images'].append(
                                    (img_path, f"brightness: {mean_brightness:.1f}")
                                )
                            
                            processed_count += 1
                            
                except Exception as e:
                    issues['corrupted_images'].append((img_path, str(e)))
    
    return issues, processed_count, total_images

# Perform quality validation
quality_issues, processed, total = validate_dataset_quality(DATA_DIR, dataset_stats['classes'])

print(f"\n{'='*60}")
print("DATASET QUALITY ASSESSMENT RESULTS")
print(f"{'='*60}")

print(f"\nProcessed {processed:,} images out of {total:,} total images")
print(f"Sample rate: {processed/total*100:.1f}%")

print("\nQuality Issues Found:")
print("-" * 30)

for issue_type, issues in quality_issues.items():
    issue_name = issue_type.replace('_', ' ').title()
    print(f"{issue_name}: {len(issues)}")
    
    if issues and len(issues) <= 5:  # Show details for small number of issues
        for item in issues:
            if isinstance(item, tuple):
                print(f"  - {item[0]} ({item[1]})")
            else:
                print(f"  - {item}")
    elif issues:
        print(f"  (showing first 3 out of {len(issues)})")
        for item in issues[:3]:
            if isinstance(item, tuple):
                print(f"  - {item[0]} ({item[1]})")
            else:
                print(f"  - {item}")

# Calculate overall quality score
total_issues = sum(len(issues) for issues in quality_issues.values())
quality_score = max(0, 100 - (total_issues / processed * 100))

print(f"\nOverall Dataset Quality Score: {quality_score:.1f}/100")

if quality_score >= 90:
    print("‚úÖ Excellent dataset quality!")
elif quality_score >= 80:
    print("‚úÖ Good dataset quality.")
elif quality_score >= 70:
    print("‚ö†Ô∏è  Acceptable dataset quality with minor issues.")
else:
    print("‚ùå Dataset quality needs improvement.")

# Recommendations
print("\nRecommendations:")
recommendations = []

if quality_issues['corrupted_images']:
    recommendations.append("- Remove or fix corrupted images before training")

if quality_issues['extremely_small_images']:
    recommendations.append("- Consider upsampling or removing very small images")

if quality_issues['unusual_aspect_ratios']:
    recommendations.append("- Review images with unusual aspect ratios for relevance")

if quality_issues['very_dark_images'] or quality_issues['very_bright_images']:
    recommendations.append("- Consider brightness normalization during preprocessing")

if quality_issues['grayscale_images']:
    recommendations.append("- Ensure grayscale images are properly converted to RGB")

if not recommendations:
    recommendations.append("- Dataset appears to be in good condition for training")

for rec in recommendations:
    print(rec)

## 7. Data Loading Performance Test

Test the performance of data loading with different configurations.

In [None]:
import time
from torch.utils.data import DataLoader

def test_dataloader_performance(data_dir, image_sizes=[32, 64], batch_sizes=[16, 32, 64], num_workers_list=[0, 2, 4]):
    """Test data loading performance with different configurations"""
    
    print("Testing DataLoader Performance...")
    print("=" * 50)
    
    results = []
    
    for image_size in image_sizes:
        for batch_size in batch_sizes:
            for num_workers in num_workers_list:
                try:
                    # Create dataset
                    dataset = SportsDataset(
                        root_dir=data_dir,
                        split='train',
                        image_size=image_size,
                        augment=True
                    )
                    
                    # Create dataloader
                    dataloader = DataLoader(
                        dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True
                    )
                    
                    # Time loading first few batches
                    start_time = time.time()
                    batch_count = 0
                    max_batches = min(10, len(dataloader))  # Test first 10 batches
                    
                    for batch_idx, (images, labels) in enumerate(dataloader):
                        if batch_idx >= max_batches:
                            break
                        batch_count += 1
                        
                        # Ensure data is actually loaded
                        _ = images.mean()
                    
                    end_time = time.time()
                    total_time = end_time - start_time
                    time_per_batch = total_time / batch_count if batch_count > 0 else float('inf')
                    
                    result = {
                        'image_size': image_size,
                        'batch_size': batch_size,
                        'num_workers': num_workers,
                        'total_time': total_time,
                        'time_per_batch': time_per_batch,
                        'batches_tested': batch_count
                    }
                    
                    results.append(result)
                    
                    print(f"Size: {image_size:2d}, Batch: {batch_size:2d}, Workers: {num_workers}, "
                          f"Time/batch: {time_per_batch:.3f}s")
                    
                except Exception as e:
                    print(f"Error with config (size:{image_size}, batch:{batch_size}, workers:{num_workers}): {e}")
    
    return results

# Test different configurations
perf_results = test_dataloader_performance(DATA_DIR)

# Analyze results
if perf_results:
    # Convert to DataFrame for easier analysis
    df_results = pd.DataFrame(perf_results)
    
    print("\nPerformance Analysis:")
    print("-" * 30)
    
    # Best configuration overall
    best_config = df_results.loc[df_results['time_per_batch'].idxmin()]
    print(f"\nBest overall configuration:")
    print(f"  Image size: {int(best_config['image_size'])}")
    print(f"  Batch size: {int(best_config['batch_size'])}")
    print(f"  Num workers: {int(best_config['num_workers'])}")
    print(f"  Time per batch: {best_config['time_per_batch']:.3f}s")
    
    # Best for each image size
    print("\nBest configuration for each image size:")
    for size in sorted(df_results['image_size'].unique()):
        size_data = df_results[df_results['image_size'] == size]
        best_for_size = size_data.loc[size_data['time_per_batch'].idxmin()]
        print(f"  {int(size)}x{int(size)}: batch={int(best_for_size['batch_size'])}, "
              f"workers={int(best_for_size['num_workers'])}, "
              f"time={best_for_size['time_per_batch']:.3f}s/batch")
    
    # Visualize performance results
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Performance by batch size
    for size in sorted(df_results['image_size'].unique()):
        size_data = df_results[df_results['image_size'] == size]
        batch_performance = size_data.groupby('batch_size')['time_per_batch'].min()
        axes[0].plot(batch_performance.index, batch_performance.values, 
                    marker='o', label=f'{int(size)}x{int(size)}')
    
    axes[0].set_title('Loading Time vs Batch Size', fontweight='bold')
    axes[0].set_xlabel('Batch Size')
    axes[0].set_ylabel('Time per Batch (seconds)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Performance by number of workers
    for size in sorted(df_results['image_size'].unique()):
        size_data = df_results[df_results['image_size'] == size]
        worker_performance = size_data.groupby('num_workers')['time_per_batch'].min()
        axes[1].plot(worker_performance.index, worker_performance.values, 
                    marker='s', label=f'{int(size)}x{int(size)}')
    
    axes[1].set_title('Loading Time vs Number of Workers', fontweight='bold')
    axes[1].set_xlabel('Number of Workers')
    axes[1].set_ylabel('Time per Batch (seconds)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 8. Summary and Recommendations

Summary of dataset analysis and recommendations for model training.

In [None]:
print("\n" + "="*70)
print("DATASET EXPLORATION SUMMARY")
print("="*70)

# Dataset overview summary
print("\nüìä DATASET OVERVIEW:")
print(f"   ‚Ä¢ Total images: {dataset_stats['total_images']:,}")
print(f"   ‚Ä¢ Number of classes: {len(dataset_stats['classes'])}")
print(f"   ‚Ä¢ Classes: {', '.join(dataset_stats['classes'])}")

if 'train' in dataset_stats['splits']:
    train_total = dataset_stats['splits']['train']['total_images']
    print(f"   ‚Ä¢ Training images: {train_total:,}")

if 'valid' in dataset_stats['splits']:
    valid_total = dataset_stats['splits']['valid']['total_images']
    print(f"   ‚Ä¢ Validation images: {valid_total:,}")
    
    if 'train' in dataset_stats['splits']:
        split_ratio = train_total / valid_total if valid_total > 0 else float('inf')
        print(f"   ‚Ä¢ Train/Val split ratio: {split_ratio:.1f}:1")

# Class balance summary
print("\n‚öñÔ∏è  CLASS BALANCE:")
if 'train' in dataset_stats['splits']:
    train_dist = dataset_stats['splits']['train']['class_distribution']
    train_counts = list(train_dist.values())
    if train_counts:
        balance_ratio = min(train_counts) / max(train_counts)
        cv = np.std(train_counts) / np.mean(train_counts)
        print(f"   ‚Ä¢ Balance ratio (min/max): {balance_ratio:.3f}")
        print(f"   ‚Ä¢ Coefficient of variation: {cv:.3f}")
        
        if balance_ratio > 0.8:
            print("   ‚úÖ Well-balanced dataset")
        elif balance_ratio > 0.6:
            print("   ‚ö†Ô∏è  Moderately imbalanced - consider class weighting")
        else:
            print("   ‚ùå Highly imbalanced - consider resampling techniques")

# Image properties summary
print("\nüñºÔ∏è  IMAGE PROPERTIES:")
if img_properties['dimensions']:
    widths, heights = zip(*img_properties['dimensions'])
    print(f"   ‚Ä¢ Average dimensions: {np.mean(widths):.0f}x{np.mean(heights):.0f} pixels")
    print(f"   ‚Ä¢ Size range: {min(widths)}x{min(heights)} to {max(widths)}x{max(heights)}")

if img_properties['aspect_ratios']:
    avg_ratio = np.mean(img_properties['aspect_ratios'])
    print(f"   ‚Ä¢ Average aspect ratio: {avg_ratio:.2f}")

if img_properties['file_sizes']:
    avg_size_kb = np.mean(img_properties['file_sizes']) / 1024
    print(f"   ‚Ä¢ Average file size: {avg_size_kb:.1f} KB")

# Quality assessment summary
print("\nüîç QUALITY ASSESSMENT:")
total_issues = sum(len(issues) for issues in quality_issues.values())
quality_score = max(0, 100 - (total_issues / processed * 100))
print(f"   ‚Ä¢ Quality score: {quality_score:.1f}/100")
print(f"   ‚Ä¢ Issues found: {total_issues} out of {processed} sampled images")

major_issues = [
    ('corrupted_images', 'Corrupted images'),
    ('extremely_small_images', 'Very small images'),
    ('extremely_large_images', 'Very large images')
]

for issue_key, issue_name in major_issues:
    if quality_issues[issue_key]:
        print(f"   ‚ö†Ô∏è  {issue_name}: {len(quality_issues[issue_key])}")

# Performance summary
print("\n‚ö° PERFORMANCE OPTIMIZATION:")
if perf_results:
    df_results = pd.DataFrame(perf_results)
    best_config = df_results.loc[df_results['time_per_batch'].idxmin()]
    print(f"   ‚Ä¢ Optimal batch size: {int(best_config['batch_size'])}")
    print(f"   ‚Ä¢ Optimal workers: {int(best_config['num_workers'])}")
    print(f"   ‚Ä¢ Loading time: {best_config['time_per_batch']:.3f}s per batch")

# Training recommendations
print("\nüéØ TRAINING RECOMMENDATIONS:")
print("\n   Data Preprocessing:")
print("   ‚Ä¢ Use data augmentation (rotation, flip, color jitter) ‚úÖ")
print("   ‚Ä¢ Apply normalization with ImageNet statistics ‚úÖ")
print("   ‚Ä¢ Consider image size: 32x32 for quick experiments, 64x64+ for better accuracy")

print("\n   Training Strategy:")
if 'train' in dataset_stats['splits'] and 'valid' in dataset_stats['splits']:
    train_dist = dataset_stats['splits']['train']['class_distribution']
    train_counts = list(train_dist.values())
    if train_counts:
        balance_ratio = min(train_counts) / max(train_counts)
        if balance_ratio < 0.8:
            print("   ‚Ä¢ Consider class weighting due to imbalance")
        print("   ‚Ä¢ Use cross-entropy loss for multi-class classification")
        print("   ‚Ä¢ Monitor both accuracy and per-class F1 scores")

print("\n   Model Architecture:")
print("   ‚Ä¢ Start with SimpleCNN for baseline")
print("   ‚Ä¢ Try ResNet for better feature extraction")
print("   ‚Ä¢ Consider dropout for regularization")

print("\n   Hyperparameters:")
if perf_results:
    df_results = pd.DataFrame(perf_results)
    best_config = df_results.loc[df_results['time_per_batch'].idxmin()]
    print(f"   ‚Ä¢ Batch size: {int(best_config['batch_size'])} (based on performance test)")
    print(f"   ‚Ä¢ Num workers: {int(best_config['num_workers'])} (based on performance test)")
print("   ‚Ä¢ Learning rate: Start with 1e-3, use scheduler")
print("   ‚Ä¢ Epochs: 50-100 with early stopping")

print("\n" + "="*70)
print("Dataset exploration complete! Ready for model training.")
print("="*70)