In [40]:
"""
EMOTIC Dataset Augmentation Pipeline with Visualizations
Author: Enhanced for local Mac execution
Dataset: EMOTIC (EMOTions In Context)
"""


!pip install pandas==2.0.3 numpy==1.24.3 matplotlib==3.7.2 seaborn==0.12.2
!pip install pillow==10.0.0 scikit-learn tqdm
!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cpu

Looking in indexes: https://download.pytorch.org/whl/cpu


In [41]:


# ==============================================================================
# IMPORTS AND SETUP
# ==============================================================================

import os
import pandas as pd
import numpy as np
from torchvision import transforms
from PIL import Image
import random
import logging
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
plt.ioff() 
import seaborn as sns
from collections import Counter
import warnings
from tqdm import tqdm
from pathlib import Path
import shutil
import math

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
# ==============================================================================
EXCLUDED_CLASSES = ["Engagement", "Happiness", "Anticipation"]
# EXCLUDED_CLASSES = ["Engagement"]



In [42]:
class DatasetConfig:
    """Configuration for EMOTIC dataset paths."""
    # Dataset paths - Update these if your directory structure is different
    TRAIN_ANNOTATIONS = "archive/annots_arrs/annot_arrs_train.csv"
    VAL_ANNOTATIONS = "archive/annots_arrs/annot_arrs_val.csv"
    IMG_DIR = "archive/img_arrs/"
    OUTPUT_DIR = "archive/augmented_img_arrs/"
    VIZ_DIR = "visualizations/"
    
    @classmethod
    def check_paths(cls):
        """Verify that dataset paths exist."""
        paths_to_check = [
            (cls.TRAIN_ANNOTATIONS, "Training annotations"),
            (cls.VAL_ANNOTATIONS, "Validation annotations"),
            (cls.IMG_DIR, "Image directory")
        ]
        
        all_exist = True
        for path, name in paths_to_check:
            if not os.path.exists(path):
                logging.error(f"{name} not found at: {path}")
                all_exist = False
            else:
                logging.info(f"✓ {name} found: {path}")
        
        # Always recreate these two directories fresh
        
        for dir_path in [cls.OUTPUT_DIR, cls.VIZ_DIR]:
            if os.path.exists(dir_path):
                shutil.rmtree(dir_path)
                logging.info(f"✗ Cleared existing directory: {dir_path}")
            os.makedirs(dir_path)
            logging.info(f"✓ Created directory: {dir_path}")
    
                
        return all_exist

