# Example: Modular Analysis Framework

This notebook demonstrates the new modular scripts for FTIR/HIPS/Aethalometer analysis.

## Key Features:
- **PlotConfig**: Set defaults once, apply to all plots
- **FilterId matching**: Match by physical filter, not just date
- **Flexible layouts**: Individual, grid, or combined plots
- **Site selection**: Plot all sites or specific ones
- **Outlier exclusion**: Traceable, transparent flagging with before/after comparison
- **Data modes**: Choose 'all' data or 'common' (only samples with ALL measurements)

## 1. Setup and Imports

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

# Get the directory where the notebook is located and add scripts to path
notebook_dir = os.path.dirname(os.path.abspath('__file__'))
scripts_path = os.path.join(notebook_dir, 'scripts')
if scripts_path not in sys.path:
    sys.path.insert(0, scripts_path)

# Core imports
from config import SITES, MAC_VALUE
from data_matching import (
    load_aethalometer_data, 
    load_filter_data,
    add_base_filter_id,
    match_by_filter_id,
    match_aeth_filter_data,
    match_hips_with_smooth_raw
)
from flow_periods import (
    add_flow_period,
    has_before_after_data,
    print_flow_period_summary
)

# Outlier system imports
from outliers import (
    EXCLUDED_SAMPLES,
    MANUAL_OUTLIERS,
    apply_exclusion_flags,
    apply_threshold_flags,
    get_clean_data,
    print_exclusion_summary,
    identify_outlier_dates
)

# Plotting imports
from plotting import PlotConfig, crossplots, timeseries, distributions, comparisons

print("Imports successful!")
print(f"Scripts loaded from: {scripts_path}")

In [None]:
# =============================================================================
# UTILITY: Before/After Outlier Comparison Plot
# =============================================================================
# This function is available throughout the notebook for any analysis

SITE_COLORS = {'Beijing': '#1f77b4', 'Delhi': '#ff7f0e', 'JPL': '#2ca02c', 'Addis_Ababa': '#d62728'}

