# Cohort-Specific FPGrowth Feature Importance Analysis

**Purpose:** Cohort-level frequent pattern mining for process mining and comparative analysis  
**Updated:** November 23, 2025  
**Hardware:** Optimized for EC2 (32 cores, 1TB RAM)  
**Output:** `s3://pgxdatalake/gold/fpgrowth/cohort/{item_type}/cohort_name={cohort}/...`

Key Features

‚úÖ **Three Item Types** - Drugs, ICD codes, CPT codes  
‚úÖ **Cohort-Specific Patterns** - Discovers patterns unique to each cohort  
‚úÖ **Parallel Processing** - Processes multiple cohorts simultaneously  
‚úÖ **BupaR Integration** - Outputs ready for process mining workflows  
‚úÖ **Comparative Analysis** - Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts

Methodology

For each combination of (cohort, age_band, event_year, item_type):
1. Extract items from cohort-specific data
2. Create patient-level transactions
3. Encode transactions into binary matrix
4. Run FP-Growth to find frequent itemsets
5. Generate association rules
6. Save results to S3 in organized structure

Key Differences from Global Analysis

| Aspect | Global FPGrowth | Cohort FPGrowth |
|--------|-----------------|-----------------|
| **Scope** | All patients (~5.7M) | Individual cohorts (~10K-100K) |
| **Purpose** | Universal ML features | Process mining patterns |
| **Support Threshold** | 0.01 (1%) | 0.05 (5%) |
| **Output** | `global/{item_type}/` | `cohort/{item_type}/cohort_name={c}/...` |
| **Use Case** | CatBoost consistency | BupaR pathway analysis |
| **Parallelization** | Sequential by item type | Parallel by cohort |

Expected Runtime (EC2: 32 cores, 1TB RAM)

- **Cohorts**: 2 (opioid_ed, ed_non_opioid)
- **Age bands √ó Years**: ~100 combinations per cohort
- **Item types**: 3 (drug_name, icd_code, cpt_code)
- **Total jobs**: ~600 combinations
- **Avg time per job**: ~1-2 minutes
- **Total runtime**: ~2-4 hours (with MAX_WORKERS=4)

S3 Output Structure

```
s3://pgxdatalake/gold/fpgrowth/cohort/
‚îú‚îÄ‚îÄ drug_name/
‚îÇ   ‚îú‚îÄ‚îÄ cohort_name=opioid_ed/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ age_band=65-74/event_year=2020/
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ itemsets.json
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ rules.json
‚îÇ   ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ summary.json
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îÇ   ‚îî‚îÄ‚îÄ cohort_name=ed_non_opioid/...
‚îú‚îÄ‚îÄ icd_code/
‚îÇ   ‚îî‚îÄ‚îÄ (same structure)
‚îî‚îÄ‚îÄ cpt_code/
    ‚îî‚îÄ‚îÄ (same structure)
```

---


## Environment Setup and Imports


In [1]:
import os
import sys
import json
import pandas as pd
import numpy as np
from datetime import datetime
import logging
from pathlib import Path
import psutil
import duckdb
import time
from concurrent.futures import ProcessPoolExecutor, as_completed

# MLxtend for FP-Growth
from mlxtend.frequent_patterns import fpgrowth, association_rules
from mlxtend.preprocessing import TransactionEncoder

# Project root
project_root = Path.cwd().parent if Path.cwd().name == '3_fpgrowth_analysis' else Path.cwd()
sys.path.insert(0, str(project_root))

# Project utilities
from helpers_1997_13.common_imports import s3_client, S3_BUCKET
from helpers_1997_13.duckdb_utils import get_duckdb_connection
from helpers_1997_13.s3_utils import save_to_s3_json, save_to_s3_parquet, get_cohort_parquet_path
from helpers_1997_13.fpgrowth_utils import run_fpgrowth_drug_token_with_fallback, convert_frozensets
from helpers_1997_13.visualization_utils import create_network_visualization
from helpers_1997_13.constants import AGE_BANDS, EVENT_YEARS

