# Cohort-Specific FPGrowth Feature Importance Analysis

**Purpose:** Cohort-level frequent pattern mining for process mining and comparative analysis  
**Updated:** November 23, 2025  
**Hardware:** Optimized for EC2 (32 cores, 1TB RAM)  
**Output:** `s3://pgxdatalake/gold/fpgrowth/cohort/{item_type}/cohort_name={cohort}/...`

## Key Features

‚úÖ **Three Item Types** - Drugs, ICD codes, CPT codes  
‚úÖ **Cohort-Specific Patterns** - Discovers patterns unique to each cohort  
‚úÖ **Parallel Processing** - Processes multiple cohorts simultaneously  
‚úÖ **BupaR Integration** - Outputs ready for process mining workflows  
‚úÖ **Comparative Analysis** - Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts

## Methodology

For each combination of (cohort, age_band, event_year, item_type):
1. Extract items from cohort-specific data
2. Create patient-level transactions
3. Encode transactions into binary matrix
4. Run FP-Growth to find frequent itemsets
5. Generate association rules
6. Save results to S3 in organized structure

## Key Differences from Global Analysis

| Aspect | Global FPGrowth | Cohort FPGrowth |
|--------|-----------------|-----------------|
| **Scope** | All patients (~5.7M) | Individual cohorts (~10K-100K) |
| **Purpose** | Universal ML features | Process mining patterns |
| **Support Threshold** | 0.01 (1%) | 0.05 (5%) |
| **Output** | `global/{item_type}/` | `cohort/{item_type}/cohort_name={c}/...` |
| **Use Case** | CatBoost consistency | BupaR pathway analysis |
| **Parallelization** | Sequential by item type | Parallel by cohort |

## Expected Runtime (EC2: 32 cores, 1TB RAM)

- **Cohorts**: 2 (opioid_ed, ed_non_opioid)
- **Age bands √ó Years**: ~100 combinations per cohort
- **Item types**: 3 (drug_name, icd_code, cpt_code)
- **Total jobs**: ~600 combinations
- **Avg time per job**: ~1-2 minutes
- **Total runtime**: ~2-4 hours (with MAX_WORKERS=4)

## S3 Output Structure

```
s3://pgxdatalake/gold/fpgrowth/cohort/
‚îú‚îÄ‚îÄ drug_name/
‚îÇ   ‚îú‚îÄ‚îÄ cohort_name=opioid_ed/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ age_band=65-74/event_year=2020/
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ itemsets.json
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ rules.json
‚îÇ   ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ summary.json
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îÇ   ‚îî‚îÄ‚îÄ cohort_name=ed_non_opioid/...
‚îú‚îÄ‚îÄ icd_code/
‚îÇ   ‚îî‚îÄ‚îÄ (same structure)
‚îî‚îÄ‚îÄ cpt_code/
    ‚îî‚îÄ‚îÄ (same structure)
```

---


## Setup and Imports


In [None]:
import os
import sys
import json
import pandas as pd
import numpy as np
from datetime import datetime
import logging
from pathlib import Path
import time
from concurrent.futures import ProcessPoolExecutor, as_completed

# MLxtend for FP-Growth
from mlxtend.frequent_patterns import fpgrowth, association_rules
from mlxtend.preprocessing import TransactionEncoder

# Project root
project_root = Path.cwd().parent if Path.cwd().name == '3_fpgrowth_analysis' else Path.cwd()
sys.path.insert(0, str(project_root))

# Project utilities
from helpers_1997_13.common_imports import s3_client, S3_BUCKET
from helpers_1997_13.duckdb_utils import get_duckdb_connection
from helpers_1997_13.s3_utils import save_to_s3_json, save_to_s3_parquet, get_cohort_parquet_path
from helpers_1997_13.fpgrowth_utils import run_fpgrowth_drug_token_with_fallback, convert_frozensets
from helpers_1997_13.visualization_utils import create_network_visualization
from helpers_1997_13.constants import AGE_BANDS, EVENT_YEARS