def plot_before_after(data_dict, x_col, y_col, xlabel, ylabel, title_prefix, 
                      outlier_col='is_any_outlier', sites=None):
    """
    Create side-by-side before/after plots showing outlier impact.
    
    Parameters:
    -----------
    data_dict : dict of DataFrames with outlier flags
    x_col, y_col : column names for x and y axes
    xlabel, ylabel : axis labels
    title_prefix : prefix for plot titles
    outlier_col : column name for outlier mask (default 'is_any_outlier')
    sites : list of sites to plot, or None for all
    
    Returns dict with before/after stats for each site.
    """
    results = {}
    sites_to_plot = sites if sites else list(data_dict.keys())
    
    for site_name in sites_to_plot:
        if site_name not in data_dict:
            continue
        df = data_dict[site_name]
        if len(df) == 0 or outlier_col not in df.columns:
            continue
            
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        color = SITE_COLORS.get(site_name, '#333333')
        
        x_data = df[x_col].values
        y_data = df[y_col].values
        outlier_mask = df[outlier_col].values
        
        valid = ~np.isnan(x_data) & ~np.isnan(y_data)
        clean_valid = valid & ~outlier_mask
        outlier_valid = valid & outlier_mask
        
        # ===== BEFORE (all data, outliers highlighted) =====
        ax = axes[0]
        ax.scatter(x_data[clean_valid], y_data[clean_valid], c=color, alpha=0.7, s=60, label='Data')
        if outlier_valid.any():
            ax.scatter(x_data[outlier_valid], y_data[outlier_valid], c='red', marker='X', 
                      s=200, linewidths=2, label=f'Outliers ({outlier_valid.sum()})')
        
        # Regression on ALL data
        x_all, y_all = x_data[valid], y_data[valid]
        stats_before = {}
        if len(x_all) > 2:
            slope, intercept, r_value, _, _ = stats.linregress(x_all, y_all)
            stats_before = {'r2': r_value**2, 'slope': slope, 'n': len(x_all)}
            x_line = np.linspace(x_all.min(), x_all.max(), 100)
            ax.plot(x_line, slope * x_line + intercept, 'b-', lw=2, 
                   label=f'Fit: R²={r_value**2:.3f}, slope={slope:.2f}')
        
        lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]), max(ax.get_xlim()[1], ax.get_ylim()[1])]
        ax.plot(lims, lims, 'k--', alpha=0.5, label='1:1')
        ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
        ax.set_title(f'{title_prefix} - {site_name}\nBEFORE (n={valid.sum()})')
        ax.legend(loc='upper left', fontsize=9); ax.grid(True, alpha=0.3)
        
        # ===== AFTER (clean data, outliers faded) =====
        ax = axes[1]
        ax.scatter(x_data[clean_valid], y_data[clean_valid], c=color, alpha=0.7, s=60, label='Retained')
        if outlier_valid.any():
            ax.scatter(x_data[outlier_valid], y_data[outlier_valid], c='red', marker='X', 
                      s=100, alpha=0.3, label=f'Removed ({outlier_valid.sum()})')
        
        # Regression on CLEAN data only
        x_clean, y_clean = x_data[clean_valid], y_data[clean_valid]
        stats_after = {}
        if len(x_clean) > 2:
            slope, intercept, r_value, _, _ = stats.linregress(x_clean, y_clean)
            stats_after = {'r2': r_value**2, 'slope': slope, 'n': len(x_clean)}
            x_line = np.linspace(x_clean.min(), x_clean.max(), 100)
            ax.plot(x_line, slope * x_line + intercept, 'g-', lw=2, 
                   label=f'Fit: R²={r_value**2:.3f}, slope={slope:.2f}')
        
        ax.plot(lims, lims, 'k--', alpha=0.5, label='1:1')
        ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
        ax.set_title(f'{title_prefix} - {site_name}\nAFTER (n={clean_valid.sum()})')
        ax.legend(loc='upper left', fontsize=9); ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        results[site_name] = {'before': stats_before, 'after': stats_after}
    
    return results

print("Utility function plot_before_after() defined and ready to use!")

In [None]:
# =============================================================================
# UTILITY: Apply Outlier Flags to Matched Data
# =============================================================================

def apply_all_outlier_flags(data_dict, aeth_col='ir_bcc', filter_col='hips_fabs', 
                            convert_to_ng=True, verbose=True):
    """
    Apply both date-based and threshold-based outlier flags to matched data.
    
    Parameters:
    -----------
    data_dict : dict of site_name -> DataFrame
    aeth_col : aethalometer BC column name
    filter_col : filter measurement column name  
    convert_to_ng : if True, multiply by 1000 to convert ug/m3 to ng/m3 for thresholds
    verbose : if True, print summary for each site
    
    Returns:
    --------
    dict with outlier flags added (is_excluded, is_outlier, is_any_outlier)
    """
    flagged_data = {}
    
    for site_name, df in data_dict.items():
        df_flagged = df.copy()
        
        # Create columns for threshold detection
        multiplier = 1000 if convert_to_ng else 1
        df_flagged['aeth_bc'] = df_flagged[aeth_col] * multiplier
        df_flagged['filter_ec'] = df_flagged[filter_col] * multiplier
        
        # Apply date-based exclusions
        df_flagged = apply_exclusion_flags(df_flagged, site_name)
        
        # Apply threshold-based outlier flags
        df_flagged = apply_threshold_flags(df_flagged, site_name)
        
        # Create combined mask
        df_flagged['is_any_outlier'] = df_flagged['is_excluded'] | df_flagged['is_outlier']
        
        flagged_data[site_name] = df_flagged
        
        if verbose:
            print_exclusion_summary(df_flagged, site_name)
    
    return flagged_data


# =============================================================================
# UTILITY: Filter to Common Data (samples with ALL measurements)
# =============================================================================

