In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import random
import os
from PIL import Image

# --- 1. CONFIGURATION AND UTILITIES ---

DATA_DIR = Path('/kaggle/input/physionet-ecg-image-digitization/train') 
TRAIN_META = Path('/kaggle/input/physionet-ecg-image-digitization/train.csv')

LEADS = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
IMAGE_VARIANTS = {
    '0001': 'Original Color Image (Clean)',
    '0003': 'Printed & Scanned (Color)',
    '0004': 'Printed & Scanned (B/W)',
    '0005': 'Mobile Photo (Color Print)',
    '0006': 'Mobile Photo (Screen)',
    '0009': 'Mobile Photo (Stained/Soaked)',
    '0010': 'Mobile Photo (Extensive Damage)',
    '0011': 'Scan with Mold (Color)',
    '0012': 'Scan with Mold (B/W)',
}

def load_signal(id_: int) -> pd.DataFrame:
    """Loads the ground-truth time-series signal CSV for a given ID."""
    # Crucially, convert id_ to string to avoid the TypeError in path concatenation
    id_str = str(id_)
    signal_path = DATA_DIR / id_str / f"{id_str}.csv"
    
    if not signal_path.exists():
        print(f"Warning: Signal file not found at {signal_path}")
        return pd.DataFrame()
        
    return pd.read_csv(signal_path)

def load_image(id_: int, suffix: str) -> Image.Image | None:
    """Loads a specific ECG image variant for a given ID and suffix."""
    id_str = str(id_)
    
    # Try the specific file name pattern
    img_path = DATA_DIR / id_str / f"{id_str}-{suffix}.png"
    
    if not img_path.exists():
        # Fallback: find any PNG in the folder that matches the suffix pattern
        # This is a robust check, but the competition uses the exact naming.
        img_candidates = list((DATA_DIR / id_str).glob(f"*-{suffix}.png"))
        if img_candidates:
            img_path = img_candidates[0]
        else:
            return None
            
    try:
        return Image.open(img_path)
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")
        return None

def plot_ecg_signal(signal_df: pd.DataFrame, id_: int, fs: int):
    """Plots the 12-lead ECG time series from the ground truth."""
    if signal_df.empty:
        print("Empty signal data, skipping plot.")
        return

    # Calculate time vector
    time = np.arange(len(signal_df)) / fs
    
    # Define a 4x3 grid for the 12 leads
    fig, axes = plt.subplots(4, 3, figsize=(18, 12), sharex=False, sharey=False)
    axes = axes.flatten()
    
    # Titles for the subplot (based on standard clinical arrangement)
    lead_layout = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    for i, lead in enumerate(lead_layout):
        ax = axes[i]
        
        # Check if the lead exists in the DataFrame (should for ground truth)
        if lead in signal_df.columns:
            # Check length of the current lead
            lead_len_sec = len(signal_df[lead]) / fs
            
            # Recalculate time for the current lead if it's not lead II (10s)
            current_time = np.arange(len(signal_df[lead])) / fs
            
            ax.plot(current_time, signal_df[lead], color='#1f77b4', linewidth=1)
            ax.set_title(f'Lead {lead} ({lead_len_sec:.1f}s, {len(signal_df[lead])} samples)', fontsize=10)
            ax.set_ylabel('Amplitude (mV)', fontsize=8)
            ax.set_xlabel('Time (s)', fontsize=8)
            ax.grid(color='gray', linestyle=':', linewidth=0.5, alpha=0.5)
            
    fig.suptitle(f'Ground Truth 12-Lead ECG Signal (ID: {id_}, FS: {fs} Hz)', fontsize=16, y=1.02)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.show()

# --- 2. MAIN EDA EXECUTION ---

