# Dataset Descriptives Analysis

This notebook loads multiple JSON ground-truth label datasets, computes summary statistics (unique labels, label counts, samples per label, labels per sample, missing labels) for each individual dataset and combined, and creates publication-quality plots plus a summary table.

**Inputs:** paths to JSON files containing lists of samples; each sample is expected to have a `labels` field which is a list of dictionaries with a `'label'` key.

**Outputs:**
- PDF plots: overall label distribution, dataset comparisons (counts and percentages), labels-per-sample histogram, label coverage.
- Summary CSV/table with combined and individual metrics.

**Main flow:**
1. `load_all_datasets`: loads and validates datasets.
2. `analyze_dataset`: safely extracts labels and computes metrics.
3. `create_publication_plots`: generates and saves the visualizations.
4. `create_summary_table`: builds the summary DataFrame.
5. `analyze_full_dataset`: orchestrates the steps and persists results.

**Assumptions / notes:** label presence per sample is deduplicated (i.e., multiple identical labels in one sample count once), malformed label entries are filtered out, and missing labels are reported.

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from collections import Counter
from typing import Dict, List, Set, Tuple
import numpy as np

# Configure matplotlib for publication quality (larger fonts for readability)
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'Times', 'serif'],
    'font.size': 12,  # Increased base font size
    'axes.labelsize': 14,  # Larger axis labels
    'axes.titlesize': 16,  # Larger titles
    'xtick.labelsize': 11,  # Larger tick labels
    'ytick.labelsize': 11,  # Larger tick labels
    'legend.fontsize': 11,  # Larger legend
    'figure.titlesize': 18,  # Larger figure title
    'lines.linewidth': 1.5,
    'axes.linewidth': 1.2,
    'grid.linewidth': 0.8,
    'xtick.major.width': 1.2,
    'ytick.major.width': 1.2,
    'axes.edgecolor': 'black',
    'text.color': 'black',
    'axes.labelcolor': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'savefig.facecolor': 'white',
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.top': True,
    'axes.spines.right': True
})

# Define the expected labels (from your original code)
PLACEHOLDERS = {
    "TITEL"         :  ["TITEL"],
    "VORNAME"       :  ["VORNAME"],
    "NACHNAME"      :  ["NACHNAME"],
    "FIRMA"         :  ["FIRMA"],
    "TELEFONNUMMER" :  ["TELEFONNUMMER"],
    "EMAIL"         :  ["EMAIL"],
    "FAX"           :  ["FAX"],
    "STRASSE"       :  ["STRASSE"],
    "HAUSNUMMER"    :  ["HAUSNUMMER"],
    "POSTLEITZAHL"  :  ["POSTLEITZAHL","PLZ","ZIP"],
    "WOHNORT"       :  ["WOHNORT","ORT","CITY"],
    "ZÄHLERNUMMER"  :  ["ZÄHLERNUMMER","METER_ID"],
    "ZÄHLERSTAND"   :  ["ZÄHLERSTAND","METER_READING"],
    "VERTRAGSNUMMER":  ["VERTRAGSNUMMER","ANGEBOTSNUMMER", "KUNDENNUMMER"],
    "ZAHLUNG"       :  ["BETRAG","ZAHLUNG","AMOUNT"],
    "BANK"          :  ["BANK"],
    "IBAN"          :  ["IBAN"],
    "BIC"           :  ["BIC"],
    "DATUM"         :  ["DATUM","DATE"],
    "GESENDET_MIT"  :  ["GESENDET_MIT"],
    "LINK"          :  ["LINK"]
}

EXPECTED_LABELS = set(PLACEHOLDERS.keys())

def load_all_datasets(file_paths: List[str]) -> Tuple[List[Dict], Dict[str, List[Dict]]]:
    """Load multiple datasets and combine them."""
    all_data = []
    individual_datasets = {}

    for file_path in file_paths:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            dataset_name = file_path.split('/')[-1].replace('.json', '')
            individual_datasets[dataset_name] = data
            all_data.extend(data)
            print(f"✓ Loaded {len(data)} samples from {file_path}")
        except Exception as e:
            print(f"❌ Error loading {file_path}: {e}")

    return all_data, individual_datasets

