# Cohort-Specific FPGrowth Feature Importance Analysis

## Overview

This notebook performs **cohort-specific FPGrowth analysis** to discover drug patterns unique to each cohort. The results are used for:

1. **BupaR Process Mining**: Cohort-specific treatment pathways and sequences
2. **Comparative Analysis**: Understand differences between OPIOID_ED and ED_NON_OPIOID cohorts
3. **Feature Filtering**: Pre-filter features for cohort-specific CatBoost models

## Key Differences from Global Analysis

| Aspect | Global FPGrowth | Cohort FPGrowth |
|--------|-----------------|-----------------|
| **Scope** | All patients | Individual cohorts |
| **Purpose** | Universal ML features | Process mining patterns |
| **Support Threshold** | 0.005 (lower) | 0.05 (higher) |
| **Output** | Single encoding map | Multiple cohort-specific results |
| **Use Case** | CatBoost consistency | BupaR pathway analysis |

## Key Outputs

- **Cohort Drug Patterns**: Drug combinations specific to each cohort
- **Association Rules**: Cohort-specific prescribing patterns
- **Network Visualizations**: Visual representation of drug associations per cohort
- **Feature Manifests**: Metadata about processing status

## S3 Output Structure

```
s3://pgxdatalake/gold/fpgrowth/cohort/
‚îú‚îÄ‚îÄ cohort_name=opioid_ed/
‚îÇ   ‚îú‚îÄ‚îÄ age_band=65-74/
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ event_year=2020/
‚îÇ   ‚îÇ       ‚îú‚îÄ‚îÄ itemsets.json
‚îÇ   ‚îÇ       ‚îú‚îÄ‚îÄ rules.json
‚îÇ   ‚îÇ       ‚îî‚îÄ‚îÄ drug_network.html
‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ cohort_name=ed_non_opioid/
    ‚îî‚îÄ‚îÄ (same structure)
```

---


## Setup and Imports


In [None]:
import os
import sys
import json
import pandas as pd
import numpy as np
from datetime import datetime
import logging
from pathlib import Path
import time
from concurrent.futures import ProcessPoolExecutor, as_completed

# MLxtend for FP-Growth
from mlxtend.frequent_patterns import fpgrowth, association_rules
from mlxtend.preprocessing import TransactionEncoder

# Project root
project_root = Path.cwd().parent if Path.cwd().name == '3_fpgrowth_analysis' else Path.cwd()
sys.path.insert(0, str(project_root))

# Project utilities
from helpers_1997_13.common_imports import s3_client, S3_BUCKET
from helpers_1997_13.duckdb_utils import get_duckdb_connection
from helpers_1997_13.s3_utils import save_to_s3_json, save_to_s3_parquet, get_cohort_parquet_path
from helpers_1997_13.fpgrowth_utils import run_fpgrowth_drug_token_with_fallback, convert_frozensets
from helpers_1997_13.visualization_utils import create_network_visualization
from helpers_1997_13.constants import AGE_BANDS, EVENT_YEARS