print(f"‚úì Project root: {project_root}")
print(f"‚úì All imports successful")
print(f"‚úì Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


‚úì Project root: /home/pgx3874/pgx-analysis
‚úì All imports successful
‚úì Timestamp: 2025-11-24 00:28:54


## EC2 Configuration

In [None]:
# =============================================================================
# EC2 CONFIGURATION (32 cores, 1TB RAM)
# =============================================================================

# FP-Growth parameters (higher threshold for cohort-specific patterns)
MIN_SUPPORT = 0.05       # 5% support (items must appear in 5% of patients within cohort)
MIN_CONFIDENCE = 0.5     # 50% confidence - only strong associations

# CPT-specific parameters (prevent memory exhaustion from millions of rules)
MIN_SUPPORT_CPT = 0.15   # 15% support for CPT codes (focuses on common patterns)
MIN_CONFIDENCE_CPT = 0.6 # 60% confidence for CPT (very strong associations only)

# Rule limits (focus on most important rules)
MAX_RULES_PER_COHORT = 1000  # Keep top 1000 rules by lift (practical limit)

# Target-focused rule mining (NEW!)
TARGET_FOCUSED = True  # Only generate rules that predict target outcomes
TARGET_ICD_CODES = ['F11.20', 'F11.21', 'F11.22', 'F11.23', 'F11.24', 'F11.25', 'F11.29']  # Opioid dependence codes
TARGET_HCG_LINES = [
    "P51 - ER Visits and Observation Care",
    "O11 - Emergency Room",
    "P33 - Urgent Care Visits"
]  # ED visits (HCG Line codes - matches phase2_event_processing.py)
TARGET_PREFIXES = ['TARGET_ICD:', 'TARGET_ED:']  # Prefixes for target items in transactions

# Item types to process
ITEM_TYPES = ['drug_name', 'icd_code', 'cpt_code']

# Processing parameters
MAX_WORKERS = 1  # Sequential processing to prevent memory issues

# DRY RUN MODE (test with limited cohorts first)
DRY_RUN = True  # Set to False to process all cohorts
DRY_RUN_LIMIT = 5  # Number of cohort combinations to process in dry run
COHORTS_TO_PROCESS = ['opioid_ed', 'ed_non_opioid']  # Specify cohorts to process

# Paths
S3_OUTPUT_BASE = f"s3://{S3_BUCKET}/gold/fpgrowth/cohort"
LOCAL_DATA_PATH = Path("/mnt/nvme/cohorts")  # Instance storage (NVMe SSD for fast I/O)

# Setup logger with file output (prevents Jupyter rate limit issues)
logger = logging.getLogger('cohort_fpgrowth')
logger.setLevel(logging.INFO)
logger.handlers.clear()  # Clear any existing handlers

# File handler - full logs to file
log_file = project_root / "3_fpgrowth_analysis" / "cohort_fpgrowth_execution.log"
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)

# Console handler - only major milestones (prevents Jupyter rate limit)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.WARNING)  # Only warnings/errors to console
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(console_handler)

print(f"‚úì Min Support (drug/ICD): {MIN_SUPPORT} (5%)")
print(f"‚úì Min Support (CPT): {MIN_SUPPORT_CPT} (15% - focuses on common patterns)")
print(f"‚úì Min Confidence (drug/ICD): {MIN_CONFIDENCE} (50% - strong associations)")
print(f"‚úì Min Confidence (CPT): {MIN_CONFIDENCE_CPT} (60% - very strong associations)")
print(f"‚úì Max Rules per Cohort: {MAX_RULES_PER_COHORT:,} (top rules by lift)")
print(f"‚úì Item Types: {ITEM_TYPES}")
print(f"‚úì Max Workers: {MAX_WORKERS} (sequential - prevents OOM)")
if DRY_RUN:
    print(f"‚úì DRY RUN MODE: Processing only {DRY_RUN_LIMIT} cohort combinations")
    print(f"  ‚Üí Set DRY_RUN = False to process all cohorts")
else:
    print(f"‚úì FULL RUN MODE: Processing all cohorts")
print(f"‚úì Cohorts: {COHORTS_TO_PROCESS}")
print(f"‚úì S3 Output: {S3_OUTPUT_BASE}")
print(f"‚úì S3 Retry: 3 attempts with exponential backoff")
print(f"‚úì Local Data: {LOCAL_DATA_PATH}")
print(f"‚úì Local Data Exists: {LOCAL_DATA_PATH.exists()}")
print(f"‚úì Detailed logs ‚Üí {log_file}")
print(f"‚úì Console output: WARNING level only (check log file for progress)")
print("\nüéØ Quality Over Quantity Approach:")
print("  - High confidence thresholds (50-60%) = meaningful patterns only")
print("  - CPT uses 15% support (vs 5%) = focuses on common procedures")
print("  - Top 1,000 rules by lift = actionable insights, not exhaustive lists")
print("  - 2 parallel workers = stable memory usage")
print(f"\nüéØ TARGET-FOCUSED RULE MINING: {'ENABLED' if TARGET_FOCUSED else 'DISABLED'}")
if TARGET_FOCUSED:
    print(f"  - Target ICD codes: {TARGET_ICD_CODES}")
    print(f"  - Target HCG lines (ED visits): {TARGET_HCG_LINES}")
    print("  - Only generates rules that PREDICT target outcomes")
    print("  - Example: {Metoprolol, Gabapentin} ‚Üí {TARGET_ICD:OPIOID_DEPENDENCE}")
    print("  - Example: {99213: Office Visit, J0670: Morphine} ‚Üí {TARGET_ED:EMERGENCY_DEPT}")
    print("  ‚úÖ Drastically reduces rule count (only predictive patterns)")
    print("  ‚úÖ More actionable for BupaR (pathways to target)")
    print("  ‚úÖ Better for CatBoost (features that predict outcome)")