In [43]:
# -------------------------
# Visualization Functions
# -------------------------
def plot_class_distribution(df, discrete_labels, minority_classes, majority_classes, save_path="visualizations/"):
    """Plot class distribution bar chart with log scale option."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Calculate class frequencies
    class_counts = df[discrete_labels].sum().sort_values(ascending=False)
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
    
    # Linear scale plot
    colors = ['red' if cls in minority_classes.index else 'green' for cls in class_counts.index]
    bars1 = ax1.bar(range(len(class_counts)), class_counts.values, color=colors, alpha=0.7)
    ax1.set_xticks(range(len(class_counts)))
    ax1.set_xticklabels(class_counts.index, rotation=45, ha='right')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Class Distribution (Linear Scale)')
    ax1.axhline(y=class_counts.median(), color='blue', linestyle='--', label='Median', alpha=0.5)
    ax1.legend()
    
    # Log scale plot for better minority class visibility
    bars2 = ax2.bar(range(len(class_counts)), class_counts.values, color=colors, alpha=0.7)
    ax2.set_yscale('log')
    ax2.set_xticks(range(len(class_counts)))
    ax2.set_xticklabels(class_counts.index, rotation=45, ha='right')
    ax2.set_ylabel('Frequency (log scale)')
    ax2.set_title('Class Distribution (Log Scale) - Red: Minority, Green: Majority')
    ax2.axhline(y=class_counts.median(), color='blue', linestyle='--', label='Median', alpha=0.5)
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'class_distribution.png'), dpi=300, bbox_inches='tight')
    plt.close() 
    
    logging.info(f"Class distribution plot saved to {save_path}")
    return class_counts


In [44]:
def plot_class_cooccurrence(df, discrete_labels, save_path="visualizations/"):
    """Create heatmap showing emotion co-occurrence patterns."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Calculate co-occurrence matrix
    binary_data = df[discrete_labels].astype(int)
    cooccurrence = binary_data.T.dot(binary_data)
    
    # Normalize by diagonal (self-occurrence)
    with np.errstate(divide='ignore', invalid='ignore'):
        normalized_cooccurrence = cooccurrence / np.diag(cooccurrence)[:, None]
        normalized_cooccurrence = np.nan_to_num(normalized_cooccurrence)
    
    # Create heatmap
    plt.figure(figsize=(16, 14))
    sns.heatmap(normalized_cooccurrence, 
                xticklabels=discrete_labels,
                yticklabels=discrete_labels,
                cmap='YlOrRd',
                vmin=0, vmax=1,
                square=True,
                cbar_kws={'label': 'Co-occurrence Probability'},
                fmt='.2f',
                linewidths=0.5)
    
    plt.title('Emotion Co-occurrence Heatmap\n(Probability of column emotion given row emotion)', fontsize=14)
    plt.xlabel('Co-occurring Emotion', fontsize=12)
    plt.ylabel('Primary Emotion', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'cooccurrence_heatmap.png'), dpi=300, bbox_inches='tight')
    plt.close() 
    
    logging.info(f"Co-occurrence heatmap saved to {save_path}")
    return cooccurrence


In [45]:
def plot_label_count_distribution(df, discrete_labels, save_path="visualizations/"):
    """Plot distribution of number of labels per image."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Count labels per image
    label_counts = df[discrete_labels].sum(axis=1)
    count_distribution = Counter(label_counts)
    
    # Create bar plot
    fig, ax = plt.subplots(figsize=(12, 6))
    
    x_values = sorted(count_distribution.keys())
    y_values = [count_distribution[x] for x in x_values]
    
    bars = ax.bar(x_values, y_values, color='steelblue', alpha=0.8, edgecolor='black')
    
    # Add value labels on bars
    for bar, val in zip(bars, y_values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{val}\n({val/len(df)*100:.1f}%)',
                ha='center', va='bottom', fontsize=9)
    
    ax.set_xlabel('Number of Labels per Image', fontsize=12)
    ax.set_ylabel('Number of Images', fontsize=12)
    ax.set_title('Distribution of Multi-label Samples', fontsize=14)
    ax.set_xticks(x_values)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add statistics
    mean_labels = label_counts.mean()
    median_labels = label_counts.median()
    ax.axvline(x=mean_labels, color='red', linestyle='--', label=f'Mean: {mean_labels:.2f}', alpha=0.7)
    ax.axvline(x=median_labels, color='green', linestyle='--', label=f'Median: {median_labels:.1f}', alpha=0.7)
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'label_count_distribution.png'), dpi=300, bbox_inches='tight')
    plt.close() 
    
    logging.info(f"Label count distribution saved to {save_path}")
    return count_distribution


In [46]:
def plot_minority_analysis(df, discrete_labels, minority_classes, majority_classes, exclude_classes, save_path="visualizations/"):
    """Analyze and visualize minority vs majority class distribution in samples."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Categorize each sample
    minority_only = 0
    majority_only = 0
    mixed = 0
    
    minority_indices = [discrete_labels.get_loc(cls) for cls in minority_classes.index]
    majority_indices = [discrete_labels.get_loc(cls) for cls in majority_classes.index if cls not in exclude_classes]
    
    for _, row in df.iterrows():
        labels = row[discrete_labels]
        has_minority = any(labels.iloc[idx] == 1 for idx in minority_indices)
        has_majority = any(labels.iloc[idx] == 1 for idx in majority_indices)
        
        if has_minority and not has_majority:
            minority_only += 1
        elif has_majority and not has_minority:
            majority_only += 1
        elif has_minority and has_majority:
            mixed += 1
    
    # Create donut chart
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    # Donut chart
    sizes = [minority_only, mixed, majority_only]
    labels = ['Minority Only', 'Mixed (Minority + Majority)', 'Majority Only']
    colors = ['#ff6b6b', '#ffd93d', '#6bcf7f']
    explode = (0.05, 0.05, 0)
    
    wedges, texts, autotexts = ax1.pie(sizes, labels=labels, colors=colors, 
                                         autopct=lambda pct: f'{pct:.1f}%\n({int(pct/100*len(df))})',
                                         explode=explode, startangle=90,
                                         textprops={'fontsize': 10})
    
    # Create donut effect
    centre_circle = plt.Circle((0, 0), 0.70, fc='white')
    ax1.add_artist(centre_circle)
    ax1.set_title('Sample Distribution Analysis', fontsize=14, pad=20)
    
    # Bar chart for better comparison
    ax2.bar(labels, sizes, color=colors, alpha=0.8, edgecolor='black')
    for i, (label, size) in enumerate(zip(labels, sizes)):
        ax2.text(i, size, f'{size}\n({size/len(df)*100:.1f}%)', 
                ha='center', va='bottom', fontsize=10)
    
    ax2.set_ylabel('Number of Samples', fontsize=12)
    ax2.set_title('Sample Category Distribution', fontsize=14)
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'minority_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()  
    
    logging.info(f"Minority analysis saved to {save_path}")
    logging.info(f"Minority only: {minority_only}, Mixed: {mixed}, Majority only: {majority_only}")
    
    return minority_only, mixed, majority_only