print(f"‚úì Project root: {project_root}")
print(f"‚úì All imports successful")
print(f"‚úì Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


In [None]:
# =============================================================================
# EC2 CONFIGURATION (32 cores, 1TB RAM)
# =============================================================================

# FP-Growth parameters (higher threshold for cohort-specific patterns)
MIN_SUPPORT = 0.05       # 5% support (items must appear in 5% of patients within cohort)
MIN_CONFIDENCE = 0.3     # 30% confidence for association rules

# Item types to process
ITEM_TYPES = ['drug_name', 'icd_code', 'cpt_code']

# Processing parameters
MAX_WORKERS = 4  # Parallel workers (adjust based on available cores)
COHORTS_TO_PROCESS = ['opioid_ed', 'ed_non_opioid']  # Specify cohorts to process

# Paths
S3_OUTPUT_BASE = f"s3://{S3_BUCKET}/gold/fpgrowth/cohort"
LOCAL_DATA_PATH = project_root / "data" / "gold" / "cohorts_F1120"

# Setup logger with file output (prevents Jupyter rate limit issues)
logger = logging.getLogger('cohort_fpgrowth')
logger.setLevel(logging.INFO)
logger.handlers.clear()  # Clear any existing handlers

# File handler - full logs to file
log_file = project_root / "3_fpgrowth_analysis" / "cohort_fpgrowth_execution.log"
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)

# Console handler - only major milestones (prevents Jupyter rate limit)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.WARNING)  # Only warnings/errors to console
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(console_handler)

print(f"‚úì Min Support: {MIN_SUPPORT}")
print(f"‚úì Min Confidence: {MIN_CONFIDENCE}")
print(f"‚úì Item Types: {ITEM_TYPES}")
print(f"‚úì Max Workers: {MAX_WORKERS}")
print(f"‚úì Cohorts: {COHORTS_TO_PROCESS}")
print(f"‚úì S3 Output: {S3_OUTPUT_BASE}")
print(f"‚úì Local Data: {LOCAL_DATA_PATH}")
print(f"‚úì Local Data Exists: {LOCAL_DATA_PATH.exists()}")
print(f"‚úì Detailed logs ‚Üí {log_file}")
print(f"‚úì Console output: WARNING level only (check log file for progress)")


## Step 1: Discover Available Cohorts

Scan local data to find all available cohort combinations.


In [None]:
def discover_cohorts(local_data_path, cohort_filter=None):
    """
    Discover all available cohort combinations from local data.
    """
    cohort_jobs = []
    
    for cohort_dir in local_data_path.glob("cohort_name=*"):
        cohort_name = cohort_dir.name.replace("cohort_name=", "")
        
        # Filter if specified
        if cohort_filter and cohort_name not in cohort_filter:
            continue
        
        for year_dir in cohort_dir.glob("event_year=*"):
            event_year = year_dir.name.replace("event_year=", "")
            
            for age_dir in year_dir.glob("age_band=*"):
                age_band = age_dir.name.replace("age_band=", "")
                
                # Check if cohort file exists
                cohort_file = age_dir / "cohort.parquet"
                if cohort_file.exists():
                    cohort_jobs.append({
                        'cohort': cohort_name,
                        'age_band': age_band,
                        'event_year': event_year,
                        'local_path': str(cohort_file)
                    })
    
    return cohort_jobs

# Discover available cohorts
cohort_jobs = discover_cohorts(LOCAL_DATA_PATH, cohort_filter=COHORTS_TO_PROCESS)

print(f"\nüìä Discovered Cohorts:")
print(f"  Total combinations: {len(cohort_jobs)}")

# Group by cohort
cohort_counts = {}
for job in cohort_jobs:
    cohort_counts[job['cohort']] = cohort_counts.get(job['cohort'], 0) + 1

for cohort, count in cohort_counts.items():
    print(f"  {cohort}: {count} combinations")

print(f"\n  Sample jobs:")
for job in cohort_jobs[:5]:
    print(f"    {job['cohort']}/{job['age_band']}/{job['event_year']}")


## Step 2: Define Cohort Processing Function

