# Cohort-Specific FPGrowth Feature Importance Analysis

**Purpose:** Cohort-level frequent pattern mining for process mining and comparative analysis  
**Updated:** November 23, 2025  
**Hardware:** Optimized for EC2 (32 cores, 1TB RAM)  
**Output:** `s3://pgxdatalake/gold/fpgrowth/cohort/{item_type}/cohort_name={cohort}/...`

Key Features

‚úÖ **Three Item Types** - Drugs, ICD codes, CPT codes  
‚úÖ **Cohort-Specific Patterns** - Discovers patterns unique to each cohort  
‚úÖ **Parallel Processing** - Processes multiple cohorts simultaneously  
‚úÖ **BupaR Integration** - Outputs ready for process mining workflows  
‚úÖ **Comparative Analysis** - Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts

Methodology

For each combination of (cohort, age_band, event_year, item_type):
1. Extract items from cohort-specific data
2. Create patient-level transactions
3. Encode transactions into binary matrix
4. Run FP-Growth to find frequent itemsets
5. Generate association rules
6. Save results to S3 in organized structure

Key Differences from Global Analysis

| Aspect | Global FPGrowth | Cohort FPGrowth |
|--------|-----------------|-----------------|
| **Scope** | All patients (~5.7M) | Individual cohorts (~10K-100K) |
| **Purpose** | Universal ML features | Process mining patterns |
| **Support Threshold** | 0.01 (1%) | 0.05 (5%) |
| **Output** | `global/{item_type}/` | `cohort/{item_type}/cohort_name={c}/...` |
| **Use Case** | CatBoost consistency | BupaR pathway analysis |
| **Parallelization** | Sequential by item type | Parallel by cohort |

Expected Runtime (EC2: 32 cores, 1TB RAM)

- **Cohorts**: 2 (opioid_ed, ed_non_opioid)
- **Age bands √ó Years**: ~100 combinations per cohort
- **Item types**: 3 (drug_name, icd_code, cpt_code)
- **Total jobs**: ~600 combinations
- **Avg time per job**: ~1-2 minutes
- **Total runtime**: ~2-4 hours (with MAX_WORKERS=4)

S3 Output Structure

```
s3://pgxdatalake/gold/fpgrowth/cohort/
‚îú‚îÄ‚îÄ drug_name/
‚îÇ   ‚îú‚îÄ‚îÄ cohort_name=opioid_ed/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ age_band=65-74/event_year=2020/
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ itemsets.json
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ rules.json
‚îÇ   ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ summary.json
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îÇ   ‚îî‚îÄ‚îÄ cohort_name=ed_non_opioid/...
‚îú‚îÄ‚îÄ icd_code/
‚îÇ   ‚îî‚îÄ‚îÄ (same structure)
‚îî‚îÄ‚îÄ cpt_code/
    ‚îî‚îÄ‚îÄ (same structure)
```

---


## Environment Setup and Imports


In [1]:
import os
import sys
import json
import pandas as pd
import numpy as np
from datetime import datetime
import logging
from pathlib import Path
import psutil
import duckdb
import time
from concurrent.futures import ProcessPoolExecutor, as_completed

# MLxtend for FP-Growth
from mlxtend.frequent_patterns import fpgrowth, association_rules
from mlxtend.preprocessing import TransactionEncoder

# Project root
project_root = Path.cwd().parent if Path.cwd().name == '3_fpgrowth_analysis' else Path.cwd()
sys.path.insert(0, str(project_root))

# Project utilities
from helpers_1997_13.common_imports import s3_client, S3_BUCKET
from helpers_1997_13.duckdb_utils import get_duckdb_connection
from helpers_1997_13.s3_utils import save_to_s3_json, save_to_s3_parquet, get_cohort_parquet_path
from helpers_1997_13.fpgrowth_utils import run_fpgrowth_drug_token_with_fallback, convert_frozensets
from helpers_1997_13.visualization_utils import create_network_visualization
from helpers_1997_13.constants import AGE_BANDS, EVENT_YEARS

print(f"‚úì Project root: {project_root}")
print(f"‚úì All imports successful")
print(f"‚úì Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


‚úì Project root: /home/pgx3874/pgx-analysis
‚úì All imports successful
‚úì Timestamp: 2025-11-24 12:27:05


## EC2 Configuration

In [None]:
# =============================================================================
# EC2 CONFIGURATION (32 cores, 1TB RAM)
# =============================================================================

# FP-Growth parameters (higher threshold for cohort-specific patterns)
MIN_SUPPORT = 0.05       # 5% support (items must appear in 5% of patients within cohort)
MIN_CONFIDENCE = 0.5     # 50% confidence - only strong associations

# CPT-specific parameters (prevent memory exhaustion from millions of rules)
MIN_SUPPORT_CPT = 0.15   # 15% support for CPT codes (focuses on common patterns)
MIN_CONFIDENCE_CPT = 0.6 # 60% confidence for CPT (very strong associations only)

# Rule limits (focus on most important rules)
MAX_RULES_PER_COHORT = 1000  # Keep top 1000 rules by lift (practical limit)

# Itemset filtering (remove common/trivial itemsets)
MIN_ITEMSET_LIFT = 1.1  # Filter itemsets with lift < 1.1 (items are independent/not interesting)

# Target-focused rule mining (NEW!)
TARGET_FOCUSED = True  # Only generate rules that predict target outcomes
TARGET_ICD_CODES = ['F11.20', 'F11.21', 'F11.22', 'F11.23', 'F11.24', 'F11.25', 'F11.29']  # Opioid dependence codes
TARGET_HCG_LINES = [
    "P51 - ER Visits and Observation Care",
    "O11 - Emergency Room",
    "P33 - Urgent Care Visits"
]  # ED visits (HCG Line codes - matches phase2_event_processing.py)
TARGET_PREFIXES = ['TARGET_ICD:', 'TARGET_ED:']  # Prefixes for target items in transactions

# Item types to process
# NOTE: drug_name processed separately (pharmacy events)
#       icd_code + cpt_code combined as 'medical_codes' (both from medical events)
ITEM_TYPES = ['drug_name', 'medical_codes']  # Changed: combine ICD + CPT into medical_codes

# Processing parameters
MAX_WORKERS = 1  # Sequential processing to prevent memory issues

# DEBUG MODE (enable detailed logging)
DEBUG_MODE = True  # Set to True for DEBUG level logging, False for INFO level

# DRY RUN MODE (test with limited cohorts first)
DRY_RUN = True  # Set to False to process all cohorts
DRY_RUN_LIMIT = 5  # Number of cohort combinations to process in dry run
COHORTS_TO_PROCESS = ['opioid_ed', 'non_opioid_ed']  # Specify cohorts to process

# Paths
S3_OUTPUT_BASE = f"s3://{S3_BUCKET}/gold/fpgrowth/cohort"
LOCAL_DATA_PATH = Path("/mnt/nvme/cohorts")  # Instance storage (NVMe SSD for fast I/O)

# Setup logger with file output (prevents Jupyter rate limit issues)
logger = logging.getLogger('cohort_fpgrowth')

# Set logger level based on DEBUG_MODE
if DEBUG_MODE:
    logger.setLevel(logging.DEBUG)
    log_level_str = "DEBUG"
else:
    logger.setLevel(logging.INFO)
    log_level_str = "INFO"

logger.handlers.clear()  # Clear any existing handlers

# File handler - full logs to file (always captures DEBUG if DEBUG_MODE is on)
log_file = project_root / "3_fpgrowth_analysis" / "logs" / "cohort_fpgrowth_execution.log"
log_file.parent.mkdir(exist_ok=True)  # Create logs directory if it doesn't exist
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG if DEBUG_MODE else logging.INFO)  # Capture DEBUG in debug mode
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s'))
logger.addHandler(file_handler)

# Console handler - only major milestones (prevents Jupyter rate limit)
console_handler = logging.StreamHandler()
if DEBUG_MODE:
    console_handler.setLevel(logging.DEBUG)  # Show DEBUG in console too when in debug mode
else:
    console_handler.setLevel(logging.WARNING)  # Only warnings/errors to console
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(console_handler)