def run_eda(sample_size_meta=5, sample_size_signal=5):
    """Executes the full EDA pipeline."""
    sns.set_theme(style="whitegrid")
    
    # --- Step 1: Metadata Analysis (train.csv) ---
    print("=" * 60)
    print("STEP 1: METADATA ANALYSIS (train.csv)")
    print("=" * 60)
    
    try:
        # Check if the train metadata file exists
        if not TRAIN_META.exists():
            print(f"Error: Metadata file not found at {TRAIN_META}. Cannot proceed.")
            print("Please ensure 'train.csv' is in the current directory or adjust the TRAIN_META path.")
            return

        train_meta = pd.read_csv(TRAIN_META)
        print(f"Loaded {len(train_meta)} records.")
        print("\n--- Basic Info ---")
        print(train_meta.info())
        print("\n--- Descriptive Statistics ---")
        print(train_meta.describe())
        
        # A. Sampling Frequency (fs) Distribution
        plt.figure(figsize=(8, 5))
        sns.histplot(train_meta['fs'], bins=10, kde=True, color='purple')
        plt.title('Distribution of Sampling Frequencies (FS)')
        plt.xlabel('Sampling Frequency (Hz)')
        plt.ylabel('Count')
        plt.show()
        
        # B. Signal Length Verification
        train_meta['expected_len'] = train_meta['fs'] * 10
        print("\n--- Signal Length Verification ---")
        verification_match = (train_meta['sig_len'] == train_meta['expected_len']).all()
        print(f"All sig_len match 10s * fs: {verification_match}")
        
        # C. Unique IDs
        if train_meta['id'].duplicated().any():
            print(f"Warning: Found {train_meta['id'].duplicated().sum()} duplicate IDs.")
        else:
            print("No duplicate IDs found.")
            
    except Exception as e:
        print(f"An error occurred during metadata analysis: {e}")
        return
        
    # --- Step 2: Ground Truth Signal Analysis (Sampling) ---
    print("\n\n" + "=" * 60)
    print("STEP 2: GROUND TRUTH SIGNAL ANALYSIS")
    print("=" * 60)

    # Filter for IDs that have directories (to avoid errors on partial downloads)
    available_ids = [int(p.name) for p in DATA_DIR.iterdir() if p.is_dir() and p.name.isdigit()]
    
    if not available_ids:
        print(f"Error: No data directories found in {DATA_DIR}. Cannot perform signal/image analysis.")
        return

    # Sample a few IDs for plotting
    sample_ids = random.sample(available_ids, min(sample_size_signal, len(available_ids)))
    
    all_leads_data = []

    for id_ in sample_ids:
        meta_row = train_meta[train_meta['id'] == id_].iloc[0]
        fs = meta_row['fs']
        
        signal_df = load_signal(id_)
        
        if signal_df.empty:
            continue
            
        print(f"\nProcessing ID {id_}: FS={fs}, Length={len(signal_df)} samples.")
        
        # A. Plot the signal
        plot_ecg_signal(signal_df, id_, fs)
        
        # B. Lead-specific length check (Important Insight)
        print(f"Lead lengths at FS={fs}:")
        for lead in LEADS:
            if lead in signal_df.columns:
                length = len(signal_df[lead])
                duration = length / fs
                expected_duration = 10.0 if lead == 'II' else 2.5
                print(f"  {lead}: {length} samples ({duration:.1f}s). Expected: {expected_duration:.1f}s.")
                
                # For combined amplitude analysis
                temp_df = signal_df[[lead]].copy()
                temp_df['Lead'] = lead
                temp_df['Value'] = temp_df[lead]
                all_leads_data.append(temp_df[['Lead', 'Value']])
                
    # C. Combined Amplitude Distribution
    if all_leads_data:
        combined_df = pd.concat(all_leads_data)
        plt.figure(figsize=(12, 6))
        sns.boxplot(x='Lead', y='Value', data=combined_df, palette='Spectral')
        plt.title('Amplitude (mV) Distribution Across 12 Leads (Sampled)')
        plt.ylabel('Amplitude (mV)')
        plt.xlabel('ECG Lead')
        plt.show()


    # --- Step 3: Image Degradation Comparison ---
    print("\n\n" + "=" * 60)
    print("STEP 3: IMAGE DEGRADATION ANALYSIS")
    print("=" * 60)

    # Use the first sampled ID for consistency
    if sample_ids:
        id_ = sample_ids[0]
        print(f"Visualizing all image variants for sample ID {id_}")
        
        fig, axes = plt.subplots(4, 3, figsize=(15, 20))
        axes = axes.flatten()
        
        for i, (suffix, title) in enumerate(IMAGE_VARIANTS.items()):
            if i >= len(axes): break # safety break
            
            img = load_image(id_, suffix)
            ax = axes[i]
            
            if img:
                ax.imshow(img, aspect='auto')
                ax.set_title(f'({suffix}) {title}', fontsize=10)
                ax.axis('off')
            else:
                ax.set_title(f'({suffix}) Image Missing', fontsize=10)
                ax.axis('off')

        # Hide unused subplots
        for j in range(len(IMAGE_VARIANTS), len(axes)):
            fig.delaxes(axes[j])
            
        fig.suptitle(f'Image Degradation Variants (ID: {id_})', fontsize=16, y=1.01)
        plt.tight_layout(rect=[0, 0, 1, 0.98])
        plt.show()
        
        print("\nKey Visual Observations:")
        print(" - Observe grid line visibility, color degradation, rotation, and artifacts.")
        print(" - The original image (0001) is the cleanest reference for the signal location.")
        print(" - Types like 0005 (photo) often have perspective distortion and poor contrast.")
        print(" - Types 0009/0010/0011/0012 show severe artifacts (stains, mold, damage).")

In [None]:
run_eda()