def filter_common_samples(data_dict, required_cols, verbose=True):
    """
    Filter to only samples that have data across ALL specified columns.
    
    Use this to get the "common denominator" - only samples where ALL three
    measurement types (HIPS, FTIR, Aethalometer) have valid data.
    
    Parameters:
    -----------
    data_dict : dict of site_name -> DataFrame
    required_cols : list of column names that must ALL have non-null values
                   e.g., ['ir_bcc', 'hips_fabs', 'ftir_ec']
    verbose : if True, print summary
    
    Returns:
    --------
    dict with filtered DataFrames (only rows where all required_cols are non-null)
    """
    filtered_data = {}
    
    for site_name, df in data_dict.items():
        # Check which columns exist in this DataFrame
        available_cols = [col for col in required_cols if col in df.columns]
        missing_cols = [col for col in required_cols if col not in df.columns]
        
        if missing_cols:
            if verbose:
                print(f"{site_name}: Missing columns {missing_cols}, skipping")
            continue
        
        # Filter to rows where ALL required columns have data
        mask = df[available_cols].notna().all(axis=1)
        df_common = df[mask].copy()
        
        filtered_data[site_name] = df_common
        
        if verbose:
            print(f"{site_name}: {len(df)} total -> {len(df_common)} common ({len(df_common)/len(df)*100:.1f}%)")
    
    return filtered_data


def get_data_mode(data_dict, mode='all', required_cols=None, verbose=False):
    """
    Get data in specified mode: 'all' (all available) or 'common' (common denominator).
    
    Parameters:
    -----------
    data_dict : dict of site_name -> DataFrame
    mode : 'all' or 'common'
           - 'all': Return all available data (default)
           - 'common': Return only samples with ALL required_cols present
    required_cols : list of columns required for 'common' mode
                   Default: ['ir_bcc', 'hips_fabs', 'ftir_ec'] for HIPS/FTIR/Aeth
    verbose : if True, print summary
    
    Returns:
    --------
    dict with DataFrames filtered according to mode
    """
    if mode == 'all':
        return data_dict
    elif mode == 'common':
        if required_cols is None:
            required_cols = ['ir_bcc', 'hips_fabs', 'ftir_ec']
        return filter_common_samples(data_dict, required_cols, verbose=verbose)
    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'all' or 'common'")


print("Utility functions defined:")
print("  - apply_all_outlier_flags()")
print("  - filter_common_samples()")
print("  - get_data_mode()")

## 2. Configure Plot Defaults

Set these once at the top of your notebook. All subsequent plots will use these settings.

In [None]:
# Option 1: Plot all sites, individual figures
PlotConfig.set(
    sites='all',
    layout='individual',
    figsize=(10, 8),
    show_stats=True,
    show_1to1=True
)

# Show current settings
PlotConfig.show()

In [None]:
# Option 2: Plot specific sites in a grid
# PlotConfig.set(
#     sites=['Beijing', 'JPL'],
#     layout='grid'
# )

# Option 3: Just one site
# PlotConfig.set(sites='JPL', layout='individual')

## 3. Load Data

In [None]:
# Load aethalometer data (all sites)
aethalometer_data = load_aethalometer_data()

In [None]:
# Load filter data
filter_data = load_filter_data()

# Add base_filter_id for proper matching
filter_data = add_base_filter_id(filter_data)
print(f"\nAdded base_filter_id column")
print(f"Example: {filter_data['FilterId'].iloc[0]} -> {filter_data['base_filter_id'].iloc[0]}")

## 4. Match Data by FilterId (Recommended)

This ensures you're comparing measurements from the **same physical filter**.

In [None]:
# Match FTIR EC and HIPS by FilterId for each site
matched_by_filter = {}

for site_name, config in SITES.items():
    site_code = config['code']
    
    matched = match_by_filter_id(
        filter_data, 
        site_code=site_code,
        params=['EC_ftir', 'HIPS_Fabs', 'ChemSpec_Iron_PM2.5']
    )
    
    if matched is not None:
        # Convert HIPS to BC equivalent (divide by MAC)
        matched['hips_fabs'] = matched['hips_fabs'] / MAC_VALUE
        matched_by_filter[site_name] = matched
        print(f"{site_name}: {len(matched)} filters with matched data")
    else:
        print(f"{site_name}: No matched data")