def analyze_dataset(data: List[Dict]) -> Dict:
    """Perform comprehensive analysis of the dataset."""
    if not data:
        return {}

    num_samples = len(data)
    all_labels = []
    samples_with_labels = []

    for sample in data:
        sample_labels = {label['label'] for label in sample.get('labels', [])}
        all_labels.extend(sample_labels)
        samples_with_labels.append(sample_labels)

    label_counts = Counter(all_labels)
    unique_labels = set(all_labels)
    missing_labels = EXPECTED_LABELS - unique_labels

    samples_per_label = {label: sum(1 for s in samples_with_labels if label in s)
                        for label in unique_labels}

    labels_per_sample = [len(sample_labels) for sample_labels in samples_with_labels]

    return {
        'num_samples': num_samples,
        'unique_labels': unique_labels,
        'label_counts': label_counts,
        'missing_labels': missing_labels,
        'samples_per_label': samples_per_label,
        'labels_per_sample': labels_per_sample
    }

def create_publication_plots(combined_analysis: Dict, individual_analyses: Dict[str, Dict],
                           output_dir: str = "./figures/"):
    """Create publication-quality plots for the dataset analysis."""

    # Create output directory if it doesn't exist
    import os
    os.makedirs(output_dir, exist_ok=True)

    # Black and white styling for professional appearance
    patterns = ['', '///', '...', '+++', 'xxx', '|||', '---', '\\\\\\']
    grays = ['0.2', '0.4', '0.6', '0.8', '0.3', '0.5', '0.7', '0.9']

    # 1. Combined Label Distribution (Horizontal Bar Chart) - More compact
    fig, ax = plt.subplots(figsize=(8, 6))  # Wider figure for better readability

    # Sort labels by frequency for better readability
    sorted_labels = sorted(combined_analysis['samples_per_label'].items(),
                          key=lambda x: x[1], reverse=True)
    labels, counts = zip(*sorted_labels)

    # Create horizontal bar chart with black and white styling
    y_pos = np.arange(len(labels))
    bars = ax.barh(y_pos, counts, color='black', alpha=0.8, height=0.6,
                   edgecolor='black', linewidth=0.8)

    # Formatting with better spacing
    ax.set_yticks(y_pos)
    ax.set_yticklabels(labels, fontweight='normal')
    ax.set_xlabel('Number of Samples', fontweight='bold')
    ax.set_title('Label Distribution Across Full Dataset', fontweight='bold', pad=20)
    ax.grid(axis='x', alpha=0.3, linewidth=0.8)

    # Add value labels on bars with better positioning
    for i, bar in enumerate(bars):
        width = bar.get_width()
        ax.text(width + max(counts) * 0.02, bar.get_y() + bar.get_height()/2,
                f'{int(width)}', ha='left', va='center', fontsize=10, fontweight='bold')

    # Tighter layout
    ax.margins(y=0.01)
    plt.tight_layout()
    plt.savefig(f'{output_dir}label_distribution.pdf', bbox_inches='tight', dpi=300)
    plt.show()

    # 2. Dataset Comparison (Grouped Bar Chart) - Optimized layout
    if len(individual_analyses) > 1:
        fig, ax = plt.subplots(figsize=(10, 6))  # Wider for better label readability

        # Get top 8 most common labels for readability (reduced from 10)
        top_labels = [label for label, _ in sorted_labels[:8]]

        # Prepare data for grouped bar chart
        dataset_names = list(individual_analyses.keys())
        x = np.arange(len(top_labels))
        width = 0.7 / len(dataset_names)  # Slightly wider bars

        for i, (dataset_name, analysis) in enumerate(individual_analyses.items()):
            values = [analysis['samples_per_label'].get(label, 0) for label in top_labels]
            offset = (i - len(dataset_names)/2 + 0.5) * width

            # Use different patterns and fills for different datasets
            if i < len(patterns):
                hatch = patterns[i]
                facecolor = grays[i % len(grays)]
            else:
                hatch = ''
                facecolor = grays[i % len(grays)]

            bars = ax.bar(x + offset, values, width,
                         label=dataset_name.replace('_', ' ').title(),
                         facecolor=facecolor, edgecolor='black', linewidth=1.0,
                         hatch=hatch, alpha=0.8)

        ax.set_xlabel('Label Type', fontweight='bold')
        ax.set_ylabel('Number of Samples', fontweight='bold')
        ax.set_title('Label Distribution Comparison Across Datasets', fontweight='bold', pad=20)
        ax.set_xticks(x)
        ax.set_xticklabels(top_labels, rotation=45, ha='right', fontweight='normal')
        ax.legend(loc='upper right')
        ax.grid(axis='y', alpha=0.3, linewidth=0.8)

        plt.tight_layout()
        plt.savefig(f'{output_dir}dataset_comparison.pdf', bbox_inches='tight', dpi=300)
        plt.show()

    # 2b. Dataset Comparison in Percentage (Grouped Bar Chart)
    if len(individual_analyses) > 1:
        fig, ax = plt.subplots(figsize=(10, 6))

        # Use combined label order from the first (count) chart
        all_labels = [label for label, _ in sorted_labels]
        dataset_names = list(individual_analyses.keys())
        x = np.arange(len(all_labels))
        width = 0.7 / len(dataset_names)  # bar width

        # Use the same patterns and grays as above
        # patterns and grays are already defined earlier in this function

        for i, (dataset_name, analysis) in enumerate(individual_analyses.items()):
            # Raw counts for every label
            counts = [analysis['samples_per_label'].get(label, 0) for label in all_labels]
            total = sum(counts) if sum(counts) > 0 else 1
            perc = [c / total * 100 for c in counts]

            offset = (i - len(dataset_names)/2 + 0.5) * width

            bars = ax.bar(
                x + offset,
                perc,
                width,
                label=dataset_name.replace('_', ' ').title(),
                facecolor=grays[i % len(grays)],
                edgecolor='black',
                linewidth=1.0,
                hatch=patterns[i] if i < len(patterns) else '',
                alpha=0.8
            )

        ax.set_xlabel('Label Type', fontweight='bold')
        ax.set_ylabel('Percentage of Label Instances (%)', fontweight='bold')
        ax.set_title('Label Distribution Across Datasets (Percentage)', fontweight='bold', pad=20)
        ax.set_xticks(x)
        ax.set_xticklabels(all_labels, rotation=45, ha='right', fontweight='normal')
        ax.legend(loc='upper right')
        ax.grid(axis='y', alpha=0.3, linewidth=0.8)

        plt.tight_layout()
        plt.savefig(f'{output_dir}dataset_comparison_percentage.pdf', bbox_inches='tight', dpi=300)
        plt.show()

    # 3. Labels per Sample Distribution - Simplified and larger
    fig, ax = plt.subplots(figsize=(7, 5))

    labels_per_sample = combined_analysis['labels_per_sample']
    bins = range(min(labels_per_sample), max(labels_per_sample) + 2)

    n, bins, patches = ax.hist(labels_per_sample, bins=bins, color='0.6',
                              alpha=0.8, edgecolor='black', linewidth=1.0)

    ax.set_xlabel('Number of Labels per Sample', fontweight='bold')
    ax.set_ylabel('Frequency (Number of Samples)', fontweight='bold')
    ax.set_title('Distribution of Labels per Sample', fontweight='bold', pad=20)
    ax.grid(axis='y', alpha=0.3, linewidth=0.8)

    # Add statistics annotation with black and white styling
    mean_labels = np.mean(labels_per_sample)
    median_labels = np.median(labels_per_sample)
    ax.axvline(mean_labels, color='black', linestyle='--', linewidth=2.0,
               label=f'Mean: {mean_labels:.1f}')
    ax.axvline(median_labels, color='black', linestyle=':', linewidth=2.0,
               label=f'Median: {median_labels:.1f}')
    ax.legend(loc='upper right')

    plt.tight_layout()
    plt.savefig(f'{output_dir}labels_per_sample.pdf', bbox_inches='tight', dpi=300)
    plt.show()

    # 4. Label Coverage Analysis - Optimized for readability
    fig, ax = plt.subplots(figsize=(8, 7))  # Taller for better label spacing

    # Calculate coverage percentages
    coverage_data = []
    for label in EXPECTED_LABELS:
        if label in combined_analysis['unique_labels']:
            coverage = (combined_analysis['samples_per_label'][label] /
                       combined_analysis['num_samples']) * 100
            coverage_data.append((label, coverage))
        else:
            coverage_data.append((label, 0))

    # Sort by coverage
    coverage_data.sort(key=lambda x: x[1], reverse=True)
    labels_cov, coverages = zip(*coverage_data)

    # Create grayscale styling based on coverage
    colors_map = []
    hatches = []
    for cov in coverages:
        if cov == 0:
            colors_map.append('white')  # White for missing
            hatches.append('///')  # Diagonal lines for missing
        elif cov < 25:
            colors_map.append('0.8')  # Light gray for very low coverage
            hatches.append('')
        elif cov < 50:
            colors_map.append('0.6')  # Medium gray for low coverage
            hatches.append('')
        elif cov < 75:
            colors_map.append('0.4')  # Darker gray for medium coverage
            hatches.append('')
        else:
            colors_map.append('0.2')  # Dark gray for high coverage
            hatches.append('')

    y_pos = np.arange(len(labels_cov))
    bars = ax.barh(y_pos, coverages, color=colors_map, alpha=1.0, height=0.6,
                   edgecolor='black', linewidth=1.0, hatch=hatches)

    ax.set_yticks(y_pos)
    ax.set_yticklabels(labels_cov, fontweight='normal')
    ax.set_xlabel('Coverage Percentage (%)', fontweight='bold')
    ax.set_title('Label Coverage Across Full Dataset', fontweight='bold', pad=20)
    ax.grid(axis='x', alpha=0.3, linewidth=0.8)
    ax.set_xlim(0, 100)

    # Add percentage labels with better formatting
    for i, bar in enumerate(bars):
        width = bar.get_width()
        if width > 0:
            ax.text(width + 2, bar.get_y() + bar.get_height()/2,
                    f'{width:.1f}%', ha='left', va='center', fontsize=10, fontweight='bold')

    # Add margins for better spacing
    ax.margins(y=0.01)
    plt.tight_layout()
    plt.savefig(f'{output_dir}label_coverage.pdf', bbox_inches='tight', dpi=300)
    plt.show()

    print(f"\n✓ All plots saved to {output_dir}")

