# Feature Importance with Monte Carlo Cross-Validation

**Purpose:** Calculate scaled feature importance using CatBoost and Random Forest  
**Method:** Normalized feature importance scaled by MC-CV Recall scores  
**Based on:** [R Example](https://github.com/Jerome3590/phts/blob/main/graft-loss/feature_importance/replicate_20_features_MC_CV.R)  
**Updated:** November 2025  
**Hardware:** Optimized for EC2 (32 cores, 1TB RAM)  
**Validation:** Proper evaluation on unseen test data

## Key Features

‚úÖ **Monte Carlo Cross-Validation** ‚Äì up to 1000 random train/test splits (100-split runs used for faster iteration)  
‚úÖ **Stratified Sampling** - Maintains target distribution  
‚úÖ **Parallel Processing** - Fast execution with furrr/future (‚âà30 workers)  
‚úÖ **95% Confidence Intervals** - Narrow, precise estimates (tighter with more splits)  
‚úÖ **Multiple Models** - CatBoost (R) and Random Forest (R)  

## Methodology

This notebook implements the feature selection methodology:

1. Load cohort data from parquet files (same as FP-Growth notebook)
2. Create patient-level features (one-hot encoding of items)
3. For each model type:
   - Create 100‚Äì1000 stratified train/test splits
   - Train model on training set
   - Evaluate Recall on unseen test set
   - Extract feature importance
   - Aggregate results across splits
4. Normalize and scale feature importance by MC-CV Recall
5. Aggregate across models
6. Extract top features

## Expected Runtime

- **100 splits (current default):**
  - Local (4 cores): ~2‚Äì4 hours
  - Workstation (16 cores): ~1‚Äì2 hours
  - EC2 (32 cores, 1TB RAM): ~1‚Äì2 hours ‚úÖ **RECOMMENDED FOR DEVELOPMENT**
- **1000 splits (extended / publication-level):**
  - Local (4 cores): 8‚Äì12+ hours
  - Workstation (16 cores): ~8‚Äì16 hours
  - EC2 (32 cores, 1TB RAM): ~10‚Äì20 hours ‚úÖ **RECOMMENDED FOR FINAL RESULTS**


## 1. Setup and Configuration

Load required packages and configure parallel processing.

**üìñ Documentation:** See [Feature Importance README](/docs/README_feature_importance.md) for detailed documentation, usage examples, and troubleshooting.


In [1]:
# Check R version
R.version.string

# Load required packages
suppressPackageStartupMessages({
  library(here)
  library(dplyr)
  library(readr)
  library(tidyr)
  library(tibble)
  library(purrr)
  library(catboost)
  library(randomForest)
  library(rsample)    # For MC-CV
  library(furrr)      # For parallel processing
  library(future)     # For parallel backend
  library(progressr)  # For progress bars
  library(duckdb)     # For loading parquet files
  library(DBI)        # Database interface for DuckDB
})

cat("‚úì All packages loaded successfully\n")

# ============================================================
# SOURCE HELPER FUNCTIONS (Modular R Scripts)
# ============================================================
helpers_dir <- here("helpers_13_1997")
if (!dir.exists(helpers_dir)) {
  stop(sprintf("Helpers directory not found: %s", helpers_dir))
}

# Source constants first (needed by other scripts)
source(file.path(helpers_dir, "constants.R"))

# Source all helper scripts
source(file.path(helpers_dir, "logging_utils.R"))
source(file.path(helpers_dir, "metrics.R"))
source(file.path(helpers_dir, "model_helpers.R"))
source(file.path(helpers_dir, "mc_cv_helpers.R"))
source(file.path(helpers_dir, "run_cohort_analysis.R"))

cat("‚úì All helper functions loaded from R scripts\n")
cat(sprintf("‚úì Age bands loaded from constants: %s\n", paste(AGE_BANDS, collapse = ", ")))
cat(sprintf("‚úì Cohorts loaded from constants: %s\n", paste(COHORT_NAMES, collapse = ", ")))


‚úì All packages loaded successfully
‚úì All helper functions loaded from R scripts
‚úì Age bands loaded from constants: 0-12, 13-24, 25-44, 45-54, 55-64, 65-74, 75-84, 85-94, 95-114
‚úì Cohorts loaded from constants: opioid_ed, non_opioid_ed


In [None]:
# ============================================================
# DEBUG/TEST MODE - Quick testing before full run
# ============================================================
# Set DEBUG_MODE = TRUE for quick testing (5 splits, ~2-5 min)
# Set DEBUG_MODE = FALSE for full analysis (100 splits, ~1-2 hours on EC2)

DEBUG_MODE <- FALSE  # Change to TRUE for quick test

if (DEBUG_MODE) {
  cat("\n")
  cat("‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó\n")
  cat("‚ïë                    üîç DEBUG MODE ENABLED                       ‚ïë\n")
  cat("‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù\n")
  cat("\n")
  cat("Quick test configuration:\n")
  cat("  ‚Ä¢ MC-CV Splits: 5 (instead of 100)\n")
  cat("  ‚Ä¢ Expected time: 2-5 minutes\n")
  cat("  ‚Ä¢ Purpose: Verify everything works before full run\n")
  cat("\n")
  cat("To run full analysis, set DEBUG_MODE = FALSE\n")
  cat("\n")
}

# Configuration
# NOTE: Target definition differs by cohort (already baked into cohort data via is_target_case):
#   - "opioid_ed": Target cases = patients with F1120/opioid ICD codes (any of 10 ICD columns)
#   - "non_opioid_ed": Target cases = patients with HCG ED visits (P51/O11/P33) WITHOUT opioid codes
# Controls (is_target_case=0) are sampled to maintain 5:1 ratio in both cohorts

# Configuration
# NOTE: AGE_BANDS and COHORT_NAMES are loaded from constants.R (sourced in Cell 2)
# To override, uncomment and modify:
# AGE_BANDS <- c("25-44")  # Single age-band for testing
# COHORT_NAMES <- c("opioid_ed")  # Single cohort for testing

EVENT_YEAR <- 2016          # Change as needed

# For backward compatibility, set AGE_BAND to first value (used in single-cohort execution)
AGE_BAND <- if (length(AGE_BANDS) == 1) AGE_BANDS else AGE_BANDS[1]

# For backward compatibility, also set COHORT_NAME (will be overridden in function)
COHORT_NAME <- COHORT_NAMES[1]

N_SPLITS <- if (DEBUG_MODE) 5 else 200  # MC-CV splits (5 for debug, 100 for development, 1000 for production)
TEST_SIZE <- 0.2             # Test set proportion (20%)
TRAIN_PROP <- 1 - TEST_SIZE  # Training proportion (80%)

# Scaling metric for feature importance
# Options: "recall" (default) or "logloss"
# - Recall: Higher is better (0-1), good for imbalanced classes, focuses on finding positives
# - LogLoss: Lower is better, measures probability calibration, penalizes overconfident errors
#   (will be inverted: 1/logloss for scaling, so higher = better)
SCALING_METRIC <- "recall"  # Change to "logloss" if preferred

# Model parameters
MODEL_PARAMS <- list(
  catboost = list(
    iterations = 100,
    learning_rate = 0.1,
    depth = 6,
    verbose = 0L,  # Turn off CatBoost logging (0L = integer 0)
    random_seed = 42
  ),
  random_forest = list(
    ntree = 100,
    mtry = NULL,  # Will be set to sqrt(n_features)
    nodesize = 1,
    maxnodes = NULL
  )
)

# Set up parallel processing
# EC2 optimization: Use 30 out of 32 cores (leave 2 for system)
N_WORKERS <- as.integer(Sys.getenv("N_WORKERS", "0"))
if (N_WORKERS < 1) {
  # Auto-detect: use all cores minus 2 for system
  total_cores <- parallel::detectCores()
  N_WORKERS <- max(1, total_cores - 2)
  cat(sprintf("Auto-detected %d cores, using %d workers\n", total_cores, N_WORKERS))
}
cat(sprintf("Setting up parallel processing with %d workers...\n", N_WORKERS))

# Increase future.globals.maxSize for large data objects
# With 1TB RAM on EC2, we can handle large transfers
# Note: After refactoring, we use lightweight split_indices instead of large mc_splits
# Still need space for X_all, y_all, and data_full matrices
options(future.globals.maxSize = 97 * 1024^3)  # 97 GB limit
cat("Set future.globals.maxSize to 97 GB\n")

# Set up parallel processing plan
# Use N_WORKERS for MC-CV within each cohort
plan(multisession, workers = N_WORKERS)

# Output directory
output_dir <- here("outputs")
dir.create(output_dir, showWarnings = FALSE, recursive = TRUE)

# Clean up any incomplete/checkpoint files from previous runs
# Only keep files that match the expected pattern and are complete
if (dir.exists(output_dir)) {
  existing_files <- list.files(output_dir, pattern = "feature_importance.*\\.csv$", full.names = TRUE)
  if (length(existing_files) > 0) {
    cat(sprintf("Found %d existing output files in %s\n", length(existing_files), output_dir))
    cat("Note: These will be overwritten if processing runs. Only S3 files are used for idempotency checks.\n")
    # Optionally remove incomplete files (uncomment to enable):
    # incomplete_patterns <- c("_partial", "_checkpoint", "_tmp", "_incomplete")
    # incomplete_files <- existing_files[grepl(paste(incomplete_patterns, collapse = "|"), existing_files)]
    # if (length(incomplete_files) > 0) {
    #   cat(sprintf("Removing %d incomplete/checkpoint files...\n", length(incomplete_files)))
    #   file.remove(incomplete_files)
    # }
  }
}

cat("Output directory:", output_dir, "\n")
cat(sprintf("MC-CV Configuration: %d splits, %.0f/%.0f train/test split\n", 
            N_SPLITS, TRAIN_PROP * 100, TEST_SIZE * 100))
cat(sprintf("Cohorts to process: %s\n", paste(COHORT_NAMES, collapse = ", ")))
cat(sprintf("Running %d cohort(s) in parallel\n", length(COHORT_NAMES)))


Auto-detected 32 cores, using 30 workers
Setting up parallel processing with 30 workers...
Set future.globals.maxSize to 97 GB
Output directory: /home/pgx3874/pgx-analysis/3_feature_importance/outputs 
MC-CV Configuration: 200 splits, 80/20 train/test split
Cohorts to process: opioid_ed, non_opioid_ed
Running 2 cohort(s) in parallel


In [None]:
# ============================================================
# CHECK FOR COHORT FILES EXISTENCE
# ============================================================
# Function to check if cohort file exists before processing
check_cohort_file_exists <- function(cohort_name, age_band, event_year) {
  # Determine local data path
  local_data_path <- Sys.getenv("LOCAL_DATA_PATH", "/mnt/nvme/cohorts")
  if (!dir.exists(local_data_path)) {
    local_data_path <- Sys.getenv("LOCAL_DATA_PATH", "C:/Projects/pgx-analysis/data/gold/cohorts_F1120")
  }
  
  parquet_file <- file.path(local_data_path, 
                            paste0("cohort_name=", cohort_name),
                            paste0("event_year=", event_year),
                            paste0("age_band=", age_band),
                            "cohort.parquet")
  
  return(file.exists(parquet_file))
}

# ============================================================
# CHECK FOR ALREADY PROCESSED COMBINATIONS (IDEMPOTENCY)
# ============================================================
# Function to check if results already exist (S3 only - final destination)
# Note: Only checks S3, not local output_dir which may contain checkpoints or incomplete files
check_results_exist <- function(cohort_name, age_band, event_year) {
  # Check S3 (final destination) - don't check local output_dir which may have checkpoints
  s3_base <- "s3://pgxdatalake/gold/feature_importance"
  s3_path <- sprintf("%s/cohort_name=%s/age_band=%s/event_year=%d/%s_%s_%d_feature_importance_aggregated.csv",
                     s3_base, cohort_name, age_band, event_year,
                     cohort_name, age_band, event_year)
  
  # Use AWS CLI to check if file exists
  aws_cmd <- Sys.which("aws")
  if (aws_cmd == "") {
    # Try common AWS CLI paths
    aws_paths <- c(
      "/usr/local/bin/aws",
      "/usr/bin/aws",
      "/home/ec2-user/.local/bin/aws",
      "C:/Program Files/Amazon/AWSCLIV2/aws.exe"
    )
    for (path in aws_paths) {
      if (file.exists(path)) {
        aws_cmd <- path
        break
      }
    }
  }
  
  if (aws_cmd != "" && file.exists(aws_cmd)) {
    # Extract bucket and key from s3:// path
    s3_parts <- gsub("^s3://", "", s3_path)
    parts <- strsplit(s3_parts, "/", fixed = TRUE)[[1]]
    bucket <- parts[1]
    key <- paste(parts[-1], collapse = "/")
    
    # Check if object exists in S3 using s3api head-object (more reliable)
    # Suppress warnings for "file not found" (status 1 is expected when file doesn't exist)
    result <- tryCatch({
      # Use head-object: returns 0 if file exists, non-zero if not found
      # Suppress stderr to avoid warnings
      exit_code <- suppressWarnings(
        system2(aws_cmd, 
                c("s3api", "head-object", 
                  "--bucket", bucket,
                  "--key", key),
                stdout = FALSE,  # Don't need stdout
                stderr = FALSE,  # Suppress stderr to avoid warnings
                wait = TRUE)
      )
      
      # system2 returns exit code directly when wait=TRUE
      # 0 = success (file exists), non-zero = file doesn't exist or error
      exit_code == 0
    }, error = function(e) {
      # If head-object fails (e.g., not available), try ls as fallback
      tryCatch({
        # Suppress all warnings for ls command
        result_ls <- suppressWarnings(
          system2(aws_cmd, c("s3", "ls", s3_path), 
                 stdout = TRUE, stderr = FALSE)  # Suppress stderr
        )
        # If we get output, file exists
        !is.null(result_ls) && length(result_ls) > 0
      }, error = function(e2) {
        return(FALSE)
      })
    })
    
    if (is.logical(result) && result) {
      return(TRUE)
    }
  }
  
  return(FALSE)
}

# ============================================================
# RUN COHORTS AND AGE-BANDS IN PARALLEL
# ============================================================
# Create all combinations of cohort x age-band
combinations <- expand.grid(
  cohort = COHORT_NAMES,
  age_band = AGE_BANDS,
  stringsAsFactors = FALSE
)

# Check which combinations already exist
cat("\nChecking for already processed combinations...\n")
combinations$already_exists <- mapply(check_results_exist, 
                                       combinations$cohort, 
                                       combinations$age_band, 
                                       EVENT_YEAR)

# Check which cohort files exist
cat("Checking for cohort files...\n")
combinations$file_exists <- mapply(check_cohort_file_exists,
                                    combinations$cohort,
                                    combinations$age_band,
                                    EVENT_YEAR)

# Filter to only process new combinations that have cohort files
combinations_to_process <- combinations[!combinations$already_exists & combinations$file_exists, ]
combinations_skipped <- combinations[combinations$already_exists, ]
combinations_missing_files <- combinations[!combinations$already_exists & !combinations$file_exists, ]

cat(sprintf("Total combinations: %d\n", nrow(combinations)))
cat(sprintf("Already processed: %d\n", sum(combinations$already_exists)))
cat(sprintf("To process: %d\n", nrow(combinations_to_process)))

if (nrow(combinations_skipped) > 0) {
  cat("\nSkipping (already processed):\n")
  for (i in 1:nrow(combinations_skipped)) {
    cat(sprintf("  - %s / %s\n", combinations_skipped$cohort[i], combinations_skipped$age_band[i]))
  }
}

if (nrow(combinations_to_process) == 0) {
  cat("\n‚úì All combinations already processed! Nothing to do.\n")
  stop("All combinations already processed. Exiting.")
}

total_tasks <- nrow(combinations_to_process)
cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
cat("Running feature importance analysis\n")
cat(paste(rep("=", 80), collapse=""), "\n", sep="")
cat(sprintf("Cohorts: %s\n", paste(COHORT_NAMES, collapse = ", ")))
cat(sprintf("Age Bands: %s\n", paste(AGE_BANDS, collapse = ", ")))
cat(sprintf("Event Year: %d\n", EVENT_YEAR))
cat(sprintf("Total combinations: %d (cohorts √ó age-bands)\n", total_tasks))
cat(sprintf("Parallel workers: %d\n", min(total_tasks, parallel::detectCores() - 2)))
cat(sprintf("MC-CV workers per task: %d\n", N_WORKERS))
cat(paste(rep("=", 80), collapse=""), "\n\n", sep="")

# Use a separate plan for task-level parallelism
# Calculate worker allocation for nested parallelism:
# - Task-level: Run multiple cohort/age-band combinations in parallel
# - MC-CV level: Within each task, run MC-CV splits in parallel
# We need to divide available cores between these two levels

# Check available cores (respects mc.cores and other limits)
available_cores <- tryCatch({
  if (requireNamespace("parallelly", quietly = TRUE)) {
    parallelly::availableCores()
  } else {
    parallel::detectCores()
  }
}, error = function(e) {
  parallel::detectCores()
})

# Reserve 2 cores for system
usable_cores <- max(2, available_cores - 2)

# Calculate task-level workers (how many tasks run in parallel)
# Use min of: number of tasks, or reasonable limit based on cores
# For many tasks, we want to run multiple in parallel but not all at once
max_task_workers <- min(total_tasks, max(2, floor(usable_cores / 4)))  # Each task needs MC-CV workers too
task_workers <- min(total_tasks, max_task_workers)

# Calculate MC-CV workers per task
# Divide remaining cores among tasks: (usable_cores - task_workers) / task_workers
# But ensure each task gets at least 1 worker for MC-CV
mc_cv_workers_per_task <- max(1, floor((usable_cores - task_workers) / task_workers))

# Cap MC-CV workers to reasonable maximum (don't want too many per task)
mc_cv_workers_per_task <- min(mc_cv_workers_per_task, 10)  # Max 10 workers per task for MC-CV

cat(sprintf("\nNested Parallelism Configuration:\n"))
cat(sprintf("  Available cores: %d\n", available_cores))
cat(sprintf("  Usable cores (reserved 2 for system): %d\n", usable_cores))
cat(sprintf("  Task-level workers: %d (running %d tasks in parallel)\n", task_workers, min(total_tasks, task_workers)))
cat(sprintf("  MC-CV workers per task: %d\n", mc_cv_workers_per_task))
cat(sprintf("  Total worker allocation: %d (task-level) + %d √ó %d (MC-CV) = %d workers\n",
            task_workers, task_workers, mc_cv_workers_per_task,
            task_workers + (task_workers * mc_cv_workers_per_task)))
cat(sprintf("  (Note: Workers are shared/reused, so actual usage is ~%d concurrent)\n", usable_cores))

# Save current plan first
current_plan <- plan("list")
# Set task-level plan
cat(sprintf("\nSetting task-level plan to %d workers\n", task_workers))
plan(tweak(multisession, workers = task_workers))

# Update N_WORKERS to be used by each task for MC-CV
# This will be passed to run_cohort_analysis
N_WORKERS_MC_CV <- mc_cv_workers_per_task
cat(sprintf("Each task will use %d workers for MC-CV\n", N_WORKERS_MC_CV))
# Verify plan is set correctly
plan_info <- plan("list")
workers_info <- if (length(plan_info) > 0 && "workers" %in% names(plan_info[[1]])) {
  as.character(plan_info[[1]]$workers)
} else {
  "unknown"
}
cat(sprintf("Current plan: %s with %s workers\n", 
            if (length(plan_info) > 0) class(plan_info[[1]])[1] else "unknown",
            workers_info))
cat("\n")

# Run all combinations in parallel
start_time <- Sys.time()
cat(sprintf("Starting parallel execution of %d tasks...\n", total_tasks))
cat(sprintf("Task-level plan: %d workers\n", task_workers))
cat(sprintf("MC-CV plan per task: %d workers\n", if (exists("N_WORKERS_MC_CV")) N_WORKERS_MC_CV else N_WORKERS))
cat("All tasks should start simultaneously...\n\n")

task_results <- future_map(
  1:total_tasks,
  function(i) {
    cohort_name <- combinations_to_process$cohort[i]
    age_band <- combinations_to_process$age_band[i]
    task_id <- sprintf("%s_%s", cohort_name, age_band)
    cat(sprintf("[%s] Starting analysis at %s\n", task_id, format(Sys.time(), "%H:%M:%S")))
    
    # Capture parameter values explicitly to avoid scoping issues in parallel workers
    # This ensures values are captured from the calling environment, not looked up in worker environment
    event_year_val <- EVENT_YEAR
    n_splits_val <- N_SPLITS
    train_prop_val <- TRAIN_PROP
    # Use MC-CV specific worker count (calculated above for nested parallelism)
    n_workers_val <- if (exists("N_WORKERS_MC_CV")) N_WORKERS_MC_CV else N_WORKERS
    scaling_metric_val <- SCALING_METRIC
    model_params_val <- MODEL_PARAMS
    debug_mode_val <- DEBUG_MODE
    
    # Validate critical parameters before calling
    if (is.null(n_splits_val) || is.na(n_splits_val) || !is.numeric(n_splits_val) || n_splits_val <= 0) {
      stop(sprintf("[%s] Invalid N_SPLITS value: %s (type: %s). Ensure Cell 3 was executed.", 
                   task_id, 
                   if (is.null(n_splits_val)) "NULL" else if (is.na(n_splits_val)) "NA" else as.character(n_splits_val),
                   typeof(n_splits_val)))
    }
    
    # Pass parameters explicitly to ensure they're available in worker environment
    result <- run_cohort_analysis(
      cohort_name = cohort_name,
      age_band = age_band,
      event_year = event_year_val,
      n_splits = n_splits_val,
      train_prop = train_prop_val,
      n_workers = n_workers_val,
      scaling_metric = scaling_metric_val,
      model_params = model_params_val,
      debug_mode = debug_mode_val
    )
    cat(sprintf("[%s] Completed analysis at %s\n", task_id, format(Sys.time(), "%H:%M:%S")))
    return(result)
  },
  .options = furrr_options(
    seed = NULL,  # Don't set seed for task-level parallelism
    globals = c("combinations_to_process", "EVENT_YEAR", "N_SPLITS", "TRAIN_PROP", 
                "TEST_SIZE", "SCALING_METRIC", "DEBUG_MODE", "N_WORKERS", "N_WORKERS_MC_CV", "MODEL_PARAMS",
                "output_dir", "run_cohort_analysis",
                # Note: setup_r_logging creates file connections which can't be serialized
                # The function will be called inside run_cohort_analysis, so it needs to be available
                "setup_r_logging", "save_logs_to_s3_r", "check_memory_usage_r",
                "train_catboost_r", "train_random_forest_r",
                "predict_catboost_r", "predict_random_forest_r",
                "predict_proba_catboost_r", "predict_proba_random_forest_r",
                "get_importance_catboost_r", "get_importance_random_forest_r",
                "calculate_recall", "calculate_logloss",
                "run_mc_cv_method", "mc_cv", "future_map", "furrr_options", "progressor",
                # Ensure future package functions are available
                "future", "plan", "multisession"),
    # Ensure required packages are available in worker environments
    packages = c("rsample", "furrr", "progressr", "catboost", "randomForest",
                 "dplyr", "tibble", "purrr", "tidyr", "readr", "here", "duckdb", "DBI")
  )
)
end_time <- Sys.time()

# Restore original plan (use the MC-CV worker count, not task-level)
plan(multisession, workers = if (exists("N_WORKERS_MC_CV")) N_WORKERS_MC_CV else N_WORKERS)

# Print summary
cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
cat("TASK ANALYSIS SUMMARY\n")
cat(paste(rep("=", 80), collapse=""), "\n", sep="")
cat(sprintf("Total time: %.1f minutes\n", as.numeric(difftime(end_time, start_time, units = "mins"))))
cat(sprintf("Total tasks completed: %d\n", length(task_results)))
cat(sprintf("Total tasks skipped (already processed): %d\n", nrow(combinations_skipped)))

# Group results by cohort (including skipped ones)
cohort_summary <- list()
for (result in task_results) {
  cohort <- result$cohort
  if (!cohort %in% names(cohort_summary)) {
    cohort_summary[[cohort]] <- list()
  }
  cohort_summary[[cohort]][[length(cohort_summary[[cohort]]) + 1]] <- result
}

# Add skipped combinations to summary
for (i in 1:nrow(combinations_skipped)) {
  cohort <- combinations_skipped$cohort[i]
  age_band <- combinations_skipped$age_band[i]
  if (!cohort %in% names(cohort_summary)) {
    cohort_summary[[cohort]] <- list()
  }
  cohort_summary[[cohort]][[length(cohort_summary[[cohort]]) + 1]] <- list(
    cohort = cohort,
    age_band = age_band,
    status = "skipped",
    note = "Already processed"
  )
}

for (cohort_name in names(cohort_summary)) {
  cat(sprintf("\nCohort: %s\n", cohort_name))
  cat(sprintf("  Age-bands processed: %d\n", length(cohort_summary[[cohort_name]])))
  for (result in cohort_summary[[cohort_name]]) {
    age_band <- if ("age_band" %in% names(result)) result$age_band else "unknown"
    cat(sprintf("  - Age-band %s: ", age_band))
    if (result$status == "success") {
      cat(sprintf("‚úì Success (Features: %d)\n", nrow(result$aggregated)))
    } else if (result$status == "skipped") {
      cat(sprintf("‚äò Skipped (already processed)\n"))
    } else {
      cat(sprintf("‚úó Error: %s\n", result$error))
    }
  }
}

# ============================================================
# CHECK IF ALL COMBINATIONS ARE COMPLETE FOR AGGREGATION
# ============================================================
cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
cat("CHECKING COMPLETENESS FOR AGGREGATION\n")
cat(paste(rep("=", 80), collapse=""), "\n", sep="")

# Function to check if all required combinations exist
check_all_combinations_complete <- function(cohort_names, age_bands, event_year) {
  all_combinations <- expand.grid(
    cohort = cohort_names,
    age_band = age_bands,
    stringsAsFactors = FALSE
  )
  
  complete_combinations <- 0
  missing_combinations <- list()
  
  for (i in 1:nrow(all_combinations)) {
    cohort <- all_combinations$cohort[i]
    age_band <- all_combinations$age_band[i]
    
    if (check_results_exist(cohort, age_band, event_year)) {
      complete_combinations <- complete_combinations + 1
    } else {
      missing_combinations[[length(missing_combinations) + 1]] <- 
        list(cohort = cohort, age_band = age_band)
    }
  }
  
  return(list(
    total = nrow(all_combinations),
    complete = complete_combinations,
    missing = missing_combinations,
    all_complete = (complete_combinations == nrow(all_combinations))
  ))
}

# Check completeness for each cohort
completeness_status <- list()
for (cohort_name in COHORT_NAMES) {
  status <- check_all_combinations_complete(c(cohort_name), AGE_BANDS, EVENT_YEAR)
  completeness_status[[cohort_name]] <- status
  
  cat(sprintf("\nCohort: %s\n", cohort_name))
  cat(sprintf("  Required age-bands: %d\n", status$total))
  cat(sprintf("  Complete: %d\n", status$complete))
  cat(sprintf("  Missing: %d\n", length(status$missing)))
  
  if (length(status$missing) > 0) {
    cat("  Missing combinations:\n")
    for (missing in status$missing) {
      cat(sprintf("    - %s / %s\n", missing$cohort, missing$age_band))
    }
  }
  
  if (status$all_complete) {
    cat(sprintf("  ‚úì All age-bands complete - ready for aggregation\n"))
  } else {
    cat(sprintf("  ‚ö† Not all age-bands complete - aggregation will be skipped\n"))
  }
}

# Determine if we can proceed with cross-age-band aggregation
all_cohorts_complete <- all(sapply(completeness_status, function(x) x$all_complete))

if (all_cohorts_complete) {
  cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
  cat("‚úì ALL COMBINATIONS COMPLETE\n")
  cat("  Ready for cross-age-band aggregation and visualizations\n")
  cat(paste(rep("=", 80), collapse=""), "\n", sep="")
  
  # Set flag for downstream aggregation steps
  AGGREGATION_READY <- TRUE
} else {
  cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
  cat("‚ö† NOT ALL COMBINATIONS COMPLETE\n")
  cat("  Cross-age-band aggregation will be skipped\n")
  cat("  Run again after all combinations are processed\n")
  cat(paste(rep("=", 80), collapse=""), "\n", sep="")
  
  AGGREGATION_READY <- FALSE
}

cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
cat("‚úì All cohort analyses complete\n")
cat(paste(rep("=", 80), collapse=""), "\n\n", sep="")



Checking for already processed combinations...
Checking for cohort files...
Total combinations: 18
Already processed: 0
To process: 18

Running feature importance analysis
Cohorts: opioid_ed, non_opioid_ed
Age Bands: 0-12, 13-24, 25-44, 45-54, 55-64, 65-74, 75-84, 85-94, 95-114
Event Year: 2016
Total combinations: 18 (cohorts √ó age-bands)
Parallel workers: 18
MC-CV workers per task: 30


Nested Parallelism Configuration:
  Available cores: 32
  Usable cores (reserved 2 for system): 30
  Task-level workers: 7 (running 7 tasks in parallel)
  MC-CV workers per task: 3
  Total worker allocation: 7 (task-level) + 7 √ó 3 (MC-CV) = 28 workers
  (Note: Workers are shared/reused, so actual usage is ~30 concurrent)

Setting task-level plan to 7 workers
Each task will use 3 workers for MC-CV
Current plan: tweaked with unknown workers

Starting parallel execution of 18 tasks...
Task-level plan: 7 workers
MC-CV plan per task: 3 workers
All tasks should start simultaneously...



‚Äú'package:rsample' may not be available when loading‚Äù
‚Äú'package:furrr' may not be available when loading‚Äù
‚Äú'package:progressr' may not be available when loading‚Äù
‚Äú'package:future' may not be available when loading‚Äù


## 2. Parallel Execution

**This notebook is configured for parallel execution** of multiple cohort √ó age-band combinations.

**Parallel Processing**: Runs all combinations defined in `COHORT_NAMES` √ó `AGE_BANDS` in parallel. Each task processes one cohort/age-band combination and uses the `run_cohort_analysis()` function.

**Single Cohort Execution**: If you need to run a single cohort/age-band combination instead of parallel execution, see the [Single Cohort Execution section](/docs/README_feature_importance.md#single-cohort-execution-optional) in the Feature Importance README for a complete example.


## 3. Cross-Age-Band Aggregation and Visualizations

**Note:** This step only runs when all cohort √ó age-band combinations are complete.

Creates:
- Cross-age-band heatmaps showing feature importance across age bands
- Aggregated visualizations comparing cohorts


# Cleanup

Close parallel processing.


In [None]:
# ============================================================
# CROSS-AGE-BAND AGGREGATION AND VISUALIZATIONS
# ============================================================
# Only run if all combinations are complete
if (exists("AGGREGATION_READY") && AGGREGATION_READY) {
  cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
  cat("CREATING CROSS-AGE-BAND VISUALIZATIONS\n")
  cat(paste(rep("=", 80), collapse=""), "\n", sep="")
  
  # Source visualization scripts
  if (!exists("helpers_dir")) {
    helpers_dir <- here("helpers_13_1997")
  }
  source(file.path(helpers_dir, "create_cross_ageband_heatmap.R"))
  
  # Create cross-age-band heatmaps for each cohort
  for (cohort_name in COHORT_NAMES) {
    cat(sprintf("\nCreating cross-age-band heatmap for: %s\n", cohort_name))
    
    tryCatch({
      heatmap_file <- create_ageband_heatmap(
        cohort_name = cohort_name,
        event_year = EVENT_YEAR,
        age_bands = AGE_BANDS,
        output_dir = output_dir,
        s3_upload = TRUE,
        top_n = 50
      )
      cat(sprintf("‚úì Heatmap created: %s\n", heatmap_file))
    }, error = function(e) {
      cat(sprintf("‚úó Error creating heatmap for %s: %s\n", cohort_name, e$message))
    })
  }
  
  cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
  cat("‚úì Cross-age-band visualizations complete\n")
  cat(paste(rep("=", 80), collapse=""), "\n", sep="")
  
} else {
  cat("\n", paste(rep("=", 80), collapse=""), "\n", sep="")
  cat("SKIPPING CROSS-AGE-BAND AGGREGATION\n")
  cat(paste(rep("=", 80), collapse=""), "\n", sep="")
  cat("Reason: Not all cohort √ó age-band combinations are complete\n")
  cat("  - Complete all combinations first\n")
  cat("  - Then re-run this cell to create cross-age-band visualizations\n")
  cat(paste(rep("=", 80), collapse=""), "\n", sep="")
}


In [None]:
# Close parallel processing
plan(sequential)

cat("\n========================================\n")
cat("Analysis Complete!\n")
cat("========================================\n")
cat(sprintf("Local output directory: %s\n", output_dir))
cat(sprintf("S3 output location: s3://pgxdatalake/gold/feature_importance/cohort_name=%s/age_band=%s/event_year=%d/\n",
            COHORT_NAME, AGE_BAND, EVENT_YEAR))
cat(sprintf("MC-CV splits: %d\n", N_SPLITS))
cat(sprintf("Train/Test ratio: %.0f/%.0f\n", TRAIN_PROP * 100, TEST_SIZE * 100))
cat("\nResults show scaled feature importance with MC-CV Recall scores\n")
cat("based on", N_SPLITS, "independent train/test splits.\n")


# Sync Results and Code to S3

Sync output files and code (notebook + R script) to S3 bucket. 
- Outputs: CSV results files
- Code: Notebook and R script for reproducibility

In [None]:
# Sync outputs and code to S3
# On EC2, we're in the feature_importance directory  
s3_bucket <- "s3://pgx-repository/pgx-analysis/3_feature_importance/"

# Find AWS CLI (check common locations - EC2 typically has it in /usr/local/bin or /usr/bin)
aws_cmd <- Sys.which("aws")
if (aws_cmd == "") {
  # Try common EC2 installation paths
  aws_paths <- c(
    "/usr/local/bin/aws",
    "/usr/bin/aws",
    "/home/ec2-user/.local/bin/aws"
  )
  aws_cmd <- NULL
  for (path in aws_paths) {
    if (file.exists(path)) {
      aws_cmd <- path
      break
    }
  }
  if (is.null(aws_cmd)) {
    stop("AWS CLI not found. Please install AWS CLI or ensure it's in your PATH.")
  }
}

cat("Syncing outputs and code to S3...\n")
cat("Source: feature_importance/ directory\n")
cat("Destination:", s3_bucket, "\n")
cat("AWS CLI:", aws_cmd, "\n\n")

# Get current directory (should be feature_importance)
current_dir <- getwd()
if (!grepl("feature_importance", current_dir)) {
  warning("Current directory doesn't appear to be feature_importance. Double-check sync destination.")
}

# Sync feature_importance directory (includes outputs/ and code files)
# Explicitly include notebook, R scripts, README files, and outputs directory
# Exclude temporary files, checkpoints, and unnecessary directories
# Note: --delete flag removed for safety (won't delete files in S3 that don't exist locally)
# Include patterns are processed before exclude patterns, then exclude everything else
sync_cmd <- sprintf(
  '"%s" s3 sync "%s" %s --include "*.ipynb" --include "*.R" --include "README*.md" --include "outputs/**" --exclude "*checkpoint*" --exclude "*.tmp" --exclude "*.ipynb_checkpoints/*" --exclude "*.RData" --exclude "*.Rhistory" --exclude ".Rproj.user/*" --exclude "catboost_info/*" --exclude "*.log" --exclude "*"',
  aws_cmd,
  current_dir,
  s3_bucket
)

cat("Running:", sync_cmd, "\n\n")
result <- system(sync_cmd)

if (result == 0) {
  cat("‚úì Successfully synced outputs and code to S3\n")
  cat("  - Outputs:", file.path(output_dir), "\n")
  cat("  - Code: *.ipynb, *.R, README*.md\n")
} else {
  warning(sprintf("S3 sync returned exit code %d. Check AWS credentials and permissions.", result))
}

# ============================================================
# SAVE LOGS TO S3 (aligned with 2_create_cohort)
# ============================================================
cat("\n========================================\n")
cat("Saving logs to S3...\n")
cat("========================================\n")

# Close log file connection
if (exists("log_setup") && !is.null(log_setup$log_connection)) {
  if (isOpen(log_setup$log_connection)) {
    close(log_setup$log_connection)
  }
}

# Save logs to S3
if (exists("logger") && exists("log_file_path")) {
  tryCatch({
    s3_path <- save_logs_to_s3_r(log_file_path, COHORT_NAME, AGE_BAND, EVENT_YEAR, logger)
    if (!is.null(s3_path)) {
      logger$info("‚úì Analysis completed successfully. Logs saved to S3.")
    }
  }, error = function(e) {
    cat(sprintf("Warning: Could not save logs to S3: %s\n", e$message))
    cat(sprintf("Log file saved locally: %s\n", log_file_path))
  })
} else {
  cat("Warning: Logger not initialized. Logs not saved to S3.\n")
}


# Shutdown EC2

In [None]:

# Shutdown EC2 instance after analysis completes
# Set SHUTDOWN_EC2 = TRUE to enable, FALSE to disable
SHUTDOWN_EC2 <- TRUE  # Change to TRUE to enable auto-shutdown

if (SHUTDOWN_EC2) {
  cat("\n========================================\n")
  cat("Shutting down EC2 instance...\n")
  cat("========================================\n")
  
  # Get instance ID from EC2 metadata service
  instance_id <- tryCatch({
    system("curl -s http://169.254.169.254/latest/meta-data/instance-id", intern = TRUE)
  }, error = function(e) {
    cat("Warning: Could not retrieve instance ID from metadata service.\n")
    cat("If running on EC2, check that metadata service is accessible.\n")
    return(NULL)
  })
  
  if (!is.null(instance_id) && length(instance_id) > 0 && nchar(instance_id[1]) > 0) {
    instance_id <- instance_id[1]
    cat(sprintf("Instance ID: %s\n", instance_id))
    
    # Find AWS CLI
    aws_cmd <- Sys.which("aws")
    if (aws_cmd == "") {
      aws_paths <- c(
        "/usr/local/bin/aws",
        "/usr/bin/aws",
        "/home/ec2-user/.local/bin/aws"
      )
      aws_cmd <- NULL
      for (path in aws_paths) {
        if (file.exists(path)) {
          aws_cmd <- path
          break
        }
      }
    }
    
    if (!is.null(aws_cmd) && aws_cmd != "") {
      # Stop the instance (use terminate-instances for permanent deletion)
      shutdown_cmd <- sprintf(
        '"%s" ec2 stop-instances --instance-ids %s',
        aws_cmd,
        instance_id
      )
      
      cat("Running:", shutdown_cmd, "\n")
      result <- system(shutdown_cmd)
      
      if (result == 0) {
        cat("‚úì EC2 instance stop command sent successfully\n")
        cat("Instance will stop in a few moments.\n")
        cat("Note: This is a STOP (not terminate), so you can restart it later.\n")
      } else {
        warning(sprintf("EC2 stop command returned exit code %d. Check AWS credentials and permissions.", result))
      }
    } else {
      cat("Warning: AWS CLI not found. Cannot shutdown instance.\n")
      cat("Install AWS CLI or ensure it's in your PATH.\n")
    }
  } else {
    cat("Warning: Could not determine instance ID. Skipping shutdown.\n")
    cat("If you want to shutdown manually, use:\n")
    cat("  aws ec2 stop-instances --instance-ids <your-instance-id>\n")
  }
} else {
  cat("\n========================================\n")
  cat("EC2 Auto-Shutdown: DISABLED\n")
  cat("========================================\n")
  cat("To enable auto-shutdown, set SHUTDOWN_EC2 = TRUE in this cell.\n")
  cat("Instance will continue running.\n")
}