print(f"‚úì Min Support (drug_name): {MIN_SUPPORT} (5%)")
print(f"‚úì Min Support (medical_codes): {MIN_SUPPORT_CPT} (15% - ICD+CPT combined)")
print(f"‚úì Min Confidence (drug_name): {MIN_CONFIDENCE} (50% - strong associations)")
print(f"‚úì Min Confidence (medical_codes): {MIN_CONFIDENCE_CPT} (60% - very strong associations)")
print(f"‚úì Max Rules per Cohort: {MAX_RULES_PER_COHORT:,} (top rules by lift)")
print(f"‚úì Min Itemset Lift: {MIN_ITEMSET_LIFT} (filters common/trivial itemsets)")
print(f"‚úì Item Types: {ITEM_TYPES}")
print(f"‚úì Max Workers: {MAX_WORKERS} (sequential - prevents OOM)")
if DRY_RUN:
    print(f"‚úì DRY RUN MODE: Processing only {DRY_RUN_LIMIT} cohort combinations")
    print(f"  ‚Üí Set DRY_RUN = False to process all cohorts")
else:
    print(f"‚úì FULL RUN MODE: Processing all cohorts")
print(f"‚úì Cohorts: {COHORTS_TO_PROCESS}")
print(f"‚úì S3 Output: {S3_OUTPUT_BASE}")
print(f"‚úì S3 Retry: 3 attempts with exponential backoff")
print(f"‚úì Local Data: {LOCAL_DATA_PATH}")
print(f"‚úì Local Data Exists: {LOCAL_DATA_PATH.exists()}")
print(f"‚úì Logging Level: {log_level_str}")
print(f"‚úì DEBUG Mode: {'ENABLED' if DEBUG_MODE else 'DISABLED'}")
print(f"‚úì Detailed logs ‚Üí {log_file}")
print(f"‚úì Log directory: {log_file.parent}")
print(f"‚úì Log file exists: {log_file.exists()}")
if DEBUG_MODE:
    print(f"‚úì Console output: DEBUG level (all messages)")
else:
    print(f"‚úì Console output: WARNING level only (check log file for progress)")
print("\nüéØ Quality Over Quantity Approach:")
print("  - High confidence thresholds (50-60%) = meaningful patterns only")
print("  - CPT uses 15% support (vs 5%) = focuses on common procedures")
print("  - Top 1,000 rules by lift = actionable insights, not exhaustive lists")
print("  - 2 parallel workers = stable memory usage")
print(f"\nüéØ TARGET-FOCUSED RULE MINING: {'ENABLED' if TARGET_FOCUSED else 'DISABLED'}")
if TARGET_FOCUSED:
    print(f"  - Target ICD codes: {TARGET_ICD_CODES}")
    print(f"  - Target HCG lines (ED visits): {TARGET_HCG_LINES}")
    print("  - Only generates rules that PREDICT target outcomes")
    print("  - Example: {Metoprolol, Gabapentin} ‚Üí {TARGET_ICD:OPIOID_DEPENDENCE}")
    print("  - Example: {99213: Office Visit, J0670: Morphine} ‚Üí {TARGET_ED:EMERGENCY_DEPT}")
    print("  ‚úÖ Drastically reduces rule count (only predictive patterns)")
    print("  ‚úÖ More actionable for BupaR (pathways to target)")
    print("  ‚úÖ Better for CatBoost (features that predict outcome)")


‚úì Min Support (drug/ICD): 0.05 (5%)
‚úì Min Support (CPT): 0.15 (15% - focuses on common patterns)
‚úì Min Confidence (drug/ICD): 0.5 (50% - strong associations)
‚úì Min Confidence (CPT): 0.6 (60% - very strong associations)
‚úì Max Rules per Cohort: 1,000 (top rules by lift)
‚úì Item Types: ['drug_name', 'icd_code', 'cpt_code']
‚úì Max Workers: 1 (sequential - prevents OOM)
‚úì DRY RUN MODE: Processing only 5 cohort combinations
  ‚Üí Set DRY_RUN = False to process all cohorts
‚úì Cohorts: ['opioid_ed', 'ed_non_opioid']
‚úì S3 Output: s3://pgxdatalake/gold/fpgrowth/cohort
‚úì S3 Retry: 3 attempts with exponential backoff
‚úì Local Data: /mnt/nvme/cohorts
‚úì Local Data Exists: True
‚úì Detailed logs ‚Üí /home/pgx3874/pgx-analysis/3_fpgrowth_analysis/cohort_fpgrowth_execution.log

üéØ Quality Over Quantity Approach:
  - High confidence thresholds (50-60%) = meaningful patterns only
  - CPT uses 15% support (vs 5%) = focuses on common procedures
  - Top 1,000 rules by lift = actionab

## Memory Monitoring

Helper function to log memory usage at critical points.


In [None]:
# Diagnostic: Check what cohorts actually exist
print("üîç Checking available cohorts...")
available_cohorts = set()
for cohort_dir in LOCAL_DATA_PATH.glob("cohort_name=*"):
    cohort_name = cohort_dir.name.replace("cohort_name=", "")
    available_cohorts.add(cohort_name)
    print(f"  Found: {cohort_name}")

print(f"\nüìã Summary:")
print(f"  Available: {sorted(available_cohorts)}")
print(f"  Requested: {COHORTS_TO_PROCESS}")
missing = set(COHORTS_TO_PROCESS) - available_cohorts
if missing:
    print(f"  ‚ö†Ô∏è  Missing: {missing}")
    # Check for case mismatches
    cohort_lower_map = {c.lower(): c for c in available_cohorts}
    corrected_cohorts = []
    for m in missing:
        m_lower = m.lower()
        if m_lower in cohort_lower_map:
            actual_name = cohort_lower_map[m_lower]
            print(f"  üí° Case mismatch: '{m}' ‚Üí '{actual_name}'")
            corrected_cohorts.append(actual_name)
    
    if corrected_cohorts:
        print(f"\n  üí° RECOMMENDATION: Update COHORTS_TO_PROCESS to:")
        final_list = [c for c in COHORTS_TO_PROCESS if c not in missing] + corrected_cohorts
        print(f"     COHORTS_TO_PROCESS = {final_list}")
else:
    print(f"  ‚úÖ All requested cohorts found!")


In [None]:
def log_memory(logger, stage="", show_details=False):
    """
    Comprehensive memory monitoring with detailed breakdown.
    
    Args:
        logger: Logger instance
        stage: Description of current pipeline stage
        show_details: If True, show process-level memory breakdown
    """
    try:
        mem = psutil.virtual_memory()
        mem_used_gb = mem.used / (1024**3)
        mem_total_gb = mem.total / (1024**3)
        mem_percent = mem.percent
        mem_avail_gb = mem.available / (1024**3)
        mem_cached_gb = mem.cached / (1024**3) if hasattr(mem, 'cached') else 0
        
        # Get process memory
        process = psutil.Process()
        process_mem = process.memory_info()
        process_rss_gb = process_mem.rss / (1024**3)
        process_vms_gb = process_mem.vms / (1024**3)
        
        logger.info(f"[MEMORY {stage}]")
        logger.info(f"  System: {mem_used_gb:.1f} GB / {mem_total_gb:.1f} GB ({mem_percent:.1f}%) | Available: {mem_avail_gb:.1f} GB")
        logger.info(f"  Process RSS: {process_rss_gb:.2f} GB | VMS: {process_vms_gb:.2f} GB")
        
        # Warning thresholds
        if mem_percent > 90:
            logger.error(f"üö® CRITICAL MEMORY: {mem_percent:.1f}% - OOM imminent!")
        elif mem_percent > 85:
            logger.warning(f"‚ö†Ô∏è  HIGH MEMORY USAGE: {mem_percent:.1f}% - May cause OOM!")
        elif mem_percent > 75:
            logger.warning(f"‚ö†Ô∏è  ELEVATED MEMORY: {mem_percent:.1f}% - Monitor closely")
        
        # Process-level details if requested
        if show_details:
            try:
                # Get top memory-consuming processes
                processes = []
                for p in psutil.process_iter(['pid', 'name', 'memory_info']):
                    try:
                        mem_info = p.info['memory_info']
                        if mem_info:
                            processes.append({
                                'pid': p.info['pid'],
                                'name': p.info['name'],
                                'rss': mem_info.rss / (1024**3)
                            })
                    except (psutil.NoSuchProcess, psutil.AccessDenied):
                        continue
                
                processes.sort(key=lambda x: x['rss'], reverse=True)
                logger.info(f"  Top 5 processes by memory:")
                for proc in processes[:5]:
                    logger.info(f"    {proc['name']} (PID {proc['pid']}): {proc['rss']:.2f} GB")
            except Exception as e:
                logger.debug(f"Could not get process details: {e}")
        
        return {
            'system_percent': mem_percent,
            'system_used_gb': mem_used_gb,
            'system_avail_gb': mem_avail_gb,
            'process_rss_gb': process_rss_gb,
            'process_vms_gb': process_vms_gb
        }
    except Exception as e:
        logger.error(f"Error getting memory info: {e}")
        return {'system_percent': 0.0}