In [47]:
def plot_class_imbalance_ratio(minority_classes, majority_classes, save_path="visualizations/"):
    """Visualize imbalance ratio for each class."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Combine all classes
    all_classes = pd.concat([minority_classes, majority_classes])
    max_count = all_classes.max()
    
    # Calculate imbalance ratio (max_count / class_count)
    imbalance_ratios = max_count / all_classes
    imbalance_ratios = imbalance_ratios.sort_values(ascending=False)
    
    # Create color map based on severity
    colors = []
    for ratio in imbalance_ratios.values:
        if ratio > 10:
            colors.append('#d32f2f')  # Severe imbalance (red)
        elif ratio > 5:
            colors.append('#ff9800')  # Moderate imbalance (orange)
        elif ratio > 2:
            colors.append('#ffc107')  # Mild imbalance (yellow)
        else:
            colors.append('#4caf50')  # Balanced (green)
    
    # Create horizontal bar chart
    fig, ax = plt.subplots(figsize=(12, 10))
    
    y_pos = np.arange(len(imbalance_ratios))
    bars = ax.barh(y_pos, imbalance_ratios.values, color=colors, alpha=0.8, edgecolor='black')
    
    # Add value labels
    for i, (bar, val) in enumerate(zip(bars, imbalance_ratios.values)):
        width = bar.get_width()
        label = f'{val:.1f}x'
        ax.text(width, bar.get_y() + bar.get_height()/2., label,
                ha='left', va='center', fontsize=9, fontweight='bold')
    
    ax.set_yticks(y_pos)
    ax.set_yticklabels(imbalance_ratios.index, fontsize=10)
    ax.set_xlabel('Imbalance Ratio (Max Count / Class Count)', fontsize=12)
    ax.set_title('Class Imbalance Severity Analysis', fontsize=14)
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='#d32f2f', label='Severe (>10x)'),
        Patch(facecolor='#ff9800', label='Moderate (5-10x)'),
        Patch(facecolor='#ffc107', label='Mild (2-5x)'),
        Patch(facecolor='#4caf50', label='Balanced (<2x)')
    ]
    ax.legend(handles=legend_elements, loc='lower right')
    
    # Add reference lines
    ax.axvline(x=2, color='gray', linestyle='--', alpha=0.5)
    ax.axvline(x=5, color='gray', linestyle='--', alpha=0.5)
    ax.axvline(x=10, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'imbalance_ratio.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    logging.info(f"Imbalance ratio plot saved to {save_path}")
    return imbalance_ratios


In [48]:

def generate_augmentation_report(class_counts, minority_classes, filtered_annotations, save_path="visualizations/"):
    """Generate a summary report of augmentation strategy."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    report = []
    report.append("=" * 60)
    report.append("AUGMENTATION STRATEGY REPORT")
    report.append("=" * 60)
    report.append(f"\nTotal Classes: {len(class_counts)}")
    report.append(f"Minority Classes: {len(minority_classes)}")
    report.append(f"Images Selected for Augmentation: {len(filtered_annotations)}")
    report.append("\n" + "-" * 40)
    report.append("MINORITY CLASSES REQUIRING AUGMENTATION:")
    report.append("-" * 40)
    
    for cls, count in minority_classes.items():
        weight = sum(minority_classes.values) / count
        report.append(f"  {cls:20s}: {int(count):5d}  samples → {weight:.1f}x augmentation")
    
    report_text = "\n".join(report)
    
    # Save report
    with open(os.path.join(save_path, 'augmentation_report.txt'), 'w') as f:
        f.write(report_text)
    
    print(report_text)
    return report_text