def create_summary_table(combined_analysis: Dict, individual_analyses: Dict[str, Dict]) -> pd.DataFrame:
    """Create a comprehensive summary table."""

    summary_data = []

    # Add combined dataset row
    summary_data.append({
        'Dataset': 'Combined',
        'Samples': combined_analysis['num_samples'],
        'Unique Labels': len(combined_analysis['unique_labels']),
        'Total Instances': sum(combined_analysis['label_counts'].values()),
        'Avg Labels/Sample': np.mean(combined_analysis['labels_per_sample']),
        'Missing Labels': len(combined_analysis['missing_labels'])
    })

    # Add individual dataset rows
    for name, analysis in individual_analyses.items():
        summary_data.append({
            'Dataset': name.replace('_', ' ').title(),
            'Samples': analysis['num_samples'],
            'Unique Labels': len(analysis['unique_labels']),
            'Total Instances': sum(analysis['label_counts'].values()),
            'Avg Labels/Sample': np.mean(analysis['labels_per_sample']),
            'Missing Labels': len(analysis['missing_labels'])
        })

    df = pd.DataFrame(summary_data)
    df['Avg Labels/Sample'] = df['Avg Labels/Sample'].round(2)

    return df

# Main analysis function
def analyze_full_dataset(file_paths: List[str], output_dir: str = "./figures/"):
    """Complete analysis workflow for multiple dataset files."""

    print("🔍 Loading and analyzing datasets...")
    print("-" * 60)

    # Load all datasets
    combined_data, individual_datasets = load_all_datasets(file_paths)

    if not combined_data:
        print("❌ No data loaded. Please check file paths.")
        return

    # Analyze combined dataset
    combined_analysis = analyze_dataset(combined_data)

    # Analyze individual datasets
    individual_analyses = {}
    for name, data in individual_datasets.items():
        individual_analyses[name] = analyze_dataset(data)

    # Create publication-quality plots
    create_publication_plots(combined_analysis, individual_analyses, output_dir)

    # Create summary table
    summary_df = create_summary_table(combined_analysis, individual_analyses)

    print("\n📊 DATASET SUMMARY")
    print("=" * 60)
    print(summary_df.to_string(index=False))

    # Save summary table
    summary_df.to_csv(f'{output_dir}dataset_summary.csv', index=False)
    print(f"\n✓ Summary table saved to {output_dir}dataset_summary.csv")

    return combined_analysis, individual_analyses, summary_df

# Example usage
if __name__ == "__main__":
    # Define your dataset files
    dataset_files = [
        "../../../data/original/ground_truth_split/train_norm.json",
        "../../../data/original/ground_truth_split/validation_norm.json",
        "../../../data/original/ground_truth_split/test_norm.json"
    ]

    # Run the complete analysis
    combined_analysis, individual_analyses, summary_df = analyze_full_dataset(
        dataset_files,
        output_dir="./publication_figures/"
    )