print("‚úì Enhanced memory logging function defined")


def filter_itemsets_by_lift(
    itemsets: pd.DataFrame,
    df_encoded: pd.DataFrame,
    min_lift: float,
    logger: logging.Logger
) -> pd.DataFrame:
    """
    Filter itemsets by lift to remove common/trivial itemsets.
    
    Lift measures how much more likely items are to appear together than by chance.
    Lift = 1.0 means items are independent (not interesting)
    Lift > 1.0 means positive association (interesting)
    Lift < 1.0 means negative association (also interesting, but we filter these out)
    
    Args:
        itemsets: DataFrame with 'itemsets' and 'support' columns
        df_encoded: Encoded transaction DataFrame (needed to calculate individual item supports)
        min_lift: Minimum lift threshold (e.g., 1.1 = 10% more likely than chance)
        logger: Logger instance
    
    Returns:
        Filtered DataFrame with only itemsets above min_lift threshold
    """
    if len(itemsets) == 0:
        return itemsets
    
    logger.info(f"Filtering {len(itemsets):,} itemsets by lift (min_lift={min_lift})...")
    
    # Calculate individual item supports (needed for lift calculation)
    item_supports = {}
    total_transactions = len(df_encoded)
    
    for col in df_encoded.columns:
        item_supports[col] = df_encoded[col].sum() / total_transactions
    
    # Calculate lift for each itemset
    def calculate_lift(row):
        itemset = row['itemsets']
        itemset_support = row['support']
        
        # For single-item itemsets, lift is undefined (or 1.0 by convention)
        if len(itemset) == 1:
            return 1.0  # Single items don't have lift
        
        # For multi-item itemsets: lift = itemset_support / (item1_support * item2_support * ...)
        expected_support = 1.0
        for item in itemset:
            if item in item_supports:
                expected_support *= item_supports[item]
            else:
                # Item not found in transactions (shouldn't happen, but handle gracefully)
                return 0.0
        
        if expected_support == 0:
            return 0.0
        
        lift = itemset_support / expected_support
        
        # CRITICAL: Cap extreme lift values to prevent numerical instability
        # With very small cohorts, expected_support can be extremely small, causing
        # lift values in the billions. Cap at reasonable maximum (e.g., 1000)
        MAX_LIFT = 1000.0
        if lift > MAX_LIFT:
            logger.debug(f"Capping extreme lift {lift:.2e} to {MAX_LIFT} for itemset {itemset}")
            return MAX_LIFT
        
        return lift
    
    itemsets['lift'] = itemsets.apply(calculate_lift, axis=1)
    
    # Filter by lift threshold
    original_count = len(itemsets)
    itemsets_filtered = itemsets[itemsets['lift'] >= min_lift].copy()
    filtered_count = len(itemsets_filtered)
    removed_count = original_count - filtered_count
    
    logger.info(f"  Original itemsets: {original_count:,}")
    logger.info(f"  Filtered itemsets: {filtered_count:,} (lift >= {min_lift})")
    logger.info(f"  Removed common/trivial: {removed_count:,} ({removed_count/original_count*100:.1f}%)")
    
    if filtered_count > 0:
        logger.info(f"  Lift range: {itemsets_filtered['lift'].min():.3f} - {itemsets_filtered['lift'].max():.3f}")
    
    return itemsets_filtered.drop(columns=['lift'])  # Remove lift column (not needed in output)

print("‚úì Itemset filtering by lift function defined")

# =============================================================================
# HOW LIFT FILTERING REMOVES COMMON ITEMSETS
# =============================================================================
"""
Lift measures how much MORE likely items are to appear together than by chance.

Formula: Lift = P(A and B) / (P(A) * P(B))

Interpretation:
- Lift = 1.0 ‚Üí Items are INDEPENDENT (appear together by chance)
- Lift > 1.0 ‚Üí Items have POSITIVE association (more likely together than chance)
- Lift < 1.0 ‚Üí Items have NEGATIVE association (less likely together than chance)

EXAMPLE: Why common itemsets get filtered out

Scenario: 100 transactions, 2 very common items
- Item A appears in 80 transactions (P(A) = 0.8)
- Item B appears in 70 transactions (P(B) = 0.7)
- They appear together in 56 transactions (P(A and B) = 0.56)

Expected if independent: P(A) * P(B) = 0.8 * 0.7 = 0.56
Actual: 0.56
Lift = 0.56 / 0.56 = 1.0 ‚Üí INDEPENDENT (filtered out!)

Even though {A, B} has high support (56%), it's just chance - not interesting.

EXAMPLE: Meaningful association

Scenario: 100 transactions
- Item C appears in 20 transactions (P(C) = 0.2)
- Item D appears in 15 transactions (P(D) = 0.15)
- They appear together in 10 transactions (P(C and D) = 0.10)

Expected if independent: P(C) * P(D) = 0.2 * 0.15 = 0.03
Actual: 0.10
Lift = 0.10 / 0.03 = 3.33 ‚Üí STRONG ASSOCIATION (kept!)

Even though {C, D} has lower support (10%), it's 3.3x more likely than chance.

WHY THIS MATTERS:
- Common items (like "aspirin" or "blood pressure check") appear in many transactions
- They often appear together just because they're both common
- Lift filtering removes these "common but independent" patterns
- Keeps only patterns where items are MORE associated than chance would predict
"""
print("‚úì Lift filtering explanation documented")


‚úì Memory logging function defined


## Step 1: Discover Available Cohorts

Scan local data to find all available cohort combinations.


In [4]:
def discover_cohorts(local_data_path, cohort_filter=None):
    """
    Discover all available cohort combinations from local data.
    """
    cohort_jobs = []
    
    for cohort_dir in local_data_path.glob("cohort_name=*"):
        cohort_name = cohort_dir.name.replace("cohort_name=", "")
        
        # Filter if specified
        if cohort_filter and cohort_name not in cohort_filter:
            continue
        
        for year_dir in cohort_dir.glob("event_year=*"):
            event_year = year_dir.name.replace("event_year=", "")
            
            for age_dir in year_dir.glob("age_band=*"):
                age_band = age_dir.name.replace("age_band=", "")
                
                # Check if cohort file exists
                cohort_file = age_dir / "cohort.parquet"
                if cohort_file.exists():
                    cohort_jobs.append({
                        'cohort': cohort_name,
                        'age_band': age_band,
                        'event_year': event_year,
                        'local_path': str(cohort_file)
                    })
    
    return cohort_jobs

# Discover available cohorts
cohort_jobs = discover_cohorts(LOCAL_DATA_PATH, cohort_filter=COHORTS_TO_PROCESS)

# Apply DRY_RUN limit if enabled
if DRY_RUN and len(cohort_jobs) > DRY_RUN_LIMIT:
    print(f"\n‚ö†Ô∏è  DRY RUN: Limiting from {len(cohort_jobs)} to {DRY_RUN_LIMIT} cohort combinations")
    cohort_jobs = cohort_jobs[:DRY_RUN_LIMIT]

print(f"\nüìä Discovered Cohorts:")
print(f"  Total combinations: {len(cohort_jobs)}")

# Group by cohort
cohort_counts = {}
for job in cohort_jobs:
    cohort_counts[job['cohort']] = cohort_counts.get(job['cohort'], 0) + 1

for cohort, count in cohort_counts.items():
    print(f"  {cohort}: {count} combinations")

print(f"\n  Sample jobs:")
for job in cohort_jobs[:5]:
    print(f"    {job['cohort']}/{job['age_band']}/{job['event_year']}")



‚ö†Ô∏è  DRY RUN: Limiting from 45 to 5 cohort combinations

üìä Discovered Cohorts:
  Total combinations: 5
  opioid_ed: 5 combinations

  Sample jobs:
    opioid_ed/13-24/2016
    opioid_ed/0-12/2016
    opioid_ed/25-44/2016
    opioid_ed/55-64/2016
    opioid_ed/45-54/2016


## Step 2: Define Cohort Processing Function

Create a function to process a single cohort for a specific item type with FP-Growth.