## 5. Time Series Plots

In [None]:
# BC time series - uses PlotConfig defaults (all sites, individual)
timeseries.bc(aethalometer_data, wavelength='IR')

In [None]:
# Override to show all sites in a grid
timeseries.bc(aethalometer_data, wavelength='IR', layout='grid')

In [None]:
# Multi-wavelength BC for each site
timeseries.bc_multiwavelength(aethalometer_data, sites=['JPL', 'Beijing'])

In [None]:
# Flow ratio over time
timeseries.flow_ratio(aethalometer_data, layout='grid')

## 6. Cross-Plots (Scatter)

In [None]:
# HIPS vs FTIR EC - should be close to 1:1 if same physical filter
results = crossplots.scatter(
    matched_by_filter,
    x_col='ftir_ec',
    y_col='hips_fabs',
    xlabel='FTIR EC (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)',
    title='FTIR EC vs HIPS (same filter)',
    layout='grid'
)

In [None]:
# With iron as color gradient
crossplots.with_iron_gradient(
    matched_by_filter,
    x_col='ftir_ec',
    y_col='hips_fabs',
    xlabel='FTIR EC (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)',
    title='FTIR vs HIPS (colored by Iron)',
    sites=['JPL']  # Just one site for this example
)

## 7. Match Aethalometer with HIPS (including smooth/raw info)

In [None]:
# Match HIPS with aethalometer data (by date) AND apply outlier flags immediately
hips_aeth_matched = {}

for site_name, config in SITES.items():
    if site_name not in aethalometer_data:
        continue
        
    matched = match_hips_with_smooth_raw(
        site_name,
        aethalometer_data[site_name],
        filter_data,
        config['code']
    )
    
    if matched is not None:
        hips_aeth_matched[site_name] = matched
        print(f"{site_name}: {len(matched)} matched pairs")

# Apply outlier flags to all matched data
print("\n" + "=" * 60)
print("APPLYING OUTLIER FLAGS")
print("=" * 60)
hips_aeth_matched = apply_all_outlier_flags(hips_aeth_matched)

In [None]:
# Example: Before/After outlier comparison (available for any analysis)
# Shows impact of outlier removal on regression statistics
plot_before_after(
    hips_aeth_matched,
    x_col='ir_bcc',
    y_col='hips_fabs',
    xlabel='Aethalometer IR BCc (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)',
    title_prefix='HIPS vs Aethalometer',
    sites=['JPL']  # Just one site for demo; remove this line for all sites
)

In [None]:
# HIPS vs Aethalometer crossplot
crossplots.scatter(
    hips_aeth_matched,
    x_col='ir_bcc',
    y_col='hips_fabs',
    xlabel='Aethalometer IR BCc (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)',
    title='HIPS vs Aethalometer',
    layout='grid'
)

## 8. Distributions

In [None]:
# BC distribution boxplot across sites
distributions.bc_boxplot(aethalometer_data, wavelength='IR')

In [None]:
# Smooth/raw difference histogram
distributions.smooth_raw_histogram(
    hips_aeth_matched,
    col='smooth_raw_abs_pct',
    thresholds=[1, 2.5, 4, 5],
    layout='grid'
)

## 9. Flow Period Analysis

In [None]:
# Show which sites have before/after data
print_flow_period_summary()

In [None]:
# Add flow period to matched data
for site_name, df in hips_aeth_matched.items():
    hips_aeth_matched[site_name] = add_flow_period(df, site_name, date_col='date')
    periods = hips_aeth_matched[site_name]['flow_period'].value_counts()
    print(f"{site_name}: {dict(periods)}")

In [None]:
# Flow period comparison (only for sites with before/after data)
# JPL is the only site with data in both periods
if 'JPL' in hips_aeth_matched:
    comparisons.flow_periods(
        {'JPL': hips_aeth_matched['JPL']},
        x_col='ir_bcc',
        y_col='hips_fabs',
        period_col='flow_period',
        xlabel='Aethalometer IR BCc (µg/m³)',
        ylabel='HIPS Fabs / MAC (µg/m³)'
    )