Create a function to process a single cohort for a specific item type with FP-Growth.


In [None]:
def process_single_cohort(job, item_type):
    """Process a single cohort for a specific item type with FP-Growth analysis."""
    cohort = job['cohort']
    age_band = job['age_band']
    event_year = job['event_year']
    local_path = job['local_path']
    
    cohort_logger = logging.getLogger(f"{cohort}_{age_band}_{event_year}_{item_type}")
    cohort_logger.setLevel(logging.INFO)
    
    try:
        start_time = time.time()
        cohort_logger.info(f"Processing {cohort}/{age_band}/{event_year} - {item_type}")
        
        # Extract items based on type
        con = get_duckdb_connection(logger=cohort_logger)
        
        if item_type == 'drug_name':
            query = f"""
            SELECT mi_person_key, drug_name as item
            FROM read_parquet('{local_path}')
            WHERE drug_name IS NOT NULL AND drug_name != '' AND event_type = 'pharmacy'
            """
        elif item_type == 'icd_code':
            query = f"""
            WITH all_icds AS (
                SELECT mi_person_key, primary_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE primary_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, two_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE two_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, three_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE three_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, four_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE four_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
                UNION ALL
                SELECT mi_person_key, five_icd_diagnosis_code as icd FROM read_parquet('{local_path}') 
                WHERE five_icd_diagnosis_code IS NOT NULL AND event_type = 'medical'
            )
            SELECT mi_person_key, icd as item FROM all_icds WHERE icd != ''
            """
        elif item_type == 'cpt_code':
            query = f"""
            SELECT mi_person_key, procedure_code as item
            FROM read_parquet('{local_path}')
            WHERE procedure_code IS NOT NULL AND procedure_code != '' AND event_type = 'medical'
            """
        else:
            raise ValueError(f"Unknown item_type: {item_type}")
        
        df = con.execute(query).df()
        con.close()
        
        if df.empty:
            cohort_logger.warning(f"No {item_type} data for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No data")
        
        # Create transactions
        cohort_logger.info(f"Building {len(df)} transactions...")
        transactions = (
            df.groupby('mi_person_key')['item']
            .apply(lambda x: sorted(set(x.tolist())))
            .tolist()
        )
        
        if not transactions:
            cohort_logger.warning(f"No valid transactions for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No transactions")
        
        # Encode transactions
        cohort_logger.info(f"Encoding {len(transactions)} transactions...")
        te = TransactionEncoder()
        te_ary = te.fit(transactions).transform(transactions)
        df_encoded = pd.DataFrame(te_ary, columns=te.columns_)
        
        # Run FP-Growth
        cohort_logger.info(f"Running FP-Growth (min_support={MIN_SUPPORT})...")
        itemsets = fpgrowth(df_encoded, min_support=MIN_SUPPORT, use_colnames=True)
        itemsets = itemsets.sort_values('support', ascending=False).reset_index(drop=True)
        
        if itemsets.empty:
            cohort_logger.warning(f"No itemsets found for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, item_type, False, "No itemsets")
        
        # Generate rules
        cohort_logger.info(f"Generating association rules...")
        rules = association_rules(itemsets, metric="confidence", min_threshold=MIN_CONFIDENCE)
        rules = rules.sort_values('lift', ascending=False).reset_index(drop=True)
        
        # Convert frozensets for JSON
        itemsets_json = itemsets.copy()
        itemsets_json['itemsets'] = itemsets_json['itemsets'].apply(list)
        
        rules_json = rules.copy() if not rules.empty else pd.DataFrame()
        if not rules_json.empty:
            rules_json['antecedents'] = rules_json['antecedents'].apply(list)
            rules_json['consequents'] = rules_json['consequents'].apply(list)
        
        # Save to S3
        s3_base = f"{S3_OUTPUT_BASE}/{item_type}/cohort_name={cohort}/age_band={age_band}/event_year={event_year}"
        
        itemsets_path = f"{s3_base}/itemsets.json"
        save_to_s3_json(itemsets_json.to_dict(orient='records'), itemsets_path)
        
        if not rules_json.empty:
            rules_path = f"{s3_base}/rules.json"
            save_to_s3_json(rules_json.to_dict(orient='records'), rules_path)
        
        summary = {
            'timestamp': datetime.now().isoformat(),
            'cohort': cohort, 'age_band': age_band, 'event_year': event_year,
            'item_type': item_type,
            'total_patients': len(transactions),
            'total_itemsets': len(itemsets),
            'total_rules': len(rules),
            'min_support': MIN_SUPPORT,
            'min_confidence': MIN_CONFIDENCE
        }
        summary_path = f"{s3_base}/summary.json"
        save_to_s3_json(summary, summary_path)
        
        elapsed = time.time() - start_time
        cohort_logger.info(f"‚úì Completed in {elapsed:.1f}s")
        
        return (cohort, age_band, event_year, item_type, True, f"{len(itemsets)} itemsets, {len(rules)} rules")
        
    except Exception as e:
        cohort_logger.error(f"Error: {e}")
        return (cohort, age_band, event_year, item_type, False, str(e))

print("‚úì Cohort processing function defined")


## Step 3: Process Cohorts in Parallel

Run FP-Growth for all cohort combinations using parallel processing.


In [None]:
print("="*80)
print("COHORT FPGROWTH ANALYSIS - START")
print("="*80)
print(f"Cohorts: {len(cohort_jobs)} combinations")
print(f"Item types: {ITEM_TYPES}")
print(f"Total jobs: {len(cohort_jobs) * len(ITEM_TYPES)}")
print(f"Max workers: {MAX_WORKERS}")
print(f"Detailed progress ‚Üí Check log file")
print()

logger.info(f"\n{'='*80}")
logger.info(f"COHORT FPGROWTH ANALYSIS - START")
logger.info(f"{'='*80}")
logger.info(f"Cohorts: {len(cohort_jobs)} combinations")
logger.info(f"Item types: {ITEM_TYPES}")
logger.info(f"Total jobs: {len(cohort_jobs) * len(ITEM_TYPES)}")

start_time = time.time()
results = []
completed = 0
failed = 0

# Create all combinations of cohorts and item types
all_jobs = [(job, item_type) for job in cohort_jobs for item_type in ITEM_TYPES]
total_jobs = len(all_jobs)

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submit all jobs
    future_to_params = {executor.submit(process_single_cohort, job, item_type): (job, item_type) 
                        for job, item_type in all_jobs}
    
    # Process results as they complete
    for future in as_completed(future_to_params):
        job, item_type = future_to_params[future]
        try:
            cohort, age_band, event_year, item_type, success, message = future.result()
            results.append({
                'cohort': cohort,
                'age_band': age_band,
                'event_year': event_year,
                'item_type': item_type,
                'success': success,
                'message': message
            })
            
            if success:
                completed += 1
                logger.info(f"[{completed + failed}/{total_jobs}] ‚úì {cohort}/{age_band}/{event_year}/{item_type}: {message}")
                # Print every 10 successes or milestones
                if completed % 10 == 0 or completed == total_jobs:
                    print(f"Progress: {completed}/{total_jobs} completed ({completed/total_jobs*100:.1f}%), {failed} failed")
            else:
                failed += 1
                logger.warning(f"[{completed + failed}/{total_jobs}] ‚úó {cohort}/{age_band}/{event_year}/{item_type}: {message}")
                print(f"‚ö† Failed: {cohort}/{age_band}/{event_year}/{item_type}")
                
        except Exception as e:
            failed += 1
            logger.error(f"[{completed + failed}/{total_jobs}] ‚úó {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type}: {e}")
            print(f"‚ö† Error: {job['cohort']}/{job['age_band']}/{job['event_year']}/{item_type}")
            results.append({
                'cohort': job['cohort'],
                'age_band': job['age_band'],
                'event_year': job['event_year'],
                'item_type': item_type,
                'success': False,
                'message': str(e)
            })

elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"COHORT FPGROWTH ANALYSIS - COMPLETE")
print(f"{'='*80}")
print(f"  Total jobs: {total_jobs}")
print(f"  Successful: {completed}")
print(f"  Failed: {failed}")
print(f"  Success rate: {completed/total_jobs*100:.1f}%")
print(f"  Total time: {elapsed:.1f}s ({elapsed/60:.1f}min)")
print(f"  Avg time per job: {elapsed/total_jobs:.1f}s")