# -------------------------
# Helper Functions
# -------------------------
def detect_minority_classes(csv_path, exclude_classes=EXCLUDED_CLASSES, threshold=2000):
    """Detect minority classes and their distribution."""
    df = pd.read_csv(csv_path)
    
    # Get emotion columns specifically (columns 8-33, total of 26 emotion labels)
    emotion_start_idx = 8
    emotion_end_idx = 34  # exclusive, so includes up to index 33
    emotion_columns = df.columns[emotion_start_idx:emotion_end_idx].tolist()
    
    # Double-check by excluding known non-emotion columns
    non_emotion_columns = ['X_min', 'Y_min', 'X_max', 'Y_max', 'Arr_name', 'Crop_name']
    emotion_labels = [col for col in emotion_columns if col not in non_emotion_columns]
    
    print(f"Identified emotion columns ({len(emotion_labels)}): {emotion_labels}")
    
    # Calculate class frequencies for emotion labels only
    numeric_data = df[emotion_labels].apply(pd.to_numeric, errors='coerce')
    class_counts = numeric_data.sum()

    # Use 2000 as threshold instead of median
    threshold_value = threshold
    
    # Identify minority and majority classes
    minority_classes = class_counts[class_counts < threshold_value]
    majority_classes = class_counts[class_counts >= threshold_value]

    # Exclude specific classes from majority
    majority_classes = majority_classes.drop(exclude_classes, errors='ignore')

    logging.info(f"Total emotion labels found: {len(emotion_labels)}")
    logging.info(f"Threshold used: {threshold_value} instances")
    logging.info(f"Minority Classes (<{threshold_value}): {list(minority_classes.index)}")
    logging.info(f"Majority Classes (>={threshold_value}): {list(majority_classes.index)}")

    return minority_classes, majority_classes

def filter_minority_annotations(df, minority_classes, exclude_classes):
    """Filter annotations for images containing minority classes, excluding majority classes."""
    discrete_labels = df.columns[8:34]
    minority_indices = [discrete_labels.get_loc(cls) for cls in minority_classes.index]
    exclude_indices = [discrete_labels.get_loc(cls) for cls in exclude_classes if cls in discrete_labels]

    filtered_annotations = []
    for _, row in df.iterrows():
        labels = row[discrete_labels]
        is_minority = any(labels.iloc[idx] == 1 for idx in minority_indices)
        overlaps_majority = any(labels.iloc[idx] == 1 for idx in exclude_indices)

        if is_minority and not overlaps_majority:
            filtered_annotations.append(row)

    logging.info(f"Filtered {len(filtered_annotations)} images for augmentation.")
    return pd.DataFrame(filtered_annotations)


In [None]:


# -------------------------
# Augmentation Class
# -------------------------
class Augmentation:
    def __init__(self, img_dir, output_dir, transform=None):
        self.img_dir = img_dir
        self.output_dir = output_dir
        self.transform = transform or transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
        ])

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

    def augment_and_save(self, annotations, minority_classes):
        total_minority = sum(minority_classes.values)
        class_weights = {cls: total_minority / freq for cls, freq in minority_classes.items()}

        # --- Dynamic cap based on distribution (90th percentile) ---
        # Ensures super-rare classes don't explode augmentation counts
        _cap_percentile = 70
        _weights_array = np.array(list(class_weights.values()), dtype=float)
        # Optional safety: drop non-finite values (in case any freq == 0)
        _weights_array = _weights_array[np.isfinite(_weights_array)]
        cap = np.percentile(_weights_array, _cap_percentile) if len(_weights_array) else 1.0
        cap = max(1.0, cap)  # never below 1
        # -----------------------------------------------------------
        
        aug_counter = Counter()
        print("\nAugmenting images...")
        for _, row in tqdm(annotations.iterrows(), total=len(annotations), desc="Processing"):
            img_path = os.path.join(self.img_dir, row['Crop_name'])
            if not os.path.exists(img_path):
                logging.warning(f"Image not found: {img_path}")
                continue

            try:
                image = np.load(img_path)
                if len(image.shape) == 2:
                    image = np.stack([image] * 3, axis=-1)
                elif image.shape[-1] != 3:
                    raise ValueError(f"Unexpected image shape: {image.shape}")

                pil_image = Image.fromarray(image.astype(np.uint8))

                discrete_cols = annotations.columns[8:34]
                categories = [discrete_cols[idx] for idx, val in enumerate(row[8:34]) if val == 1]

                weight = 1
                for cat in categories:
                    if cat in minority_classes.index:
                        weight = max(weight, class_weights.get(cat, 1))
                num_augmentations = int(min(weight, cap))


                for i in range(num_augmentations):
                    augmented_image = self.transform(pil_image)
                    base_name = row['Crop_name'].replace('.npy', '')
                    output_path = os.path.join(self.output_dir, f"aug_{i}_{base_name}.npy")
                    np.save(output_path, np.array(augmented_image))
                
                # Count augmented samples per class (only for classes present in this image)
                for cat in categories:
                    aug_counter[cat] += num_augmentations

            except Exception as e:
                logging.error(f"Error augmenting image {img_path}: {str(e)}")
        
        # Save the augmentation report (label, augmented_count)
        viz_dir = Path(DatasetConfig.VIZ_DIR)
        viz_dir.mkdir(parents=True, exist_ok=True)
        import pandas as pd
        rep = pd.DataFrame({"label": list(aug_counter.keys()),
                            "augmented_count": list(aug_counter.values())})
        rep.to_csv(viz_dir / "augmentation_report.csv", index=False)
        return rep






In [50]:
def plot_before_after_counts(df_train, discrete_labels):
    import numpy as np, pandas as pd, matplotlib.pyplot as plt
    from pathlib import Path

    report_csv = Path(DatasetConfig.VIZ_DIR) / "augmentation_report.csv"
    orig_counts = df_train[discrete_labels].sum().astype(int).rename("original_count").to_frame()

    if report_csv.exists():
        rep = pd.read_csv(report_csv)
        aug_counts = rep.set_index("label")["augmented_count"].astype(int).rename("augmented_count")
    else:
        aug_counts = pd.Series(0, index=orig_counts.index, name="augmented_count")

    summary = orig_counts.join(aug_counts, how="left").fillna(0).astype(int)
    summary["post_count"] = summary["original_count"] + summary["augmented_count"]

    summary_sorted = summary.sort_values("post_count", ascending=False)
    idx = np.arange(len(summary_sorted)); width = 0.45

    plt.figure(figsize=(12,6))
    plt.bar(idx - width/2, summary_sorted["original_count"].values, width, label="Before")
    plt.bar(idx + width/2, summary_sorted["post_count"].values, width, label="After (Original + Augmented)")
    plt.xticks(idx, summary_sorted.index, rotation=80, ha="right")
    plt.ylabel("Count"); plt.title("Per-class counts: Before vs After Augmentation")
    plt.legend(); plt.tight_layout()
    Path(DatasetConfig.VIZ_DIR).mkdir(parents=True, exist_ok=True)
    plt.savefig(Path(DatasetConfig.VIZ_DIR) / "per_class_counts_before_vs_after.png", dpi=200)
    plt.close()  # Close the plot to free memory
    print("Plot saved to:", Path(DatasetConfig.VIZ_DIR) / "per_class_counts_before_vs_after.png")   