‚úì Min Support: 0.05
‚úì Min Confidence: 0.3
‚úì Item Types: ['drug_name', 'icd_code', 'cpt_code']
‚úì Max Workers: 4
‚úì Cohorts: ['opioid_ed', 'ed_non_opioid']
‚úì S3 Output: s3://pgxdatalake/gold/fpgrowth/cohort
‚úì Local Data: /home/pgx3874/pgx-analysis/data
‚úì Local Data Exists: True
‚úì Detailed logs ‚Üí /home/pgx3874/pgx-analysis/3_fpgrowth_analysis/cohort_fpgrowth_execution.log


## Memory Monitoring

Helper function to log memory usage at critical points.


In [None]:
def log_memory(logger, stage=""):
    """Log current memory usage."""
    try:
        mem = psutil.virtual_memory()
        mem_used_gb = mem.used / (1024**3)
        mem_total_gb = mem.total / (1024**3)
        mem_percent = mem.percent
        mem_avail_gb = mem.available / (1024**3)
        
        logger.info(f"[MEMORY {stage}] Used: {mem_used_gb:.1f} GB / {mem_total_gb:.1f} GB ({mem_percent:.1f}%) | Available: {mem_avail_gb:.1f} GB")
        
        # Warning if memory usage is high
        if mem_percent > 85:
            logger.warning(f"‚ö†Ô∏è  HIGH MEMORY USAGE: {mem_percent:.1f}% - May cause OOM!")
        
        return mem_percent
    except Exception as e:
        logger.error(f"Error getting memory info: {e}")
        return 0.0

print("‚úì Memory logging function defined")


## Step 1: Discover Available Cohorts

Scan local data to find all available cohort combinations.


In [4]:
def discover_cohorts(local_data_path, cohort_filter=None):
    """
    Discover all available cohort combinations from local data.
    """
    cohort_jobs = []
    
    for cohort_dir in local_data_path.glob("cohort_name=*"):
        cohort_name = cohort_dir.name.replace("cohort_name=", "")
        
        # Filter if specified
        if cohort_filter and cohort_name not in cohort_filter:
            continue
        
        for year_dir in cohort_dir.glob("event_year=*"):
            event_year = year_dir.name.replace("event_year=", "")
            
            for age_dir in year_dir.glob("age_band=*"):
                age_band = age_dir.name.replace("age_band=", "")
                
                # Check if cohort file exists
                cohort_file = age_dir / "cohort.parquet"
                if cohort_file.exists():
                    cohort_jobs.append({
                        'cohort': cohort_name,
                        'age_band': age_band,
                        'event_year': event_year,
                        'local_path': str(cohort_file)
                    })
    
    return cohort_jobs

# Discover available cohorts
cohort_jobs = discover_cohorts(LOCAL_DATA_PATH, cohort_filter=COHORTS_TO_PROCESS)

# Apply DRY_RUN limit if enabled
if DRY_RUN and len(cohort_jobs) > DRY_RUN_LIMIT:
    print(f"\n‚ö†Ô∏è  DRY RUN: Limiting from {len(cohort_jobs)} to {DRY_RUN_LIMIT} cohort combinations")
    cohort_jobs = cohort_jobs[:DRY_RUN_LIMIT]

print(f"\nüìä Discovered Cohorts:")
print(f"  Total combinations: {len(cohort_jobs)}")

# Group by cohort
cohort_counts = {}
for job in cohort_jobs:
    cohort_counts[job['cohort']] = cohort_counts.get(job['cohort'], 0) + 1

for cohort, count in cohort_counts.items():
    print(f"  {cohort}: {count} combinations")

print(f"\n  Sample jobs:")
for job in cohort_jobs[:5]:
    print(f"    {job['cohort']}/{job['age_band']}/{job['event_year']}")



üìä Discovered Cohorts:
  Total combinations: 45
  opioid_ed: 45 combinations

  Sample jobs:
    opioid_ed/13-24/2016
    opioid_ed/25-44/2016
    opioid_ed/45-54/2016
    opioid_ed/55-64/2016
    opioid_ed/65-74/2016


## Step 2: Define Cohort Processing Function

Create a function to process a single cohort for a specific item type with FP-Growth.