## Step 4: Analyze Results

Review processing results and identify any issues.


In [None]:
# Convert results to DataFrame for analysis
results_df = pd.DataFrame(results)

print("\nüìä Results by Cohort and Item Type:")
summary = results_df.groupby(['cohort', 'item_type'])['success'].agg([
    ('total', 'count'), 
    ('successful', 'sum'),
    ('success_rate', lambda x: f"{x.mean()*100:.1f}%")
])
print(summary)

print("\nüìä Results by Item Type:")
item_summary = results_df.groupby('item_type')['success'].agg([
    ('total', 'count'), 
    ('successful', 'sum'),
    ('success_rate', lambda x: f"{x.mean()*100:.1f}%")
])
print(item_summary)

print("\n‚ùå Failed Jobs:")
failed_df = results_df[~results_df['success']]
if not failed_df.empty:
    print(failed_df[['cohort', 'age_band', 'event_year', 'item_type', 'message']].to_string())
else:
    print("  None! All jobs completed successfully.")

print("\n‚úì Successful Jobs Sample:")
success_df = results_df[results_df['success']]
if not success_df.empty:
    print(success_df[['cohort', 'age_band', 'event_year', 'item_type', 'message']].head(15))
else:
    print("  No successful jobs.")


## Summary and Next Steps