In [51]:

# -------------------------
# Main Function
# -------------------------
def main():
    """Main execution function for the augmentation pipeline."""
    
    # Check if paths exist
    print("\n" + "="*60)
    print("EMOTIC DATASET AUGMENTATION PIPELINE")
    print("="*60)
    print("\n[Setup] Checking dataset paths...")
    
    if not DatasetConfig.check_paths():
        print("\n❌ Error: Some required paths are missing.")
        print("Please ensure the 'archive' folder is in the current directory with:")
        print("  - archive/annots_arrs/annot_arrs_train.csv")
        print("  - archive/annots_arrs/annot_arrs_val.csv")
        print("  - archive/img_arrs/")
        return
    
    # Use configuration paths
    annotations_path = DatasetConfig.TRAIN_ANNOTATIONS
    img_dir = DatasetConfig.IMG_DIR
    output_dir = DatasetConfig.OUTPUT_DIR
    viz_dir = DatasetConfig.VIZ_DIR

    # Load annotations
    annotations = pd.read_csv(annotations_path)
    discrete_labels = annotations.columns[8:34]
    
    print(f"\n[Data] Loaded {len(annotations)} training samples")
    print(f"[Data] Found {len(discrete_labels)} emotion categories")
    
    # Step 1: Detect minority and majority classes
    print("\n[Step 1] Detecting minority and majority classes...")
    minority_classes, majority_classes = detect_minority_classes(annotations_path, exclude_classes=EXCLUDED_CLASSES)
    
    # Visualization 1: Class Distribution
    print("\n[Visualization 1/5] Generating class distribution plots...")
    class_counts = plot_class_distribution(annotations, discrete_labels, minority_classes, majority_classes, viz_dir)
    
    # Visualization 2: Class Co-occurrence Heatmap
    print("\n[Visualization 2/5] Generating co-occurrence heatmap...")
    cooccurrence = plot_class_cooccurrence(annotations, discrete_labels, viz_dir)
    
    # Visualization 3: Label Count Distribution
    print("\n[Visualization 3/5] Analyzing multi-label distribution...")
    label_distribution = plot_label_count_distribution(annotations, discrete_labels, viz_dir)
    
    # Step 2: Filter annotations for minority class images
    print("\n[Step 2] Filtering annotations for minority class images...")
    filtered_annotations = filter_minority_annotations(annotations, minority_classes, EXCLUDED_CLASSES)
    
    # Visualization 4: Minority vs Majority Analysis
    print("\n[Visualization 4/5] Analyzing minority/majority sample distribution...")
    minority_only, mixed, majority_only = plot_minority_analysis(
        annotations, discrete_labels, minority_classes, majority_classes, EXCLUDED_CLASSES, viz_dir
    )
    
    # Visualization 5: Class Imbalance Ratio
    print("\n[Visualization 5/5] Computing class imbalance ratios...")
    imbalance_ratios = plot_class_imbalance_ratio(minority_classes, majority_classes, viz_dir)
    
    
    # Step 3: Perform augmentation
    print("\n[Step 3] Starting augmentation process...")
    augmenter = Augmentation(img_dir, output_dir)

    # This now returns the actual augmentation counts DataFrame
    augmentation_report = augmenter.augment_and_save(filtered_annotations, minority_classes)

    # Now plot the before vs after chart using the new report
    plot_before_after_counts(annotations, list(discrete_labels))
    
    print("\n" + "="*60)
    print("AUGMENTATION COMPLETE")
    print("="*60)
    print(f"✓ Augmented {len(filtered_annotations)} images")
    print(f"✓ Visualizations saved to: {viz_dir}")
    print(f"✓ Augmented images saved to: {output_dir}")
    print("="*60)

