In [None]:
# Imports
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display  # type: ignore[import-not-found]

# Add src to path
sys.path.insert(0, str(Path.cwd().parent))

from src.data.dataset import VideoDataset
from src.utils.io import load_yaml

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Environment Setup

Load configuration and setup paths.

In [None]:
# Load config
config_path = Path.cwd().parent / 'configs' / 'default.yaml'
config = load_yaml(config_path)

print(f"Project: {config['project']['name']}")
print(f"Labels: {config['data']['labels']}")
print(f"Num classes: {config['data']['num_classes']}")

## 2. Dataset Statistics

Load dataset and examine basic statistics.

In [None]:
# Check if processed data exists
data_dir = Path.cwd().parent / config['data']['processed_dir']

if not data_dir.exists():
    print(f"⚠️ Data directory not found: {data_dir}")
    print("Please run preprocessing first:")
    print("  1. Extract frames: python scripts/extract_frames.py --input data/raw --output data/interim/frames")
    print("  2. Create splits: python scripts/split_dataset.py --input data/interim/frames --output data/processed")
else:
    print(f"✓ Data directory found: {data_dir}")

    # List splits
    splits = [d.name for d in data_dir.iterdir() if d.is_dir()]
    print(f"Available splits: {splits}")

In [None]:
# Load datasets (if data exists)
if data_dir.exists():
    labels = config['data']['labels']
    num_frames = config['data']['preprocessing']['num_frames']

    datasets = {}
    for split in ['train', 'val', 'test']:
        if (data_dir / split).exists():
            try:
                dataset = VideoDataset(
                    data_dir=data_dir,
                    labels=labels,
                    split=split,
                    num_frames=num_frames,
                    transform=None,
                )
                datasets[split] = dataset
                print(f"{split:5s}: {len(dataset):4d} samples")
            except Exception as e:
                print(f"Error loading {split}: {e}")
else:
    print("Skipping dataset loading (no data available)")

## 3. Class Distribution

Analyze class balance across splits.

In [None]:
if data_dir.exists() and len(datasets) > 0:
    # Collect class counts
    class_counts = {}
    for split, dataset in datasets.items():
        class_counts[split] = dataset.get_class_counts()

    # Create DataFrame
    df = pd.DataFrame(class_counts).T
    display(df)

    # Plot
    fig, axes = plt.subplots(1, len(datasets), figsize=(5 * len(datasets), 4))
    if len(datasets) == 1:
        axes = [axes]

    for idx, (split, counts) in enumerate(class_counts.items()):
        axes[idx].bar(counts.keys(), counts.values())
        axes[idx].set_title(f'{split.capitalize()} Split')
        axes[idx].set_xlabel('Class')
        axes[idx].set_ylabel('Count')
        axes[idx].tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.show()
else:
    print("No data to visualize")

## 4. Sample Visualization

Visualize random samples from the dataset.

In [None]:
if data_dir.exists() and 'train' in datasets:
    dataset = datasets['train']

    # Get random samples
    num_samples = min(6, len(dataset))
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    for i, idx in enumerate(indices):
        video, label = dataset[idx]

        # Get middle frame
        mid_frame = video.shape[0] // 2
        frame = video[mid_frame].permute(1, 2, 0).numpy()

        # Display
        axes[i].imshow(frame)
        axes[i].set_title(f'Sample {idx} | Class: {labels[label]}')
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()
else:
    print("No data to visualize")

## 5. Frame Statistics

Analyze video properties.

In [None]:
if data_dir.exists() and 'train' in datasets:
    dataset = datasets['train']

    # Sample a few videos and check properties
    sample_size = min(20, len(dataset))

    shapes = []
    for i in range(sample_size):
        video, _ = dataset[i]
        shapes.append(video.shape)

    print(f"Sampled {sample_size} videos:")
    print(f"  Frame shape (T, C, H, W): {shapes[0]}")
    print(f"  Consistent shapes: {len(set(shapes)) == 1}")

    # Pixel value range
    video, _ = dataset[0]
    print("\nPixel value range:")
    print(f"  Min: {video.min():.3f}")
    print(f"  Max: {video.max():.3f}")
    print(f"  Mean: {video.mean():.3f}")
    print(f"  Std: {video.std():.3f}")
else:
    print("No data to analyze")

## 6. Next Steps

After exploring the data:

1. **Address class imbalance** (if present):
   - Use class weights in loss function
   - Apply oversampling or undersampling
   - Consider focal loss

2. **Prepare for training**:
   - Review augmentation strategy in config
   - Select appropriate model architecture
   - Set training hyperparameters

3. **Start training**:
   ```bash
   python -m src.cli train --config configs/default.yaml
   ```

4. **Monitor training**:
   - Check logs in `logs/`
   - Monitor training curves
   - Watch for overfitting