In [None]:
print("="*80)
print("COHORT FPGROWTH ANALYSIS - SUMMARY")
print("="*80)

print(f"\nüìä Processing Statistics:")
print(f"  Cohort combinations: {len(cohort_jobs)}")
print(f"  Item types: {len(ITEM_TYPES)} (drug_name, icd_code, cpt_code)")
print(f"  Total jobs: {total_jobs}")
print(f"  Successfully processed: {completed}")
print(f"  Failed: {failed}")
print(f"  Success rate: {completed/total_jobs*100:.1f}%")
print(f"  Processing time: {elapsed:.1f}s ({elapsed/60:.1f}min)")

print(f"\nüîç FP-Growth Configuration:")
print(f"  Min support: {MIN_SUPPORT} ({MIN_SUPPORT*100:.1f}%)")
print(f"  Min confidence: {MIN_CONFIDENCE} ({MIN_CONFIDENCE*100:.1f}%)")
print(f"  Parallel workers: {MAX_WORKERS}")

print(f"\nüíæ Output Location:")
print(f"  S3 Base: {S3_OUTPUT_BASE}")
print(f"  Structure: <item_type>/cohort_name=<name>/age_band=<band>/event_year=<year>/")
print(f"  Item types:")
print(f"    - drug_name/ (pharmacy events)")
print(f"    - icd_code/ (diagnosis codes)")
print(f"    - cpt_code/ (procedure codes)")
print(f"  Files per cohort:")
print(f"    - itemsets.json (frequent itemsets)")
print(f"    - rules.json (association rules)")
print(f"    - summary.json (metadata)")

print(f"\nüéØ Next Steps:")
print(f"  1. Load cohort-specific itemsets for BupaR process mining")
print(f"  2. Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts across item types")
print(f"  3. Use association rules for pathway analysis")
print(f"  4. Filter features for cohort-specific CatBoost models")
print(f"  5. Create network visualizations for cohort-specific patterns")

print(f"\nüìù Example Usage:")
print(f"  # Load drug patterns for a specific cohort")
print(f"  from helpers_1997_13.s3_utils import load_from_s3_json")
print(f"  itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/drug_name/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")
print(f"  # Load ICD patterns for same cohort")
print(f"  icd_itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/icd_code/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")

print(f"\n‚úì Analysis complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)