In [52]:
# -------------------------
# Execution
# -------------------------
if __name__ == "__main__":
    # Run the main pipeline
    main()

2025-08-18 16:58:16,896 - INFO - ✓ Training annotations found: archive/annots_arrs/annot_arrs_train.csv
2025-08-18 16:58:16,896 - INFO - ✓ Validation annotations found: archive/annots_arrs/annot_arrs_val.csv
2025-08-18 16:58:16,897 - INFO - ✓ Image directory found: archive/img_arrs/



EMOTIC DATASET AUGMENTATION PIPELINE

[Setup] Checking dataset paths...


2025-08-18 16:58:30,987 - INFO - ✗ Cleared existing directory: archive/augmented_img_arrs/
2025-08-18 16:58:30,988 - INFO - ✓ Created directory: archive/augmented_img_arrs/
2025-08-18 16:58:30,989 - INFO - ✗ Cleared existing directory: visualizations/
2025-08-18 16:58:30,989 - INFO - ✓ Created directory: visualizations/
2025-08-18 16:58:31,094 - INFO - Total emotion labels found: 26
2025-08-18 16:58:31,095 - INFO - Threshold used: 2000 instances
2025-08-18 16:58:31,095 - INFO - Minority Classes (<2000): ['Peace', 'Affection', 'Esteem', 'Surprise', 'Sympathy', 'Doubt/Confusion', 'Disconnection', 'Fatigue', 'Embarrassment', 'Yearning', 'Disapproval', 'Aversion', 'Annoyance', 'Anger', 'Sensitivity', 'Sadness', 'Disquietment', 'Fear', 'Pain', 'Suffering']
2025-08-18 16:58:31,095 - INFO - Majority Classes (>=2000): ['Confidence', 'Pleasure', 'Excitement']



[Data] Loaded 24639 training samples
[Data] Found 26 emotion categories

[Step 1] Detecting minority and majority classes...
Identified emotion columns (26): ['Peace', 'Affection', 'Esteem', 'Anticipation', 'Engagement', 'Confidence', 'Happiness', 'Pleasure', 'Excitement', 'Surprise', 'Sympathy', 'Doubt/Confusion', 'Disconnection', 'Fatigue', 'Embarrassment', 'Yearning', 'Disapproval', 'Aversion', 'Annoyance', 'Anger', 'Sensitivity', 'Sadness', 'Disquietment', 'Fear', 'Pain', 'Suffering']

[Visualization 1/5] Generating class distribution plots...


2025-08-18 16:58:32,168 - INFO - Class distribution plot saved to visualizations/



[Visualization 2/5] Generating co-occurrence heatmap...


2025-08-18 16:58:33,222 - INFO - Co-occurrence heatmap saved to visualizations/



[Visualization 3/5] Analyzing multi-label distribution...


2025-08-18 16:58:33,621 - INFO - Label count distribution saved to visualizations/



[Step 2] Filtering annotations for minority class images...


2025-08-18 16:58:36,923 - INFO - Filtered 4098 images for augmentation.



[Visualization 4/5] Analyzing minority/majority sample distribution...


2025-08-18 16:58:40,731 - INFO - Minority analysis saved to visualizations/
2025-08-18 16:58:40,731 - INFO - Minority only: 6396, Mixed: 2477, Majority only: 6341



[Visualization 5/5] Computing class imbalance ratios...


2025-08-18 16:58:41,394 - INFO - Imbalance ratio plot saved to visualizations/



[Step 3] Starting augmentation process...

Augmenting images...


Processing: 100%|██████████| 4098/4098 [01:59<00:00, 34.35it/s]


Plot saved to: visualizations/per_class_counts_before_vs_after.png

AUGMENTATION COMPLETE
✓ Augmented 4098 images
✓ Visualizations saved to: visualizations/
✓ Augmented images saved to: archive/augmented_img_arrs/