In [None]:
def process_single_cohort(job, item_type):
    """Process a single cohort for a specific item type with FP-Growth analysis."""
    cohort = job['cohort']
    age_band = job['age_band']
    event_year = job['event_year']
    local_path = job['local_path']
    
    cohort_logger = logging.getLogger(f"{cohort}_{age_band}_{event_year}_{item_type}")
    cohort_logger.setLevel(logging.INFO)
    
    try:
        start_time = time.time()
        cohort_logger.info(f"Processing {cohort}/{age_band}/{event_year} - {item_type}")
        log_memory(cohort_logger, "START")
        
        # Extract items based on type + TARGET MARKERS (for target-focused rules)
        # Simple in-memory connection (no AWS needed for local parquet reads)
        con = duckdb.connect(':memory:')
        con.sql("SET threads = 1")
        
        if item_type == 'drug_name':
            query = f"""
            SELECT mi_person_key, drug_name as item
            FROM read_parquet('{local_path}')
            WHERE drug_name IS NOT NULL AND drug_name != '' AND event_type = 'pharmacy'
            """
        elif item_type == 'icd_code':
            # For ICD codes: extract all diagnosis codes + mark target opioid codes
            query = f"""
            WITH all_icds AS (
                SELECT mi_person_key, primary_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE primary_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, two_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE two_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, three_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE three_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, four_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE four_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, five_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE five_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
            )
            SELECT mi_person_key, icd as item FROM all_icds WHERE icd != ''
            """
        elif item_type == 'cpt_code':
            query = f"""
            SELECT mi_person_key, procedure_code as item
            FROM read_parquet('{local_path}')
            WHERE procedure_code IS NOT NULL AND procedure_code != '' AND event_type = 'medical'
            """
        else:
            raise ValueError(f"Unknown item_type: {item_type}")
        
        df = con.execute(query).df()
        log_memory(cohort_logger, "After data extraction")
        
        # Add target markers if TARGET_FOCUSED mode is enabled
        if TARGET_FOCUSED:
            cohort_logger.info("Adding target markers...")
            
            # Get target information for each patient
            target_query = f"""
            SELECT DISTINCT 
                mi_person_key,
                primary_icd_diagnosis_code,
                hcg_line
            FROM read_parquet('{local_path}')
            WHERE mi_person_key IS NOT NULL
            """
            df_targets = con.execute(target_query).df()
            
            # Create target items for each patient
            target_items = []
            for _, row in df_targets.iterrows():
                patient_id = row['mi_person_key']
                # Check for opioid ICD codes
                if pd.notna(row['primary_icd_diagnosis_code']) and any(
                    row['primary_icd_diagnosis_code'].startswith(code.replace('.', '')) 
                    for code in TARGET_ICD_CODES
                ):
                    target_items.append({'mi_person_key': patient_id, 'item': 'TARGET_ICD:OPIOID_DEPENDENCE'})
                
                # Check for ED visits (HCG Line - correct field!)
                if pd.notna(row['hcg_line']) and row['hcg_line'] in TARGET_HCG_LINES:
                    target_items.append({'mi_person_key': patient_id, 'item': 'TARGET_ED:EMERGENCY_DEPT'})
            
            if target_items:
                df_targets_items = pd.DataFrame(target_items)
                df = pd.concat([df, df_targets_items], ignore_index=True)
                cohort_logger.info(f"Added {len(target_items)} target markers")
                log_memory(cohort_logger, "After target markers")
        
        con.close()
        
        if df.empty:
            cohort_logger.warning(f"No {item_type} data for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No data")
        
        # Create transactions (group items by patient)
        cohort_logger.info(f"Building transactions from {len(df)} rows...")
        transactions = (
            df.groupby('mi_person_key')['item']
            .apply(lambda x: sorted(set(x.tolist())))
            .tolist()
        )
        
        if not transactions:
            cohort_logger.warning(f"No valid transactions for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No transactions")
        
        # Encode transactions
        cohort_logger.info(f"Encoding {len(transactions)} transactions...")
        te = TransactionEncoder()
        te_ary = te.fit(transactions).transform(transactions)
        df_encoded = pd.DataFrame(te_ary, columns=te.columns_)
        log_memory(cohort_logger, "After encoding")
        
        # Run FP-Growth (use higher support for CPT to prevent rule explosion)
        min_sup = MIN_SUPPORT_CPT if item_type == 'cpt_code' else MIN_SUPPORT
        cohort_logger.info(f"Running FP-Growth (min_support={min_sup})...")
        itemsets = fpgrowth(df_encoded, min_support=min_sup, use_colnames=True)
        itemsets = itemsets.sort_values('support', ascending=False).reset_index(drop=True)
        
        if itemsets.empty:
            cohort_logger.warning(f"No itemsets found for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No itemsets")
        
        cohort_logger.info(f"Found {len(itemsets)} itemsets")
        log_memory(cohort_logger, "After FP-Growth")
        
        # Generate rules (with appropriate thresholds and limits)
        min_conf = MIN_CONFIDENCE_CPT if item_type == 'cpt_code' else MIN_CONFIDENCE
        cohort_logger.info(f"Generating rules (min_confidence={min_conf})...")
        
        try:
            all_rules = association_rules(itemsets, metric="confidence", min_threshold=min_conf)
            
            if len(all_rules) > 0:
                # Split rules: target-predicting vs control (non-target)
                if TARGET_FOCUSED:
                    # Target rules: consequent contains target marker
                    target_mask = all_rules['consequents'].apply(
                        lambda x: any(item.startswith(tuple(TARGET_PREFIXES)) for item in x)
                    )
                    rules_target = all_rules[target_mask].copy()
                    rules_control = all_rules[~target_mask].copy()
                    
                    cohort_logger.info(f"Split: {len(rules_target)} target rules, {len(rules_control)} control rules")
                    
                    # Limit both sets to top N by lift
                    if len(rules_target) > 0:
                        rules_target = rules_target.sort_values('lift', ascending=False)
                        if len(rules_target) > MAX_RULES_PER_COHORT:
                            cohort_logger.info(f"Keeping top {MAX_RULES_PER_COHORT} target rules (from {len(rules_target)})")
                            rules_target = rules_target.head(MAX_RULES_PER_COHORT)
                        rules_target = rules_target.reset_index(drop=True)
                    
                    if len(rules_control) > 0:
                        rules_control = rules_control.sort_values('lift', ascending=False)
                        if len(rules_control) > MAX_RULES_PER_COHORT:
                            cohort_logger.info(f"Keeping top {MAX_RULES_PER_COHORT} control rules (from {len(rules_control)})")
                            rules_control = rules_control.head(MAX_RULES_PER_COHORT)
                        rules_control = rules_control.reset_index(drop=True)
                    
                    # Keep target rules as main 'rules' for backward compatibility
                    rules = rules_target
                else:
                    # Not target-focused: all rules are kept
                    rules = all_rules.sort_values('lift', ascending=False).head(MAX_RULES_PER_COHORT).reset_index(drop=True)
                    rules_control = pd.DataFrame()
                
                cohort_logger.info(f"Final: {len(rules)} target rules, {len(rules_control)} control rules")
                log_memory(cohort_logger, "After rule generation")
            else:
                cohort_logger.info(f"No rules met confidence threshold of {min_conf}")
                rules = pd.DataFrame()
                rules_control = pd.DataFrame()
                
        except MemoryError as e:
            cohort_logger.error(f"MemoryError during rule generation - skipping rules")
            rules = pd.DataFrame()
            rules_control = pd.DataFrame()
        except Exception as e:
            cohort_logger.error(f"Error generating rules: {e}")
            rules = pd.DataFrame()
            rules_control = pd.DataFrame()
        
        # Convert frozensets for JSON
        itemsets_json = itemsets.copy()
        itemsets_json['itemsets'] = itemsets_json['itemsets'].apply(list)
        
        # Prepare rules for saving (split target rules by type, plus control)
        rules_by_target = {}
        
        # Process target rules (split by ICD vs ED)
        if not rules.empty:
            rules_json = rules.copy()
            rules_json['antecedents'] = rules_json['antecedents'].apply(list)
            rules_json['consequents'] = rules_json['consequents'].apply(list)
            
            # Split target rules by outcome type
            rules_by_target['TARGET_ICD'] = rules_json[
                rules_json['consequents'].apply(lambda x: any('TARGET_ICD:' in str(item) for item in x))
            ]
            rules_by_target['TARGET_ED'] = rules_json[
                rules_json['consequents'].apply(lambda x: any('TARGET_ED:' in str(item) for item in x))
            ]
        
        # Process control rules (non-target patterns)
        if not rules_control.empty:
            rules_control_json = rules_control.copy()
            rules_control_json['antecedents'] = rules_control_json['antecedents'].apply(list)
            rules_control_json['consequents'] = rules_control_json['consequents'].apply(list)
            rules_by_target['CONTROL'] = rules_control_json
        
        if rules_by_target:
            cohort_logger.info(f"Prepared for S3: {len(rules_by_target.get('TARGET_ICD', pd.DataFrame()))} ICD, "
                             f"{len(rules_by_target.get('TARGET_ED', pd.DataFrame()))} ED, "
                             f"{len(rules_by_target.get('CONTROL', pd.DataFrame()))} control")
        
        # Save to S3 (with retry logic for reliability)
        s3_base = f"{S3_OUTPUT_BASE}/{item_type}/cohort_name={cohort}/age_band={age_band}/event_year={event_year}"
        
        cohort_logger.info(f"Saving results to S3...")
        max_retries = 3
        for attempt in range(max_retries):
            try:
                itemsets_path = f"{s3_base}/itemsets.json"
                save_to_s3_json(itemsets_json.to_dict(orient='records'), itemsets_path)
                
                # Save rules by target type (separate files)
                if rules_by_target:
                    for target_type, target_rules in rules_by_target.items():
                        if not target_rules.empty:
                            rules_path = f"{s3_base}/rules_{target_type}.json"
                            save_to_s3_json(target_rules.to_dict(orient='records'), rules_path)
                            cohort_logger.info(f"Saved {len(target_rules)} {target_type} rules")
                
                summary = {
                    'timestamp': datetime.now().isoformat(),
                    'cohort': cohort, 'age_band': age_band, 'event_year': event_year,
                    'item_type': item_type,
                    'total_patients': len(transactions),
                    'total_itemsets': len(itemsets),
                    'total_rules': len(rules),
                    'rules_by_target': {
                        'TARGET_ICD': len(rules_by_target.get('TARGET_ICD', pd.DataFrame())),
                        'TARGET_ED': len(rules_by_target.get('TARGET_ED', pd.DataFrame())),
                        'CONTROL': len(rules_by_target.get('CONTROL', pd.DataFrame()))
                    } if rules_by_target else {'TARGET_ICD': 0, 'TARGET_ED': 0, 'CONTROL': 0},
                    'min_support': min_sup,
                    'min_confidence': min_conf,
                    'max_rules_limit': MAX_RULES_PER_COHORT,
                    'rules_truncated': len(rules) == MAX_RULES_PER_COHORT,
                    'target_focused': TARGET_FOCUSED,
                    'target_icd_codes': TARGET_ICD_CODES if TARGET_FOCUSED else None,
                    'target_hcg_lines': TARGET_HCG_LINES if TARGET_FOCUSED else None
                }
                summary_path = f"{s3_base}/summary.json"
                save_to_s3_json(summary, summary_path)
                
                cohort_logger.info(f"‚úì Saved to S3 successfully")
                break  # Success - exit retry loop
                
            except Exception as s3_error:
                if attempt < max_retries - 1:
                    cohort_logger.warning(f"S3 upload attempt {attempt+1} failed: {s3_error}, retrying...")
                    time.sleep(2 ** attempt)  # Exponential backoff
                else:
                    cohort_logger.error(f"S3 upload failed after {max_retries} attempts: {s3_error}")
                    raise
        
        elapsed = time.time() - start_time
        log_memory(cohort_logger, "END")
        cohort_logger.info(f"‚úì Completed in {elapsed:.1f}s")
        
        return (cohort, age_band, event_year, item_type, True, f"{len(itemsets)} itemsets, {len(rules)} rules")
        
    except Exception as e:
        cohort_logger.error(f"Error: {e}")
        return (cohort, age_band, event_year, item_type, False, str(e))