In [None]:
# Analyze effect of smooth/raw thresholds
comparisons.threshold_analysis(
    hips_aeth_matched,
    x_col='ir_bcc',
    y_col='hips_fabs',
    threshold_col='smooth_raw_abs_pct',
    thresholds=[1, 2.5, 4, 5],
    sites=['JPL'],  # Just one site for clarity
    xlabel='Aethalometer IR BCc (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)'
)

## 10. Data Mode: All vs Common Samples

Choose between two data modes:
- **'all'**: Use all available data (some samples may only have 2 of 3 measurements)
- **'common'**: Only samples with ALL three measurements (HIPS, FTIR, Aethalometer)

The "common" mode ensures you're comparing the exact same physical filters across all three instruments.

In [None]:
# Compare data availability: ALL vs COMMON mode
print("=" * 70)
print("DATA MODE COMPARISON")
print("=" * 70)

# Get data in both modes
data_all = get_data_mode(hips_aeth_matched, mode='all')
print("\nCOMMON mode (requires: ir_bcc, hips_fabs, ftir_ec):")
data_common = get_data_mode(hips_aeth_matched, mode='common', 
                            required_cols=['ir_bcc', 'hips_fabs', 'ftir_ec'],
                            verbose=True)

In [None]:
# Example: Before/After with COMMON data only
# This uses only samples that have ALL three measurements
print("Using COMMON data (samples with HIPS + FTIR + Aethalometer):")
plot_before_after(
    data_common,
    x_col='ir_bcc',
    y_col='hips_fabs',
    xlabel='Aethalometer IR BCc (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)',
    title_prefix='HIPS vs Aethalometer (COMMON)',
    sites=['JPL']
)

In [None]:
# Side-by-side comparison: ALL vs COMMON data
def compare_data_modes(data_all, data_common, x_col, y_col, xlabel, ylabel, site='JPL'):
    """Compare regression stats between ALL and COMMON data modes."""
    if site not in data_all or site not in data_common:
        print(f"Site {site} not in both datasets")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    color = SITE_COLORS.get(site, '#333333')
    
    for ax, (mode, data) in zip(axes, [('ALL', data_all), ('COMMON', data_common)]):
        df = data[site]
        x = df[x_col].dropna()
        y = df.loc[x.index, y_col].dropna()
        common = x.index.intersection(y.index)
        x, y = df.loc[common, x_col], df.loc[common, y_col]
        
        ax.scatter(x, y, c=color, alpha=0.7, s=60)
        
        if len(x) > 2:
            slope, intercept, r_value, _, _ = stats.linregress(x, y)
            x_line = np.linspace(x.min(), x.max(), 100)
            ax.plot(x_line, slope * x_line + intercept, 'b-', lw=2,
                   label=f'R²={r_value**2:.3f}, slope={slope:.2f}')
        
        lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]), 
                max(ax.get_xlim()[1], ax.get_ylim()[1])]
        ax.plot(lims, lims, 'k--', alpha=0.5, label='1:1')
        ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
        ax.set_title(f'{site} - {mode} data (n={len(x)})')
        ax.legend(loc='upper left'); ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Compare ALL vs COMMON for JPL
compare_data_modes(
    data_all, data_common,
    x_col='ir_bcc', y_col='hips_fabs',
    xlabel='Aethalometer IR BCc (µg/m³)',
    ylabel='HIPS Fabs / MAC (µg/m³)',
    site='JPL'
)

## 11. Outlier Utilities

Outlier flags were already applied during data matching (Section 7). Here are utilities for working with the flagged data.

In [None]:
# View what outlier rules are defined (in scripts/outliers.py)
print("DATE-BASED EXCLUSIONS:")
for site, exclusions in EXCLUDED_SAMPLES.items():
    print(f"  {site}: {len(exclusions)} exclusions")