print(f"‚úì Project root: {project_root}")
print(f"‚úì All imports successful")
print(f"‚úì Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


In [None]:
# FP-Growth parameters (higher threshold for cohort-specific patterns)
MIN_SUPPORT = 0.05  # 5% support threshold
MIN_CONFIDENCE = 0.3  # 30% confidence threshold
TOP_K = 30  # Top K itemsets to extract
TIMEOUT_SECONDS = 300

# Processing parameters
MAX_WORKERS = 4  # Parallel workers for processing multiple cohorts
COHORTS_TO_PROCESS = ['opioid_ed', 'ed_non_opioid']  # Can specify specific cohorts

# S3 output path
S3_OUTPUT_BASE = f"s3://{S3_BUCKET}/gold/fpgrowth/cohort"

# Local data path
LOCAL_DATA_PATH = project_root / "data" / "gold" / "cohorts_F1120"

# Create logger
logger = logging.getLogger('cohort_fpgrowth')
logger.setLevel(logging.INFO)
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

print(f"‚úì Min Support: {MIN_SUPPORT}")
print(f"‚úì Min Confidence: {MIN_CONFIDENCE}")
print(f"‚úì Max Workers: {MAX_WORKERS}")
print(f"‚úì Cohorts: {COHORTS_TO_PROCESS}")
print(f"‚úì S3 Output: {S3_OUTPUT_BASE}")
print(f"‚úì Local Data: {LOCAL_DATA_PATH}")
print(f"‚úì Local Data Exists: {LOCAL_DATA_PATH.exists()}")


## Step 1: Discover Available Cohorts

Scan local data to find all available cohort combinations.


In [None]:
def discover_cohorts(local_data_path, cohort_filter=None):
    """
    Discover all available cohort combinations from local data.
    """
    cohort_jobs = []
    
    for cohort_dir in local_data_path.glob("cohort_name=*"):
        cohort_name = cohort_dir.name.replace("cohort_name=", "")
        
        # Filter if specified
        if cohort_filter and cohort_name not in cohort_filter:
            continue
        
        for year_dir in cohort_dir.glob("event_year=*"):
            event_year = year_dir.name.replace("event_year=", "")
            
            for age_dir in year_dir.glob("age_band=*"):
                age_band = age_dir.name.replace("age_band=", "")
                
                # Check if cohort file exists
                cohort_file = age_dir / "cohort.parquet"
                if cohort_file.exists():
                    cohort_jobs.append({
                        'cohort': cohort_name,
                        'age_band': age_band,
                        'event_year': event_year,
                        'local_path': str(cohort_file)
                    })
    
    return cohort_jobs

# Discover available cohorts
cohort_jobs = discover_cohorts(LOCAL_DATA_PATH, cohort_filter=COHORTS_TO_PROCESS)

print(f"\nüìä Discovered Cohorts:")
print(f"  Total combinations: {len(cohort_jobs)}")

# Group by cohort
cohort_counts = {}
for job in cohort_jobs:
    cohort_counts[job['cohort']] = cohort_counts.get(job['cohort'], 0) + 1

for cohort, count in cohort_counts.items():
    print(f"  {cohort}: {count} combinations")

print(f"\n  Sample jobs:")
for job in cohort_jobs[:5]:
    print(f"    {job['cohort']}/{job['age_band']}/{job['event_year']}")


## Step 2: Define Cohort Processing Function

Create a function to process a single cohort with FP-Growth.


In [None]:
def process_single_cohort(job):
    """
    Process a single cohort with FP-Growth analysis.
    """
    cohort = job['cohort']
    age_band = job['age_band']
    event_year = job['event_year']
    local_path = job['local_path']
    
    # Create cohort-specific logger
    cohort_logger = logging.getLogger(f"cohort_{cohort}_{age_band}_{event_year}")
    cohort_logger.setLevel(logging.INFO)
    
    try:
        start_time = time.time()
        cohort_logger.info(f"Processing {cohort}/{age_band}/{event_year}")
        
        # Load cohort data
        con = get_duckdb_connection(logger=cohort_logger)
        df = con.execute(f"""
            SELECT mi_person_key, drug_name
            FROM read_parquet('{local_path}')
            WHERE drug_name IS NOT NULL 
              AND drug_name != ''
              AND event_type = 'PHARMACY'
        """).df()
        con.close()
        
        if df.empty:
            cohort_logger.warning(f"No pharmacy data for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, False, "No data")
        
        # Create transactions (patient-level drug lists)
        cohort_logger.info(f"Building transactions...")
        grouped = (
            df.groupby("mi_person_key")["drug_name"]
            .agg(lambda rows: sorted({
                f"drug_{str(d).strip().lower()}"
                for d in rows if pd.notnull(d) and str(d).strip()
            }))
            .reset_index()
            .rename(columns={"drug_name": "drug_tokens"})
        )
        
        transactions = [tokens for tokens in grouped["drug_tokens"].tolist() if len(tokens) > 0]
        
        if not transactions:
            cohort_logger.warning(f"No valid transactions for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, False, "No transactions")
        
        cohort_logger.info(f"Running FP-Growth with {len(transactions)} transactions...")
        
        # Run FP-Growth with fallback
        itemsets, rules = run_fpgrowth_drug_token_with_fallback(
            transactions=transactions,
            min_support_threshold=MIN_SUPPORT,
            timeout_seconds=TIMEOUT_SECONDS,
            top_k=TOP_K,
            logger=cohort_logger
        )
        
        if itemsets is None or itemsets.empty:
            cohort_logger.warning(f"FP-Growth returned no itemsets for {cohort}/{age_band}/{event_year}")
            return (cohort, age_band, event_year, False, "No itemsets")
        
        # Convert frozensets for JSON serialization
        itemsets_json = itemsets.copy()
        itemsets_json['itemsets'] = itemsets_json['itemsets'].apply(lambda x: list(x))
        
        rules_json = pd.DataFrame()
        if rules is not None and not rules.empty:
            rules_json = rules.copy()
            rules_json['antecedents'] = rules_json['antecedents'].apply(lambda x: list(x))
            rules_json['consequents'] = rules_json['consequents'].apply(lambda x: list(x))
        
        # Save to S3
        s3_base = f"{S3_OUTPUT_BASE}/cohort_name={cohort}/age_band={age_band}/event_year={event_year}"
        
        # Save itemsets
        itemsets_path = f"{s3_base}/itemsets.json"
        save_to_s3_json(itemsets_json.to_dict(orient='records'), itemsets_path)
        
        # Save rules
        if not rules_json.empty:
            rules_path = f"{s3_base}/rules.json"
            save_to_s3_json(rules_json.to_dict(orient='records'), rules_path)
        
        # Save summary
        summary = {
            'timestamp': datetime.now().isoformat(),
            'cohort': cohort,
            'age_band': age_band,
            'event_year': event_year,
            'total_patients': len(transactions),
            'total_itemsets': len(itemsets),
            'total_rules': len(rules) if rules is not None else 0,
            'min_support': MIN_SUPPORT,
            'min_confidence': MIN_CONFIDENCE
        }
        summary_path = f"{s3_base}/summary.json"
        save_to_s3_json(summary, summary_path)
        
        elapsed = time.time() - start_time
        cohort_logger.info(f"‚úì Completed {cohort}/{age_band}/{event_year} in {elapsed:.1f}s")
        
        return (cohort, age_band, event_year, True, f"{len(itemsets)} itemsets, {len(rules) if rules is not None else 0} rules")
        
    except Exception as e:
        cohort_logger.error(f"Error processing {cohort}/{age_band}/{event_year}: {e}")
        return (cohort, age_band, event_year, False, str(e))

print("‚úì Cohort processing function defined")


## Step 3: Process Cohorts in Parallel

Run FP-Growth for all cohort combinations using parallel processing.


In [None]:
logger.info(f"Processing {len(cohort_jobs)} cohorts with {MAX_WORKERS} workers...")
start_time = time.time()

results = []
completed = 0
failed = 0

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submit all jobs
    future_to_job = {executor.submit(process_single_cohort, job): job for job in cohort_jobs}
    
    # Process results as they complete
    for future in as_completed(future_to_job):
        job = future_to_job[future]
        try:
            cohort, age_band, event_year, success, message = future.result()
            results.append({
                'cohort': cohort,
                'age_band': age_band,
                'event_year': event_year,
                'success': success,
                'message': message
            })
            
            if success:
                completed += 1
                logger.info(f"[{completed + failed}/{len(cohort_jobs)}] ‚úì {cohort}/{age_band}/{event_year}: {message}")
            else:
                failed += 1
                logger.warning(f"[{completed + failed}/{len(cohort_jobs)}] ‚úó {cohort}/{age_band}/{event_year}: {message}")
                
        except Exception as e:
            failed += 1
            logger.error(f"[{completed + failed}/{len(cohort_jobs)}] ‚úó {job['cohort']}/{job['age_band']}/{job['event_year']}: {e}")
            results.append({
                'cohort': job['cohort'],
                'age_band': job['age_band'],
                'event_year': job['event_year'],
                'success': False,
                'message': str(e)
            })

elapsed = time.time() - start_time

print(f"\nüìä Processing Complete:")
print(f"  Total jobs: {len(cohort_jobs)}")
print(f"  Successful: {completed}")
print(f"  Failed: {failed}")
print(f"  Success rate: {completed/len(cohort_jobs)*100:.1f}%")
print(f"  Total time: {elapsed:.1f}s")
print(f"  Avg time per cohort: {elapsed/len(cohort_jobs):.1f}s")


## Step 4: Analyze Results

Review processing results and identify any issues.


In [None]:
# Convert results to DataFrame for analysis
results_df = pd.DataFrame(results)

print("\nüìä Results by Cohort:")
print(results_df.groupby('cohort')['success'].agg(['count', 'sum', lambda x: f"{x.mean()*100:.1f}%"]).rename(columns={'sum': 'successful', '<lambda_0>': 'success_rate'}))

print("\n‚ùå Failed Jobs:")
failed_df = results_df[~results_df['success']]
if not failed_df.empty:
    print(failed_df[['cohort', 'age_band', 'event_year', 'message']])
else:
    print("  None! All jobs completed successfully.")

print("\n‚úì Successful Jobs Sample:")
success_df = results_df[results_df['success']]
if not success_df.empty:
    print(success_df[['cohort', 'age_band', 'event_year', 'message']].head(10))
else:
    print("  No successful jobs.")


## Summary and Next Steps


In [None]:
print("="*80)
print("COHORT FPGROWTH ANALYSIS - SUMMARY")
print("="*80)

print(f"\nüìä Processing Statistics:")
print(f"  Total cohort combinations: {len(cohort_jobs)}")
print(f"  Successfully processed: {completed}")
print(f"  Failed: {failed}")
print(f"  Success rate: {completed/len(cohort_jobs)*100:.1f}%")
print(f"  Processing time: {elapsed:.1f}s ({elapsed/60:.1f}min)")

print(f"\nüîç FP-Growth Configuration:")
print(f"  Min support: {MIN_SUPPORT} ({MIN_SUPPORT*100:.1f}%)")
print(f"  Min confidence: {MIN_CONFIDENCE} ({MIN_CONFIDENCE*100:.1f}%)")
print(f"  Top K itemsets: {TOP_K}")
print(f"  Parallel workers: {MAX_WORKERS}")

print(f"\nüíæ Output Location:")
print(f"  S3 Base: {S3_OUTPUT_BASE}")
print(f"  Structure: cohort_name=<name>/age_band=<band>/event_year=<year>/")
print(f"  Files per cohort:")
print(f"    - itemsets.json (frequent drug combinations)")
print(f"    - rules.json (association rules)")
print(f"    - summary.json (metadata)")

print(f"\nüéØ Next Steps:")
print(f"  1. Load cohort-specific itemsets for BupaR process mining")
print(f"  2. Compare patterns between OPIOID_ED and ED_NON_OPIOID cohorts")
print(f"  3. Use association rules for pathway analysis")
print(f"  4. Filter features for cohort-specific CatBoost models")
print(f"  5. Create network visualizations for cohort-specific patterns")

print(f"\nüìù Example Usage:")
print(f"  # Load results for a specific cohort")
print(f"  from helpers_1997_13.s3_utils import load_from_s3_json")
print(f"  itemsets = load_from_s3_json('{S3_OUTPUT_BASE}/cohort_name=opioid_ed/age_band=65-74/event_year=2020/itemsets.json')")

print(f"\n‚úì Analysis complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)