print("‚úì Cohort processing function defined")


‚úì Cohort processing function defined


## Step 3: Process Cohorts in Parallel

Run FP-Growth for all cohort combinations using parallel processing.


In [None]:
print("="*80)
print("COHORT FPGROWTH ANALYSIS - START")
print("="*80)
print(f"Cohorts: {len(cohort_jobs)} combinations")
print(f"Item types: {ITEM_TYPES}")
print(f"Total jobs: {len(cohort_jobs) * len(ITEM_TYPES)}")
print(f"Max workers: {MAX_WORKERS}")
print(f"Detailed progress ‚Üí Check log file")
print()

logger.info(f"\n{'='*80}")
logger.info(f"COHORT FPGROWTH ANALYSIS - START")
logger.info(f"{'='*80}")
logger.info(f"Cohorts: {len(cohort_jobs)} combinations")
logger.info(f"Item types: {ITEM_TYPES}")
logger.info(f"Total jobs: {len(cohort_jobs) * len(ITEM_TYPES)}")

# Helper function to check if cohort results exist in S3
def check_cohort_exists(item_type: str, cohort: str, age_band: str, event_year: str) -> bool:
    """Check if cohort results already exist in S3 (by checking for summary.json)."""
    s3 = boto3.client('s3')
    key = f"gold/fpgrowth/cohort/{item_type}/cohort_name={cohort}/age_band={age_band}/event_year={event_year}/summary.json"
    try:
        s3.head_object(Bucket='pgxdatalake', Key=key)
        return True
    except:
        return False