print("\nTHRESHOLD-BASED RULES:")
for site, config in MANUAL_OUTLIERS.items():
    print(f"  {site}: {config['description']}")

In [None]:
# Extract clean data for each site
clean_data = {}

for site_name, df in hips_aeth_matched.items():
    # get_clean_data() removes both is_excluded and is_outlier flagged rows
    clean_df = get_clean_data(df)
    clean_data[site_name] = clean_df
    
    print(f"{site_name}:")
    print(f"  Original: {len(df)} points")
    print(f"  After exclusions: {len(clean_df)} points")
    print(f"  Removed: {len(df) - len(clean_df)} ({100*(len(df)-len(clean_df))/len(df):.1f}%)\n")

## 12. Quick Reference: Available Functions

### PlotConfig
```python
PlotConfig.set(sites='all', layout='grid')  # Set defaults
PlotConfig.show()                            # Show current settings
PlotConfig.reset()                           # Reset to defaults
```

### Data Modes (NEW)
```python
# Get all available data (default)
data = get_data_mode(matched_data, mode='all')

# Get only samples with ALL measurements (common denominator)
data = get_data_mode(matched_data, mode='common', 
                     required_cols=['ir_bcc', 'hips_fabs', 'ftir_ec'])

# Or use filter_common_samples directly
data = filter_common_samples(matched_data, ['ir_bcc', 'hips_fabs', 'ftir_ec'])
```

### crossplots
```python
crossplots.scatter(data, x_col, y_col)      # Generic scatter
crossplots.bc_vs_ec(data)                   # Preset: Aeth BC vs Filter EC
crossplots.hips_vs_ftir(data)               # Preset: HIPS vs FTIR
crossplots.with_iron_gradient(data, ...)   # Color by iron concentration
```

### timeseries
```python
timeseries.bc(data, wavelength='IR')        # BC time series
timeseries.bc_multiwavelength(data)         # All wavelengths
timeseries.flow_ratio(data)                 # Flow ratio over time
timeseries.data_completeness(data)          # Data availability
```

### distributions
```python
distributions.bc_boxplot(data)              # BC distribution
distributions.smooth_raw_histogram(data)    # Smooth/raw difference
distributions.correlation_matrix(data, cols) # Correlation heatmap
```

### comparisons
```python
comparisons.before_after_outliers(data)     # Outlier removal impact
comparisons.threshold_analysis(data, ...)   # Test different thresholds
comparisons.flow_periods(data)              # Before/after flow fix
```

### data_matching
```python
add_base_filter_id(filter_data)             # Strip -N suffix
match_by_filter_id(data, site_code, params) # Match by physical filter
match_hips_with_smooth_raw(...)             # HIPS + aethalometer
```

### flow_periods
```python
add_flow_period(df, site_name)              # Add flow_period column
has_before_after_data(site_name)            # Check data availability
print_flow_period_summary()                 # Show all sites status
```

### outliers
```python
# Apply flags (adds is_excluded, is_outlier columns)
apply_exclusion_flags(df, site_name)        # Date-based exclusions
apply_threshold_flags(df, site_name)        # Threshold-based outliers
apply_all_outlier_flags(data_dict)          # Apply both to dict of DataFrames

# Extract data
get_clean_data(df)                          # Returns non-outlier rows only

# Utilities
print_exclusion_summary(df, site_name)      # Show exclusion counts
identify_outlier_dates(site, df, criteria)  # Find dates to add to registry
```

### Typical Workflow
```python
# 1. Load and match data
hips_aeth_matched = {...}

# 2. Apply outlier flags
hips_aeth_matched = apply_all_outlier_flags(hips_aeth_matched)

# 3. Choose data mode
data_all = get_data_mode(hips_aeth_matched, mode='all')
data_common = get_data_mode(hips_aeth_matched, mode='common')

# 4. Plot with before/after comparison
plot_before_after(data_common, x_col, y_col, ...)

# 5. Get clean data for analysis
clean_df = get_clean_data(df)
```