In [None]:
def process_single_cohort(job, item_type):
    """Process a single cohort for a specific item type with FP-Growth analysis."""
    cohort = job['cohort']
    age_band = job['age_band']
    event_year = job['event_year']
    local_path = job['local_path']
    
    # Create cohort-specific logger that inherits from main logger
    # Use the main logger's handlers so logs go to the same file
    cohort_logger = logging.getLogger(f"{cohort}_{age_band}_{event_year}_{item_type}")
    
    # Set level based on DEBUG_MODE (need to check if DEBUG_MODE is available in this scope)
    # Get DEBUG_MODE from globals or use INFO as default
    debug_mode = globals().get('DEBUG_MODE', False)
    cohort_logger.setLevel(logging.DEBUG if debug_mode else logging.INFO)
    
    # CRITICAL: Add handlers from main logger so logs are written to file
    # Without this, cohort_logger logs go nowhere!
    if not cohort_logger.handlers:
        # Get the main logger and copy its handlers
        main_logger = logging.getLogger('cohort_fpgrowth')
        for handler in main_logger.handlers:
            # For FileHandler, create new handler with same file
            if isinstance(handler, logging.FileHandler):
                new_handler = logging.FileHandler(handler.baseFilename)
            # For StreamHandler, create new handler with same stream
            elif isinstance(handler, logging.StreamHandler):
                new_handler = logging.StreamHandler(handler.stream)
            else:
                # Fallback: try to create same type
                new_handler = type(handler)(handler.baseFilename if hasattr(handler, 'baseFilename') else handler.stream)
            
            new_handler.setLevel(handler.level)
            new_handler.setFormatter(handler.formatter)
            cohort_logger.addHandler(new_handler)
    
    # Prevent propagation to avoid duplicate logs
    cohort_logger.propagate = False
    
    # Debug mode: log initial setup
    if debug_mode:
        cohort_logger.debug(f"Logger initialized for {cohort}/{age_band}/{event_year}/{item_type}")
        cohort_logger.debug(f"Logger level: {logging.getLevelName(cohort_logger.level)}")
        cohort_logger.debug(f"Handlers: {len(cohort_logger.handlers)}")
    
    try:
        start_time = time.time()
        cohort_logger.info(f"Processing {cohort}/{age_band}/{event_year} - {item_type}")
        mem_start = log_memory(cohort_logger, "START", show_details=True)
        
        # Extract items based on type + TARGET MARKERS (for target-focused rules)
        # Simple in-memory connection (no AWS needed for local parquet reads)
        cohort_logger.info("Creating DuckDB connection...")
        cohort_logger.debug(f"Local path: {local_path}")
        cohort_logger.debug(f"Item type: {item_type}")
        
        con = duckdb.connect(':memory:')
        con.sql("SET threads = 1")
        cohort_logger.debug("DuckDB connection created, threads set to 1")
        log_memory(cohort_logger, "After DuckDB connection")
        
        if item_type == 'drug_name':
            # Pharmacy events: drug names only
            query = f"""
            SELECT mi_person_key, drug_name as item
            FROM read_parquet('{local_path}')
            WHERE drug_name IS NOT NULL AND drug_name != '' AND event_type = 'pharmacy'
            """
        elif item_type == 'medical_codes':
            # Medical events: combine ICD codes (diagnoses) + CPT codes (procedures)
            # This creates richer transactions showing both diagnoses and procedures together
            query = f"""
            WITH all_icds AS (
                SELECT mi_person_key, primary_icd_diagnosis_code as code FROM read_parquet('{local_path}') 
                WHERE primary_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, two_icd_diagnosis_code as code FROM read_parquet('{local_path}') 
                WHERE two_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, three_icd_diagnosis_code as code FROM read_parquet('{local_path}') 
                WHERE three_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, four_icd_diagnosis_code as code FROM read_parquet('{local_path}') 
                WHERE four_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, five_icd_diagnosis_code as code FROM read_parquet('{local_path}') 
                WHERE five_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
            ),
            all_cpts AS (
                SELECT mi_person_key, procedure_code as code
                FROM read_parquet('{local_path}')
                WHERE procedure_code IS NOT NULL AND procedure_code != '' AND event_type = 'medical'
            ),
            combined_medical AS (
                SELECT mi_person_key, code as item FROM all_icds WHERE code != ''
                UNION ALL
                SELECT mi_person_key, code as item FROM all_cpts WHERE code != ''
            )
            SELECT mi_person_key, item FROM combined_medical
            """
        else:
            raise ValueError(f"Unknown item_type: {item_type}. Expected 'drug_name' or 'medical_codes'")
        
        cohort_logger.info(f"Executing query to extract {item_type}...")
        if debug_mode:
            cohort_logger.debug(f"Query: {query[:200]}...")  # Log first 200 chars of query
        
        df = con.execute(query).df()
        cohort_logger.info(f"Extracted {len(df):,} rows, {df['mi_person_key'].nunique():,} unique patients")
        
        if debug_mode:
            cohort_logger.debug(f"DataFrame shape: {df.shape}")
            cohort_logger.debug(f"Unique items: {df['item'].nunique() if 'item' in df.columns else 'N/A'}")
            cohort_logger.debug(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
        
        mem_after_extract = log_memory(cohort_logger, "After data extraction")
        
        # Check memory increase
        if mem_start['system_percent'] > 0:
            mem_delta = mem_after_extract['system_percent'] - mem_start['system_percent']
            if mem_delta > 10:
                cohort_logger.warning(f"‚ö†Ô∏è  Large memory increase during extraction: {mem_delta:.1f}%")
        
        # Add target markers if TARGET_FOCUSED mode is enabled
        if TARGET_FOCUSED:
            cohort_logger.info("Adding target markers...")
            
            # Get target information for each patient
            target_query = f"""
            SELECT DISTINCT 
                mi_person_key,
                primary_icd_diagnosis_code,
                hcg_line
            FROM read_parquet('{local_path}')
            WHERE mi_person_key IS NOT NULL
            """
df_targets = con.execute(target_query).df()
            cohort_logger.debug(f"Target query returned {len(df_targets):,} rows")
            
            # OPTIMIZED: Use vectorized operations instead of iterrows() (much faster for large datasets!)
            # iterrows() is extremely slow - can take minutes for 393K+ rows
            target_items_list = []
            
            # Check for opioid ICD codes (vectorized)
            if 'primary_icd_diagnosis_code' in df_targets.columns:
                # Create list of ICD code prefixes to match (without dots)
                icd_prefixes = [code.replace('.', '') for code in TARGET_ICD_CODES]
                
                # Vectorized check: check if any ICD code starts with any prefix
                icd_mask = df_targets['primary_icd_diagnosis_code'].notna()
                if icd_mask.any():
                    # For each row, check if ICD code starts with any prefix (vectorized)
                    icd_matches = df_targets.loc[icd_mask, 'primary_icd_diagnosis_code'].apply(
                        lambda x: any(str(x).startswith(prefix) for prefix in icd_prefixes) if pd.notna(x) else False
                    )
                    # Get patient IDs with opioid ICD codes
                    opioid_patients = df_targets.loc[icd_mask & icd_matches, 'mi_person_key']
                    if len(opioid_patients) > 0:
                        target_items_list.append(
                            pd.DataFrame({
                                'mi_person_key': opioid_patients,
                                'item': 'TARGET_ICD:OPIOID_DEPENDENCE'
                            })
                        )
                        cohort_logger.debug(f"Found {len(opioid_patients):,} patients with opioid ICD codes")
            
            # Check for ED visits (vectorized - much faster than iterrows!)
            if 'hcg_line' in df_targets.columns:
                ed_mask = df_targets['hcg_line'].isin(TARGET_HCG_LINES)
                ed_patients = df_targets.loc[ed_mask, 'mi_person_key']
                if len(ed_patients) > 0:
                    target_items_list.append(
                        pd.DataFrame({
                            'mi_person_key': ed_patients,
                            'item': 'TARGET_ED:EMERGENCY_DEPT'
                        })
                    )
                    cohort_logger.debug(f"Found {len(ed_patients):,} patients with ED visits")
            
            # Combine all target items
            if target_items_list:
                df_targets_items = pd.concat(target_items_list, ignore_index=True)
                df = pd.concat([df, df_targets_items], ignore_index=True)
                cohort_logger.info(f"Added {len(df_targets_items):,} target markers")
                mem_after_targets = log_memory(cohort_logger, "After target markers")
                
                # Check memory increase
                if mem_after_extract['system_percent'] > 0:
                    mem_delta = mem_after_targets['system_percent'] - mem_after_extract['system_percent']
                    if mem_delta > 5:
                        cohort_logger.warning(f"‚ö†Ô∏è  Memory increase during target markers: {mem_delta:.1f}%")
            else:
                cohort_logger.info("No target markers found (no opioid ICD codes or ED visits)")
                mem_after_targets = log_memory(cohort_logger, "After target markers (none found)")
        con.close()
        
        if df.empty:
            cohort_logger.warning(f"No {item_type} data for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No data")
        
        # Create transactions (group items by patient)
        cohort_logger.info(f"Building transactions from {len(df):,} rows...")
        mem_before_transactions = log_memory(cohort_logger, "Before transaction building")
        
        transactions = (
            df.groupby('mi_person_key')['item']
            .apply(lambda x: sorted(set(x.tolist())))
            .tolist()
        )
        
        mem_after_transactions = log_memory(cohort_logger, "After transaction building")
        cohort_logger.info(f"Created {len(transactions):,} transactions")
        
        # Check transaction sizes
        if transactions:
            transaction_sizes = [len(t) for t in transactions]
            cohort_logger.info(f"Transaction size stats: min={min(transaction_sizes)}, max={max(transaction_sizes)}, "
                             f"mean={np.mean(transaction_sizes):.1f}, median={np.median(transaction_sizes):.1f}")
            
            # Warn if very large transactions
            if max(transaction_sizes) > 1000:
                cohort_logger.warning(f"‚ö†Ô∏è  Very large transactions detected (max={max(transaction_sizes)}) - may cause memory issues")
        
        if not transactions:
            cohort_logger.warning(f"No valid transactions for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No transactions")
        
        # CRITICAL: For small cohorts, min_support needs to be higher to prevent explosion
        # With only 4 transactions, min_support=0.05 means itemset needs to appear in 1 transaction
        # This is way too low and generates millions of trivial itemsets
        MIN_TRANSACTIONS_FOR_STABLE_FPGROWTH = 10
        if len(transactions) < MIN_TRANSACTIONS_FOR_STABLE_FPGROWTH:
            cohort_logger.warning(f"‚ö†Ô∏è  Very small cohort: {len(transactions)} transactions")
            cohort_logger.warning(f"   FP-Growth may generate excessive itemsets. Consider skipping or using higher min_support.")
            # Increase min_support for small cohorts
            min_sup_original = MIN_SUPPORT_CPT if item_type == 'medical_codes' else MIN_SUPPORT
            # For <10 transactions, require itemset to appear in at least 2 transactions
            min_sup_adjusted = max(min_sup_original, 2.0 / len(transactions))
            if min_sup_adjusted > min_sup_original:
                cohort_logger.info(f"   Adjusted min_support from {min_sup_original} to {min_sup_adjusted:.3f} for small cohort")
                min_sup = min_sup_adjusted
            else:
                min_sup = min_sup_original
        else:
            min_sup = MIN_SUPPORT_CPT if item_type == 'medical_codes' else MIN_SUPPORT
        
        # Encode transactions
        cohort_logger.info(f"Encoding {len(transactions):,} transactions...")
        mem_before_encode = log_memory(cohort_logger, "Before encoding")
        
        if debug_mode:
            cohort_logger.debug(f"Sample transaction (first 3): {transactions[:3] if len(transactions) >= 3 else transactions}")
            cohort_logger.debug(f"Total unique items across all transactions: {len(set(item for t in transactions for item in t))}")
        
        te = TransactionEncoder()
        cohort_logger.debug("Fitting TransactionEncoder...")
        te_ary = te.fit(transactions).transform(transactions)
        cohort_logger.debug("Transforming transactions to binary matrix...")
        df_encoded = pd.DataFrame(te_ary, columns=te.columns_)
        
        mem_after_encode = log_memory(cohort_logger, "After encoding")
        cohort_logger.info(f"Encoded matrix shape: {df_encoded.shape} ({df_encoded.shape[0]*df_encoded.shape[1]:,} elements)")
        
        if debug_mode:
            cohort_logger.debug(f"Matrix memory usage: {df_encoded.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
            cohort_logger.debug(f"Matrix sparsity: {(1 - df_encoded.sum().sum() / (df_encoded.shape[0] * df_encoded.shape[1])) * 100:.2f}%")
        
        # Check memory increase during encoding (this is often where OOM happens)
        if mem_before_encode['system_percent'] > 0:
            mem_delta = mem_after_encode['system_percent'] - mem_before_encode['system_percent']
            if mem_delta > 15:
                cohort_logger.error(f"üö® CRITICAL: Large memory spike during encoding: {mem_delta:.1f}%")
                cohort_logger.error(f"   Matrix size may be too large - consider filtering transactions")
            elif mem_delta > 10:
                cohort_logger.warning(f"‚ö†Ô∏è  Significant memory increase during encoding: {mem_delta:.1f}%")
        
        # Run FP-Growth (min_sup already set above with small-cohort adjustment)
        cohort_logger.info(f"Running FP-Growth (min_support={min_sup})...")
        mem_before_fpgrowth = log_memory(cohort_logger, "Before FP-Growth")
        
        if debug_mode:
            cohort_logger.debug(f"FP-Growth parameters: min_support={min_sup}, matrix_shape={df_encoded.shape}")
            cohort_logger.debug(f"Minimum support count: {int(min_sup * len(transactions))} transactions")
        
        itemsets = fpgrowth(df_encoded, min_support=min_sup, use_colnames=True)
        itemsets = itemsets.sort_values('support', ascending=False).reset_index(drop=True)
        
        mem_after_fpgrowth = log_memory(cohort_logger, "After FP-Growth")
        
        if debug_mode:
            cohort_logger.debug(f"FP-Growth completed, processing {len(itemsets)} itemsets")
        
        if itemsets.empty:
            cohort_logger.warning(f"No itemsets found for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No itemsets")
        
        # Warn if still too many itemsets (indicates min_support too low)
        if len(itemsets) > 100000:
            cohort_logger.warning(f"‚ö†Ô∏è  WARNING: {len(itemsets):,} itemsets is excessive!")
            cohort_logger.warning(f"   Consider increasing min_support or filtering transactions")
            cohort_logger.warning(f"   Current min_support: {min_sup}, Transactions: {len(transactions)}")
        
        cohort_logger.info(f"Found {len(itemsets):,} itemsets (before lift filtering)")
        
        # CRITICAL: Filter out common/trivial itemsets by lift BEFORE generating rules
        # This prevents memory issues from keeping millions of trivial itemsets
        # For small cohorts, lift becomes unreliable, so use higher threshold or skip
        if len(transactions) < MIN_TRANSACTIONS_FOR_STABLE_FPGROWTH:
            # For very small cohorts, lift is unreliable - use much higher threshold
            min_lift_adjusted = max(MIN_ITEMSET_LIFT, 2.0)  # Require at least 2x lift
            cohort_logger.info(f"Using adjusted min_lift={min_lift_adjusted} for small cohort (lift unreliable with <{MIN_TRANSACTIONS_FOR_STABLE_FPGROWTH} transactions)")
        else:
            min_lift_adjusted = MIN_ITEMSET_LIFT
        
        itemsets = filter_itemsets_by_lift(
            itemsets, 
            df_encoded, 
            min_lift_adjusted, 
            cohort_logger
        )
        mem_after_filtering = log_memory(cohort_logger, "After lift filtering")
        
        if itemsets.empty:
            cohort_logger.warning(f"No itemsets remaining after lift filtering for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No itemsets after filtering")
        
        cohort_logger.info(f"Found {len(itemsets):,} itemsets (after lift filtering)")
        
        # CRITICAL: Filter out single-item itemsets before rule generation
        # association_rules() requires multi-item itemsets (can't generate rules from single items)
        # Check itemset sizes more carefully
        itemset_sizes = itemsets['itemsets'].apply(len)
        single_item_count = (itemset_sizes == 1).sum()
        
        if debug_mode:
            cohort_logger.debug(f"Itemset size distribution: {itemset_sizes.value_counts().to_dict()}")
        
        if single_item_count > 0:
            cohort_logger.info(f"Filtering out {single_item_count:,} single-item itemsets (cannot generate rules)")
            itemsets = itemsets[itemset_sizes > 1].copy()
            cohort_logger.info(f"Remaining multi-item itemsets: {len(itemsets):,}")
        
        if itemsets.empty:
            cohort_logger.warning(f"No multi-item itemsets remaining for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No multi-item itemsets")
        
        # Double-check: verify all remaining itemsets are multi-item
        final_itemset_sizes = itemsets['itemsets'].apply(len)
        if (final_itemset_sizes == 1).any():
            cohort_logger.error(f"‚ö†Ô∏è  ERROR: Single-item itemsets still present after filtering!")
            cohort_logger.error(f"   This should not happen - removing them now")
            itemsets = itemsets[final_itemset_sizes > 1].copy()
            if itemsets.empty:
                cohort_logger.warning(f"No multi-item itemsets remaining after cleanup")
                return (cohort, age_band, event_year, item_type, False, "No multi-item itemsets after cleanup")
        
        # Check memory increase
        if mem_before_fpgrowth['system_percent'] > 0:
            mem_delta = mem_after_fpgrowth['system_percent'] - mem_before_fpgrowth['system_percent']
            if mem_delta > 10:
                cohort_logger.warning(f"‚ö†Ô∏è  Memory increase during FP-Growth: {mem_delta:.1f}%")
        
        # Log itemset statistics (after single-item filtering)
        if len(itemsets) > 0:
            final_itemset_sizes = itemsets['itemsets'].apply(len)
            cohort_logger.info(f"Itemset size stats: min={final_itemset_sizes.min()}, max={final_itemset_sizes.max()}, "
                             f"mean={final_itemset_sizes.mean():.1f}, median={final_itemset_sizes.median():.1f}")
            
            # Verify no single-item itemsets remain
            if (final_itemset_sizes == 1).any():
                cohort_logger.error(f"üö® CRITICAL: Single-item itemsets detected before rule generation!")
                cohort_logger.error(f"   This will cause association_rules() to fail")
        
        # Generate rules (with appropriate thresholds and limits)
        min_conf = MIN_CONFIDENCE_CPT if item_type == 'medical_codes' else MIN_CONFIDENCE
        cohort_logger.info(f"Generating rules (min_confidence={min_conf})...")
        mem_before_rules = log_memory(cohort_logger, "Before rule generation")
        
        if debug_mode:
            cohort_logger.debug(f"Rule generation parameters: min_confidence={min_conf}, itemsets={len(itemsets)}")
            cohort_logger.debug(f"Expected max rules (theoretical): {len(itemsets) * (len(itemsets) - 1) / 2:,}")
        
        # Final safety check: ensure no single-item itemsets
        itemset_lengths = itemsets['itemsets'].apply(len)
        if (itemset_lengths == 1).any():
            cohort_logger.error(f"üö® CRITICAL: Cannot generate rules - single-item itemsets present!")
            cohort_logger.error(f"   Single-item count: {(itemset_lengths == 1).sum()}")
            cohort_logger.error(f"   This is a bug - single-item itemsets should have been filtered")
            rules = pd.DataFrame()
            rules_control = pd.DataFrame()
        else:
            try:
                # Verify itemsets structure before calling association_rules
                if debug_mode:
                    cohort_logger.debug(f"Itemsets DataFrame columns: {itemsets.columns.tolist()}")
                    cohort_logger.debug(f"Itemsets shape: {itemsets.shape}")
                    if len(itemsets) > 0:
                        cohort_logger.debug(f"Sample itemset: {itemsets.iloc[0]['itemsets']}")
                        cohort_logger.debug(f"Sample itemset size: {len(itemsets.iloc[0]['itemsets'])}")
                
                # mlxtend association_rules() requires itemsets with at least 2 items
                # Double-check one more time before calling
                final_check = itemsets['itemsets'].apply(len)
                if (final_check == 1).any():
                    raise ValueError("Single-item itemsets detected - cannot generate association rules")
                
                # Try generating full rules with confidence/lift first
                # If this fails (mlxtend bug with certain itemset structures), fall back to support_only
                try:
                    all_rules = association_rules(itemsets, metric="confidence", min_threshold=min_conf)
                    mem_after_rules = log_memory(cohort_logger, "After rule generation")
                except Exception as e:
                    # mlxtend sometimes fails with "missing antecedent/consequent" error
                    # This can happen if itemsets structure doesn't match what mlxtend expects
                    # Fall back to support_only mode (but we lose confidence/lift metrics)
                    error_msg = str(e)
                    if "antecedent" in error_msg.lower() or "consequent" in error_msg.lower():
                        cohort_logger.warning(f"‚ö†Ô∏è  association_rules() failed with itemset structure issue")
                        cohort_logger.warning(f"   Error: {error_msg[:200]}")
                        cohort_logger.warning(f"   Falling back to support_only=True (will only have support metrics)")
                        cohort_logger.warning(f"   This is a workaround for mlxtend bug - investigating itemsets structure")
                        
                        # Log itemset details for debugging
                        if debug_mode:
                            cohort_logger.debug(f"Itemsets that caused error:")
                            for idx, row in itemsets.head(10).iterrows():
                                cohort_logger.debug(f"  [{idx}] {row['itemsets']} (size={len(row['itemsets'])}, support={row['support']:.4f})")
                        
                        # Fallback: use support_only (but we lose confidence/lift)
                        all_rules = association_rules(itemsets, metric="support", min_threshold=0.0, support_only=True)
                        mem_after_rules = log_memory(cohort_logger, "After rule generation (support_only fallback)")
                        cohort_logger.warning(f"‚ö†Ô∏è  Generated {len(all_rules):,} rules with support_only=True (no confidence/lift metrics)")
                    else:
                        # Re-raise if it's a different error
                        raise
                
                # Check if we're in support_only mode (no confidence/lift)
                support_only_mode = 'confidence' not in all_rules.columns if len(all_rules) > 0 else False
                
                if debug_mode:
                    cohort_logger.debug(f"Generated {len(all_rules):,} association rules")
                    if len(all_rules) > 0:
                        if support_only_mode:
                            cohort_logger.debug(f"‚ö†Ô∏è  Rules in support_only mode - no confidence/lift metrics available")
                            cohort_logger.debug(f"Rule support range: {all_rules['support'].min():.3f} - {all_rules['support'].max():.3f}")
                        else:
                            cohort_logger.debug(f"Rule confidence range: {all_rules['confidence'].min():.3f} - {all_rules['confidence'].max():.3f}")
                            cohort_logger.debug(f"Rule lift range: {all_rules['lift'].min():.3f} - {all_rules['lift'].max():.3f}")
                
                # Check memory increase (rule generation can be memory-intensive)
                if mem_before_rules['system_percent'] > 0:
                    mem_delta = mem_after_rules['system_percent'] - mem_before_rules['system_percent']
                    if mem_delta > 20:
                        cohort_logger.error(f"üö® CRITICAL: Very large memory spike during rule generation: {mem_delta:.1f}%")
                        cohort_logger.error(f"   Consider increasing min_confidence or filtering itemsets before rule generation")
                    elif mem_delta > 15:
                        cohort_logger.warning(f"‚ö†Ô∏è  Large memory increase during rule generation: {mem_delta:.1f}%")
                
                if len(all_rules) > 0:
                    # Split rules: target-predicting vs control (non-target)
                    if TARGET_FOCUSED:
                        # Target rules: consequent contains target marker
                        target_mask = all_rules['consequents'].apply(
                            lambda x: any(item.startswith(tuple(TARGET_PREFIXES)) for item in x)
                        )
                        rules_target = all_rules[target_mask].copy()
                        rules_control = all_rules[~target_mask].copy()
                        
                        cohort_logger.info(f"Split: {len(rules_target)} target rules, {len(rules_control)} control rules")
                        
                        # Limit both sets to top N (by lift if available, otherwise by support)
                        sort_col = 'lift' if 'lift' in all_rules.columns else 'support'
                        sort_ascending = False  # Descending (highest first)
                        
                        if len(rules_target) > 0:
                            rules_target = rules_target.sort_values(sort_col, ascending=sort_ascending)
                            if len(rules_target) > MAX_RULES_PER_COHORT:
                                cohort_logger.info(f"Keeping top {MAX_RULES_PER_COHORT} target rules (from {len(rules_target)})")
                                rules_target = rules_target.head(MAX_RULES_PER_COHORT)
                            rules_target = rules_target.reset_index(drop=True)
                        
                        if len(rules_control) > 0:
                            rules_control = rules_control.sort_values(sort_col, ascending=sort_ascending)
                            if len(rules_control) > MAX_RULES_PER_COHORT:
                                cohort_logger.info(f"Keeping top {MAX_RULES_PER_COHORT} control rules (from {len(rules_control)})")
                                rules_control = rules_control.head(MAX_RULES_PER_COHORT)
                            rules_control = rules_control.reset_index(drop=True)
                        
                        # Keep target rules as main 'rules' for backward compatibility
                        rules = rules_target
                    else:
                        # Not target-focused: all rules are kept
                        sort_col = 'lift' if 'lift' in all_rules.columns else 'support'
                        rules = all_rules.sort_values(sort_col, ascending=False).head(MAX_RULES_PER_COHORT).reset_index(drop=True)
                        rules_control = pd.DataFrame()
                    
                    cohort_logger.info(f"Final: {len(rules)} target rules, {len(rules_control)} control rules")
                    mem_after_filtering = log_memory(cohort_logger, "After rule filtering")
                else:
                    cohort_logger.info(f"No rules met confidence threshold of {min_conf}")
                    rules = pd.DataFrame()
                    rules_control = pd.DataFrame()
                    
            except MemoryError as e:
                cohort_logger.error(f"MemoryError during rule generation - skipping rules")
                rules = pd.DataFrame()
                rules_control = pd.DataFrame()
            except Exception as e:
                cohort_logger.error(f"Error generating rules: {e}")
                rules = pd.DataFrame()
                rules_control = pd.DataFrame()
        
        # Convert frozensets for JSON
        itemsets_json = itemsets.copy()
        itemsets_json['itemsets'] = itemsets_json['itemsets'].apply(list)
        
        # Prepare rules for saving (split target rules by type, plus control)
        rules_by_target = {}
        
        # Process target rules (split by ICD vs ED)
        if not rules.empty:
            rules_json = rules.copy()
            rules_json['antecedents'] = rules_json['antecedents'].apply(list)
            rules_json['consequents'] = rules_json['consequents'].apply(list)
            
            # Split target rules by outcome type
            rules_by_target['TARGET_ICD'] = rules_json[
                rules_json['consequents'].apply(lambda x: any('TARGET_ICD:' in str(item) for item in x))
            ]
            rules_by_target['TARGET_ED'] = rules_json[
                rules_json['consequents'].apply(lambda x: any('TARGET_ED:' in str(item) for item in x))
            ]
        
        # Process control rules (non-target patterns)
        if not rules_control.empty:
            rules_control_json = rules_control.copy()
            rules_control_json['antecedents'] = rules_control_json['antecedents'].apply(list)
            rules_control_json['consequents'] = rules_control_json['consequents'].apply(list)
            rules_by_target['CONTROL'] = rules_control_json
        
        if rules_by_target:
            cohort_logger.info(f"Prepared for S3: {len(rules_by_target.get('TARGET_ICD', pd.DataFrame()))} ICD, "
                             f"{len(rules_by_target.get('TARGET_ED', pd.DataFrame()))} ED, "
                             f"{len(rules_by_target.get('CONTROL', pd.DataFrame()))} control")
        
        # Save to S3 (with retry logic for reliability)
        s3_base = f"{S3_OUTPUT_BASE}/{item_type}/cohort_name={cohort}/age_band={age_band}/event_year={event_year}"
        
        cohort_logger.info(f"Saving results to S3...")
        max_retries = 3
        for attempt in range(max_retries):
            try:
                itemsets_path = f"{s3_base}/itemsets.json"
                save_to_s3_json(itemsets_json.to_dict(orient='records'), itemsets_path)
                
                # Save rules by target type (separate files)
                if rules_by_target:
                    for target_type, target_rules in rules_by_target.items():
                        if not target_rules.empty:
                            rules_path = f"{s3_base}/rules_{target_type}.json"
                            save_to_s3_json(target_rules.to_dict(orient='records'), rules_path)
                            cohort_logger.info(f"Saved {len(target_rules)} {target_type} rules")
                
                summary = {
                    'timestamp': datetime.now().isoformat(),
                    'cohort': cohort, 'age_band': age_band, 'event_year': event_year,
                    'item_type': item_type,
                    'total_patients': len(transactions),
                    'total_itemsets': len(itemsets),
                    'total_rules': len(rules),
                    'rules_by_target': {
                        'TARGET_ICD': len(rules_by_target.get('TARGET_ICD', pd.DataFrame())),
                        'TARGET_ED': len(rules_by_target.get('TARGET_ED', pd.DataFrame())),
                        'CONTROL': len(rules_by_target.get('CONTROL', pd.DataFrame()))
                    } if rules_by_target else {'TARGET_ICD': 0, 'TARGET_ED': 0, 'CONTROL': 0},
                    'min_support': min_sup,
                    'min_confidence': min_conf,
                    'max_rules_limit': MAX_RULES_PER_COHORT,
                    'rules_truncated': len(rules) == MAX_RULES_PER_COHORT,
                    'target_focused': TARGET_FOCUSED,
                    'target_icd_codes': TARGET_ICD_CODES if TARGET_FOCUSED else None,
                    'target_hcg_lines': TARGET_HCG_LINES if TARGET_FOCUSED else None
                }
                summary_path = f"{s3_base}/summary.json"
                save_to_s3_json(summary, summary_path)
                
                cohort_logger.info(f"‚úì Saved to S3 successfully")
                break  # Success - exit retry loop
                
            except Exception as s3_error:
                if attempt < max_retries - 1:
                    cohort_logger.warning(f"S3 upload attempt {attempt+1} failed: {s3_error}, retrying...")
                    time.sleep(2 ** attempt)  # Exponential backoff
                else:
                    cohort_logger.error(f"S3 upload failed after {max_retries} attempts: {s3_error}")
                    raise
        
        elapsed = time.time() - start_time
        log_memory(cohort_logger, "END")
        cohort_logger.info(f"‚úì Completed in {elapsed:.1f}s")
        
        return (cohort, age_band, event_year, item_type, True, f"{len(itemsets)} itemsets, {len(rules)} rules")
        
    except Exception as e:
        cohort_logger.error(f"Error: {e}")
        return (cohort, age_band, event_year, item_type, False, str(e))

print("‚úì Cohort processing function defined")


‚úì Cohort processing function defined


## Step 3: Process Cohorts in Parallel

Run FP-Growth for all cohort combinations using parallel processing.


In [None]:
import boto3

print("="*80)
print("COHORT FPGROWTH ANALYSIS - START")
print("="*80)
print(f"Cohorts: {len(cohort_jobs)} combinations")
print(f"Item types: {ITEM_TYPES}")
print(f"Total jobs: {len(cohort_jobs) * len(ITEM_TYPES)}")
print(f"Max workers: {MAX_WORKERS}")
print(f"Detailed progress ‚Üí Check log file")
print()

logger.info(f"\n{'='*80}")
logger.info(f"COHORT FPGROWTH ANALYSIS - START")
logger.info(f"{'='*80}")
logger.info(f"Cohorts: {len(cohort_jobs)} combinations")
logger.info(f"Item types: {ITEM_TYPES}")
logger.info(f"Total jobs: {len(cohort_jobs) * len(ITEM_TYPES)}")

# Helper function to check if cohort results exist in S3
def check_cohort_exists(item_type: str, cohort: str, age_band: str, event_year: str) -> bool:
    """Check if cohort results already exist in S3 (by checking for summary.json)."""
    s3 = boto3.client('s3')
    key = f"gold/fpgrowth/cohort/{item_type}/cohort_name={cohort}/age_band={age_band}/event_year={event_year}/summary.json"
    try:
        s3.head_object(Bucket='pgxdatalake', Key=key)
        return True
    except:
        return False

start_time = time.time()
results = []
completed = 0
failed = 0
skipped = 0

# Create all combinations of cohorts and item types
all_jobs_initial = [(job, item_type) for job in cohort_jobs for item_type in ITEM_TYPES]

# Filter out already-completed jobs
print("\nChecking for existing results in S3...")
all_jobs = []
for job, item_type in all_jobs_initial:
    if check_cohort_exists(item_type, job['cohort'], job['age_band'], job['event_year']):
        logger.info(f"Skipping {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type} - already exists")
        skipped += 1
        results.append({
            'cohort': job['cohort'],
            'age_band': job['age_band'],
            'event_year': job['event_year'],
            'item_type': item_type,
            'success': True,
            'message': 'Already exists in S3 (skipped)'
        })
    else:
        all_jobs.append((job, item_type))

total_jobs = len(all_jobs_initial)
print(f"Total jobs: {total_jobs}")
print(f"Already completed: {skipped}")
print(f"To process: {len(all_jobs)}")
print()

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submit all jobs
    future_to_params = {executor.submit(process_single_cohort, job, item_type): (job, item_type) 
                        for job, item_type in all_jobs}
    
    # Process results as they complete
    for future in as_completed(future_to_params):
        job, item_type = future_to_params[future]
        try:
            cohort, age_band, event_year, item_type, success, message = future.result()
            results.append({
                'cohort': cohort,
                'age_band': age_band,
                'event_year': event_year,
                'item_type': item_type,
                'success': success,
                'message': message
            })
            
            if success:
                completed += 1
                logger.info(f"[{completed + failed}/{len(all_jobs)}] ‚úì {cohort}/{age_band}/{event_year}/{item_type}: {message}")
                # Print every 10 successes or milestones
                if completed % 10 == 0 or (completed + failed) == len(all_jobs):
                    print(f"Progress: {completed}/{len(all_jobs)} completed ({completed/len(all_jobs)*100:.1f}%), {failed} failed, {skipped} skipped")
            else:
                failed += 1
                logger.warning(f"[{completed + failed}/{total_jobs}] ‚úó {cohort}/{age_band}/{event_year}/{item_type}: {message}")
                print(f"‚ö† Failed: {cohort}/{age_band}/{event_year}/{item_type}")
                
        except Exception as e:
            failed += 1
            logger.error(f"[{completed + failed}/{total_jobs}] ‚úó {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type}: {e}")
            print(f"‚ö† Error: {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type}")
            results.append({
                'cohort': job['cohort'],
                'age_band': job['age_band'],
                'event_year': job['event_year'],
                'item_type': item_type,
                'success': False,
                'message': str(e)
            })

elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"COHORT FPGROWTH ANALYSIS - COMPLETE")
print(f"{'='*80}")
print(f"  Total jobs: {total_jobs}")
print(f"  Already in S3 (skipped): {skipped}")
print(f"  Newly processed: {completed}")
print(f"  Failed: {failed}")
print(f"  Overall success rate: {(skipped + completed)/total_jobs*100:.1f}%")
print(f"  Total time: {elapsed:.1f}s ({elapsed/60:.1f}min)")
if len(all_jobs) > 0:
    print(f"  Avg time per new job: {elapsed/len(all_jobs):.1f}s")
else:
    print(f"  (No new jobs processed - all results already in S3)")


COHORT FPGROWTH ANALYSIS - START
Cohorts: 5 combinations
Item types: ['drug_name', 'icd_code', 'cpt_code']
Total jobs: 15
Max workers: 1
Detailed progress ‚Üí Check log file


Checking for existing results in S3...
Total jobs: 15
Already completed: 3
To process: 12



## Step 4: Analyze Results

Review processing results and identify any issues.


In [None]:
# Convert results to DataFrame for analysis
results_df = pd.DataFrame(results)

print("\nüìä Results by Cohort and Item Type:")
summary = results_df.groupby(['cohort', 'item_type'])['success'].agg([
    ('total', 'count'), 
    ('successful', 'sum'),
    ('success_rate', lambda x: f"{x.mean()*100:.1f}%")
])
print(summary)

print("\nüìä Results by Item Type:")
item_summary = results_df.groupby('item_type')['success'].agg([
    ('total', 'count'), 
    ('successful', 'sum'),
    ('success_rate', lambda x: f"{x.mean()*100:.1f}%")
])
print(item_summary)

print("\n‚ùå Failed Jobs:")
failed_df = results_df[~results_df['success']]
if not failed_df.empty:
    print(failed_df[['cohort', 'age_band', 'event_year', 'item_type', 'message']].to_string())
else:
    print("  None! All jobs completed successfully.")

print("\n‚úì Successful Jobs Sample:")
success_df = results_df[results_df['success']]
if not success_df.empty:
    print(success_df[['cohort', 'age_band', 'event_year', 'item_type', 'message']].head(15))
else:
    print("  No successful jobs.")


## Summary and Next Steps


In [None]:
print("="*80)
print("COHORT FPGROWTH ANALYSIS - SUMMARY")
print("="*80)

print(f"\nüìä Processing Statistics:")
print(f"  Cohort combinations: {len(cohort_jobs)}")
print(f"  Item types: {len(ITEM_TYPES)} (drug_name, icd_code, cpt_code)")
print(f"  Total jobs: {total_jobs}")
print(f"  Successfully processed: {completed}")
print(f"  Failed: {failed}")
print(f"  Success rate: {completed/total_jobs*100:.1f}%")
print(f"  Processing time: {elapsed:.1f}s ({elapsed/60:.1f}min)")

print(f"\nüîç FP-Growth Configuration:")
print(f"  Min support: {MIN_SUPPORT} ({MIN_SUPPORT*100:.1f}%)")
print(f"  Min confidence: {MIN_CONFIDENCE} ({MIN_CONFIDENCE*100:.1f}%)")
print(f"  Parallel workers: {MAX_WORKERS}")

print(f"\nüíæ Output Location:")
print(f"  S3 Base: {S3_OUTPUT_BASE}")
print(f"  Structure: <item_type>/cohort_name=<name>/age_band=<band>/event_year=<year>/")
print(f"  Item types:")
print(f"    - drug_name/ (pharmacy events)")
print(f"    - icd_code/ (diagnosis codes)")
print(f"    - cpt_code/ (procedure codes)")
print(f"  Files per cohort:")
print(f"    - itemsets.json (frequent itemsets)")
print(f"    - rules.json (association rules)")
print(f"    - summary.json (metadata)")

print(f"\nüéØ Next Steps:")
print(f"  1. Load cohort-specific itemsets for BupaR process mining")
print(f"  2. Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts across item types")
print(f"  3. Use association rules for pathway analysis")
print(f"  4. Filter features for cohort-specific CatBoost models")
print(f"  5. Create network visualizations for cohort-specific patterns")

print(f"\nüìù Example Usage:")
print(f"  # Load drug patterns for a specific cohort")
print(f"  from helpers_1997_13.s3_utils import load_from_s3_json")
print(f"  itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/drug_name/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")
print(f"  # Load ICD patterns for same cohort")
print(f"  icd_itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/icd_code/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")

print(f"\n‚úì Analysis complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)


## 5. Auto-Shutdown EC2 Instance (Optional)

Set `SHUTDOWN_EC2 = True` to automatically stop the EC2 instance after analysis completes.

**Note:** This is a **STOP** (not terminate), so you can restart the instance later.

In [None]:
# =============================================================================
# EC2 AUTO-SHUTDOWN (OPTIONAL)
# =============================================================================
# Set SHUTDOWN_EC2 = True to enable, False to disable
SHUTDOWN_EC2 = False  # Change to True to enable auto-shutdown

if SHUTDOWN_EC2:
    print("\n" + "="*80)
    print("Shutting down EC2 instance...")
    print("="*80)
    
    import subprocess
    import requests
    import shutil
    
    # Get instance ID from EC2 metadata service
    try:
        response = requests.get(
            "http://169.254.169.254/latest/meta-data/instance-id",
            timeout=2
        )
        if response.status_code == 200:
            instance_id = response.text.strip()
            print(f"Instance ID: {instance_id}")
            
            # Find AWS CLI
            aws_cmd = shutil.which("aws")
            if not aws_cmd:
                # Try common paths
                for path in ["/usr/local/bin/aws", "/usr/bin/aws", 
                           "/home/ec2-user/.local/bin/aws", 
                           "/home/ubuntu/.local/bin/aws",
                           "/home/pgx3874/.local/bin/aws"]:
                    if Path(path).exists():
                        aws_cmd = path
                        break
            
            if aws_cmd:
                # Stop the instance (use terminate-instances for permanent deletion)
                shutdown_cmd = [aws_cmd, "ec2", "stop-instances", "--instance-ids", instance_id]
                
                print(f"Running: {' '.join(shutdown_cmd)}")
                result = subprocess.run(shutdown_cmd, capture_output=True, text=True)
                
                if result.returncode == 0:
                    print("‚úì EC2 instance stop command sent successfully")
                    print("Instance will stop in a few moments.")
                    print("Note: This is a STOP (not terminate), so you can restart it later.")
                    if result.stdout:
                        print(f"\nAWS Response:\n{result.stdout}")
                else:
                    print(f"‚úó EC2 stop command failed with exit code {result.returncode}")
                    if result.stderr:
                        print(f"Error: {result.stderr}")
                    print("Check AWS credentials and IAM permissions.")
            else:
                print("‚úó AWS CLI not found. Cannot shutdown instance.")
                print("Install AWS CLI or ensure it's in your PATH.")
                print("Manual shutdown: aws ec2 stop-instances --instance-ids " + instance_id)
        else:
            print(f"‚úó Metadata service returned status code {response.status_code}")
            print("Could not retrieve instance ID.")
    
    except requests.exceptions.RequestException as e:
        print("‚úó Could not retrieve instance ID from metadata service.")
        print(f"Error: {e}")
        print("If running on EC2, check that metadata service is accessible.")
        print("\nManual shutdown command:")
        print("  aws ec2 stop-instances --instance-ids <your-instance-id>")
    
    except Exception as e:
        print(f"‚úó Unexpected error during shutdown: {e}")

else:
    print("\n" + "="*80)
    print("EC2 Auto-Shutdown: DISABLED")
    print("="*80)
    print("To enable auto-shutdown, set SHUTDOWN_EC2 = True in this cell.")
    print("Instance will continue running.")
    print("\nTo manually stop this instance later:")
    print("  aws ec2 stop-instances --instance-ids $(ec2-metadata --instance-id | cut -d ' ' -f 2)")
    print("Or use AWS Console: EC2 > Instances > Select instance > Instance State > Stop")