start_time = time.time()
results = []
completed = 0
failed = 0
skipped = 0

# Create all combinations of cohorts and item types
all_jobs_initial = [(job, item_type) for job in cohort_jobs for item_type in ITEM_TYPES]

# Filter out already-completed jobs
print("\nChecking for existing results in S3...")
all_jobs = []
for job, item_type in all_jobs_initial:
    if check_cohort_exists(item_type, job['cohort'], job['age_band'], job['event_year']):
        logger.info(f"Skipping {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type} - already exists")
        skipped += 1
        results.append({
            'cohort': job['cohort'],
            'age_band': job['age_band'],
            'event_year': job['event_year'],
            'item_type': item_type,
            'success': True,
            'message': 'Already exists in S3 (skipped)'
        })
    else:
        all_jobs.append((job, item_type))

total_jobs = len(all_jobs_initial)
print(f"Total jobs: {total_jobs}")
print(f"Already completed: {skipped}")
print(f"To process: {len(all_jobs)}")
print()

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submit all jobs
    future_to_params = {executor.submit(process_single_cohort, job, item_type): (job, item_type) 
                        for job, item_type in all_jobs}
    
    # Process results as they complete
    for future in as_completed(future_to_params):
        job, item_type = future_to_params[future]
        try:
            cohort, age_band, event_year, item_type, success, message = future.result()
            results.append({
                'cohort': cohort,
                'age_band': age_band,
                'event_year': event_year,
                'item_type': item_type,
                'success': success,
                'message': message
            })
            
            if success:
                completed += 1
                logger.info(f"[{completed + failed}/{len(all_jobs)}] ‚úì {cohort}/{age_band}/{event_year}/{item_type}: {message}")
                # Print every 10 successes or milestones
                if completed % 10 == 0 or (completed + failed) == len(all_jobs):
                    print(f"Progress: {completed}/{len(all_jobs)} completed ({completed/len(all_jobs)*100:.1f}%), {failed} failed, {skipped} skipped")
            else:
                failed += 1
                logger.warning(f"[{completed + failed}/{total_jobs}] ‚úó {cohort}/{age_band}/{event_year}/{item_type}: {message}")
                print(f"‚ö† Failed: {cohort}/{age_band}/{event_year}/{item_type}")
                
        except Exception as e:
            failed += 1
            logger.error(f"[{completed + failed}/{total_jobs}] ‚úó {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type}: {e}")
            print(f"‚ö† Error: {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type}")
            results.append({
                'cohort': job['cohort'],
                'age_band': job['age_band'],
                'event_year': job['event_year'],
                'item_type': item_type,
                'success': False,
                'message': str(e)
            })

elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"COHORT FPGROWTH ANALYSIS - COMPLETE")
print(f"{'='*80}")
print(f"  Total jobs: {total_jobs}")
print(f"  Already in S3 (skipped): {skipped}")
print(f"  Newly processed: {completed}")
print(f"  Failed: {failed}")
print(f"  Overall success rate: {(skipped + completed)/total_jobs*100:.1f}%")
print(f"  Total time: {elapsed:.1f}s ({elapsed/60:.1f}min)")
if len(all_jobs) > 0:
    print(f"  Avg time per new job: {elapsed/len(all_jobs):.1f}s")
else:
    print(f"  (No new jobs processed - all results already in S3)")


COHORT FPGROWTH ANALYSIS - START
Cohorts: 45 combinations
Item types: ['drug_name', 'icd_code', 'cpt_code']
Total jobs: 135
Max workers: 4
Detailed progress ‚Üí Check log file

Progress: 10/135 completed (7.4%), 0 failed
Progress: 20/135 completed (14.8%), 0 failed


## Step 4: Analyze Results

Review processing results and identify any issues.


In [None]:
# Convert results to DataFrame for analysis
results_df = pd.DataFrame(results)

print("\nüìä Results by Cohort and Item Type:")
summary = results_df.groupby(['cohort', 'item_type'])['success'].agg([
    ('total', 'count'), 
    ('successful', 'sum'),
    ('success_rate', lambda x: f"{x.mean()*100:.1f}%")
])
print(summary)

print("\nüìä Results by Item Type:")
item_summary = results_df.groupby('item_type')['success'].agg([
    ('total', 'count'), 
    ('successful', 'sum'),
    ('success_rate', lambda x: f"{x.mean()*100:.1f}%")
])
print(item_summary)

print("\n‚ùå Failed Jobs:")
failed_df = results_df[~results_df['success']]
if not failed_df.empty:
    print(failed_df[['cohort', 'age_band', 'event_year', 'item_type', 'message']].to_string())
else:
    print("  None! All jobs completed successfully.")

print("\n‚úì Successful Jobs Sample:")
success_df = results_df[results_df['success']]
if not success_df.empty:
    print(success_df[['cohort', 'age_band', 'event_year', 'item_type', 'message']].head(15))
else:
    print("  No successful jobs.")


## Summary and Next Steps


In [None]:
print("="*80)
print("COHORT FPGROWTH ANALYSIS - SUMMARY")
print("="*80)

print(f"\nüìä Processing Statistics:")
print(f"  Cohort combinations: {len(cohort_jobs)}")
print(f"  Item types: {len(ITEM_TYPES)} (drug_name, icd_code, cpt_code)")
print(f"  Total jobs: {total_jobs}")
print(f"  Successfully processed: {completed}")
print(f"  Failed: {failed}")
print(f"  Success rate: {completed/total_jobs*100:.1f}%")
print(f"  Processing time: {elapsed:.1f}s ({elapsed/60:.1f}min)")

print(f"\nüîç FP-Growth Configuration:")
print(f"  Min support: {MIN_SUPPORT} ({MIN_SUPPORT*100:.1f}%)")
print(f"  Min confidence: {MIN_CONFIDENCE} ({MIN_CONFIDENCE*100:.1f}%)")
print(f"  Parallel workers: {MAX_WORKERS}")

print(f"\nüíæ Output Location:")
print(f"  S3 Base: {S3_OUTPUT_BASE}")
print(f"  Structure: <item_type>/cohort_name=<name>/age_band=<band>/event_year=<year>/")
print(f"  Item types:")
print(f"    - drug_name/ (pharmacy events)")
print(f"    - icd_code/ (diagnosis codes)")
print(f"    - cpt_code/ (procedure codes)")
print(f"  Files per cohort:")
print(f"    - itemsets.json (frequent itemsets)")
print(f"    - rules.json (association rules)")
print(f"    - summary.json (metadata)")

print(f"\nüéØ Next Steps:")
print(f"  1. Load cohort-specific itemsets for BupaR process mining")
print(f"  2. Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts across item types")
print(f"  3. Use association rules for pathway analysis")
print(f"  4. Filter features for cohort-specific CatBoost models")
print(f"  5. Create network visualizations for cohort-specific patterns")

print(f"\nüìù Example Usage:")
print(f"  # Load drug patterns for a specific cohort")
print(f"  from helpers_1997_13.s3_utils import load_from_s3_json")
print(f"  itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/drug_name/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")
print(f"  # Load ICD patterns for same cohort")
print(f"  icd_itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/icd_code/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")

print(f"\n‚úì Analysis complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)


## 5. Auto-Shutdown EC2 Instance (Optional)

Set `SHUTDOWN_EC2 = True` to automatically stop the EC2 instance after analysis completes.

**Note:** This is a **STOP** (not terminate), so you can restart the instance later.

In [None]:
# =============================================================================
# EC2 AUTO-SHUTDOWN (OPTIONAL)
# =============================================================================
# Set SHUTDOWN_EC2 = True to enable, False to disable
SHUTDOWN_EC2 = False  # Change to True to enable auto-shutdown

if SHUTDOWN_EC2:
    print("\n" + "="*80)
    print("Shutting down EC2 instance...")
    print("="*80)
    
    import subprocess
    import requests
    import shutil
    
    # Get instance ID from EC2 metadata service
    try:
        response = requests.get(
            "http://169.254.169.254/latest/meta-data/instance-id",
            timeout=2
        )
        if response.status_code == 200:
            instance_id = response.text.strip()
            print(f"Instance ID: {instance_id}")
            
            # Find AWS CLI
            aws_cmd = shutil.which("aws")
            if not aws_cmd:
                # Try common paths
                for path in ["/usr/local/bin/aws", "/usr/bin/aws", 
                           "/home/ec2-user/.local/bin/aws", 
                           "/home/ubuntu/.local/bin/aws",
                           "/home/pgx3874/.local/bin/aws"]:
                    if Path(path).exists():
                        aws_cmd = path
                        break
            
            if aws_cmd:
                # Stop the instance (use terminate-instances for permanent deletion)
                shutdown_cmd = [aws_cmd, "ec2", "stop-instances", "--instance-ids", instance_id]
                
                print(f"Running: {' '.join(shutdown_cmd)}")
                result = subprocess.run(shutdown_cmd, capture_output=True, text=True)
                
                if result.returncode == 0:
                    print("‚úì EC2 instance stop command sent successfully")
                    print("Instance will stop in a few moments.")
                    print("Note: This is a STOP (not terminate), so you can restart it later.")
                    if result.stdout:
                        print(f"\nAWS Response:\n{result.stdout}")
                else:
                    print(f"‚úó EC2 stop command failed with exit code {result.returncode}")
                    if result.stderr:
                        print(f"Error: {result.stderr}")
                    print("Check AWS credentials and IAM permissions.")
            else:
                print("‚úó AWS CLI not found. Cannot shutdown instance.")
                print("Install AWS CLI or ensure it's in your PATH.")
                print("Manual shutdown: aws ec2 stop-instances --instance-ids " + instance_id)
        else:
            print(f"‚úó Metadata service returned status code {response.status_code}")
            print("Could not retrieve instance ID.")
    
    except requests.exceptions.RequestException as e:
        print("‚úó Could not retrieve instance ID from metadata service.")
        print(f"Error: {e}")
        print("If running on EC2, check that metadata service is accessible.")
        print("\nManual shutdown command:")
        print("  aws ec2 stop-instances --instance-ids <your-instance-id>")
    
    except Exception as e:
        print(f"‚úó Unexpected error during shutdown: {e}")

else:
    print("\n" + "="*80)
    print("EC2 Auto-Shutdown: DISABLED")
    print("="*80)
    print("To enable auto-shutdown, set SHUTDOWN_EC2 = True in this cell.")
    print("Instance will continue running.")
    print("\nTo manually stop this instance later:")
    print("  aws ec2 stop-instances --instance-ids $(ec2-metadata --instance-id | cut -d ' ' -f 2)")
    print("Or use AWS Console: EC2 > Instances > Select instance > Instance State > Stop")
