# Model Feature Importance Filter

In [None]:
val cohort = "cohort1"

In [None]:
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, StructField, ArrayType, StringType}
import ai.catboost.spark.{CatBoostClassificationModel, Pool}
import org.apache.hadoop.fs.{FileSystem, Path}


val spark = SparkSession.builder().
  appName("Evaluate Multiple CatBoost Models For Feature Importance").
  config("spark.jars.packages", "ai.catboost:catboost-spark_3.5_2.12:1.2.7").
  config("spark.executor.memory", "24g").
  config("spark.executor.memoryOverhead", "4g").
  config("spark.executor.cores", "4").
  config("spark.driver.memory", "24g").
  getOrCreate()


// S3 paths
val modelBasePath = s"s3a://pgx-repository/ade-risk-model/Step5_Time_to_Event_Model/4_models/${cohort}/spark_model_"
val featureImportanceBasePath = s"s3a://pgx-repository/ade-risk-model/Step5_Time_to_Event_Model/5_feature_importances/$cohort/feature_importance_model_"
val s3_bucket = s"s3a://pgx-repository/ade-risk-model/Step5_Time_to_Event_Model/2_processed_datasets/${cohort}"
val drug_name_path = s"${s3_bucket}/feature_info_${cohort}.json"


// Define schema for drug names
val schema = StructType(Array(
  StructField("names", ArrayType(StringType), true),
  StructField("types", StringType, true),
  StructField("source_column", StringType, true)
))

// Read drug names for filtering
val drug_names_json = spark.read.schema(schema).option("multiline", "true").json(drug_name_path)

// Process drug names
val drug_name_df = drug_names_json.select(explode(col("names")).alias("raw_name")).
  select(
    trim(split(col("raw_name"), "\\|")(0)).alias("drug_name"),   // Extract drug name
    trim(split(col("raw_name"), "\\|")(1)).alias("drug_index")  // Extract drug index
  ).
  withColumn("drug_index", regexp_replace(col("drug_index"), "drug_name_index_", "").cast("int")).
  orderBy(col("drug_index"))

// Loop through all 10 models
for (modelNum <- 1 to 10) {
  // Load the CatBoost model
  val modelPath = s"${modelBasePath}${modelNum}"
  val catBoostModel = CatBoostClassificationModel.load(modelPath)

  // Get feature importances
  val featureImportances = catBoostModel.getFeatureImportance()

  // Create a DataFrame with feature importances
  val importanceDf = spark.createDataFrame(featureImportances.zipWithIndex.map { case (importance, index) =>
    (index, importance.toDouble)
  }).toDF("array_index", "importance").filter(col("importance") =!= 0)

  // Join feature importances with drug names
  val joinedDf = drug_name_df.join(importanceDf, drug_name_df("drug_index") === importanceDf("array_index"), "inner").
    select("drug_name", "drug_index", "importance").
    orderBy(col("importance").desc)

  // Define output paths
  val tempOutputDir = s"s3a://pgx-repository/temp_model_${modelNum}"
  val finalOutputPath = s"${featureImportanceBasePath}${modelNum}.csv"

  // Save to temporary location
  val fs = FileSystem.get(new java.net.URI("s3a://pgx-repository"), spark.sparkContext.hadoopConfiguration)
  joinedDf.coalesce(1).write.
    option("header", "true").
    mode("overwrite").
    csv(tempOutputDir)

  // Locate and move the part file
  val partFile = fs.listStatus(new Path(tempOutputDir)).
    find(_.getPath.getName.startsWith("part-")).
    map(_.getPath).
    getOrElse(throw new RuntimeException(s"Part file not found for model $modelNum"))

  val finalPath = new Path(finalOutputPath)
  if (fs.exists(finalPath)) fs.delete(finalPath, false)
  fs.rename(partFile, finalPath)

  println(s"Feature importance with drug names for $cohort, model $modelNum successfully written to $finalOutputPath")
}

## Consolidated Feature Importance for Dataset Filter

In [None]:
%%python
cohort = "cohort1"

In [None]:
%%python

import pandas as pd
import boto3
import io
from collections import defaultdict

# AWS S3 bucket and prefix (folder) details
bucket_name = "pgx-repository"
feature_importance_prefix = f"ade-risk-model/Step5_Time_to_Event_Model/5_feature_importances/{cohort}/"
consolidated_output_path = "consolidated_feature_importances.csv"

# Initialize S3 client
s3_client = boto3.client("s3")

# List all CSV files in the specified S3 directory
csv_files = []

# Dictionary to store feature counts and importance values
feature_importance_dict = defaultdict(list)
try:
    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=feature_importance_prefix)
    for obj in response.get("Contents", []):
        if obj["Key"].endswith(".csv"):
            csv_files.append(obj["Key"])
except Exception as e:
    print(f"Error accessing S3 bucket: {e}")
    exit(1)

print(f"Found {len(csv_files)} model feature importance files in S3.")

# Loop through each CSV file and process feature importances
for s3_key in csv_files:
    print(f"Processing file: {s3_key}")
    try:
        # Read the file content from S3
        obj = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
        df = pd.read_csv(io.BytesIO(obj["Body"].read()))

        # Validate column names
        if 'drug_name' not in df.columns or 'importance' not in df.columns:
            print(f"Skipping {s3_key}: 'drug_name' or 'importance' column missing.")
            continue

        # Filter features with non-zero importance
        important_features = df[df['importance'] > 0][['drug_name', 'importance']]

        # Update the dictionary with importance values
        for _, row in important_features.iterrows():
            feature_importance_dict[row['drug_name']].append(row['importance'])

        print(f"Processed {s3_key}: {len(important_features)} important features found.")
    
    except Exception as e:
        print(f"Error reading {s3_key}: {e}")

# Build the consolidated DataFrame
consolidated_data = []
for feature, importances in feature_importance_dict.items():
    consolidated_data.append({
        'drug_name': feature,
        'num_matches': len(importances),
        'mean_importance': sum(importances) / len(importances)  
    })

# Convert to a DataFrame
consolidated_df = pd.DataFrame(consolidated_data)

# Sort first by num_matches (descending) and then by mean_importance (descending)
consolidated_df = consolidated_df.sort_values(by=['num_matches', 'mean_importance'], ascending=[False, False])

# Save the consolidated feature importance to S3
try:
    csv_buffer = io.StringIO()
    consolidated_df.to_csv(csv_buffer, index=False)
    s3_client.put_object(Bucket=bucket_name, Key=feature_importance_prefix + consolidated_output_path, Body=csv_buffer.getvalue())
    print(f"\nConsolidated feature importance saved to S3 at: {feature_importance_prefix + consolidated_output_path}")
except Exception as e:
    print(f"Error saving consolidated file to S3: {e}")


print("\nTop consolidated features across models:")
print(consolidated_df)

### Top Features

In [None]:
%%python
cohort_num = "1"

In [None]:
%%python

import matplotlib
matplotlib.use('Agg')  # Set the backend before importing pyplot

import boto3
import io
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from kneed import KneeLocator
import plotly.express as px
import plotly.graph_objects as go

# AWS S3 bucket and prefix (folder) details
bucket_name = "pgx-repository"
s3_key = f"ade-risk-model/Step5_Time_to_Event_Model/5_feature_importances/cohort{cohort_num}/consolidated_feature_importances.csv"

# Initialize S3 client
s3_client = boto3.client("s3")

# Read consolidated_df
obj = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
consolidated_df = pd.read_csv(io.BytesIO(obj["Body"].read()))

# Combine num_matches and mean_importance for clustering
clustering_data = consolidated_df[['num_matches', 'mean_importance']]

# Function to find the optimal number of clusters using WSS
def find_optimal_clusters(data, max_k=10):
    wss = []
    for k in range(1, max_k + 1):
        kmeans = KMeans(n_clusters=k, n_init=25, random_state=42)
        kmeans.fit(data)
        wss.append(kmeans.inertia_)
    
    # Create a Plotly line plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=list(range(1, max_k + 1)),
        y=wss,
        mode='lines+markers',
        name='WSS'
    ))
    fig.update_layout(
        title="Elbow Method for Optimal k",
        xaxis_title="Number of Clusters (k)",
        yaxis_title="Within-Cluster Sum of Squares (WSS)"
    )
    fig.show()

    # Automatically detect the "elbow" using kneed
    kneedle = KneeLocator(range(1, max_k + 1), wss, curve="convex", direction="decreasing")
    optimal_k = kneedle.knee
    if optimal_k is None:
        raise ValueError("Unable to detect the elbow point. Please check the data.")
    print(f"Optimal number of clusters (k): {optimal_k}")
    return optimal_k

optimal_k = find_optimal_clusters(clustering_data)
kmeans = KMeans(n_clusters=optimal_k, n_init=25, random_state=42)
consolidated_df['Cluster'] = kmeans.fit_predict(clustering_data)

cluster_df = consolidated_df.sort_values(by=['num_matches', 'mean_importance', 'Cluster'], ascending=[False, False, False])

# Display the clustered DataFrame
print("Clustered DataFrame:")
print(cluster_df)

# Save the clustered data to a CSV file
csv_buffer = io.StringIO()
cluster_df.to_csv(csv_buffer, index=False)

# Upload clustered data to S3
s3_client.put_object(
    Bucket="pgx-repository",
    Key=f"ade-risk-model/Step5_Time_to_Event_Model/6_feature_filters/cohort{cohort_num}_clustered_features.csv",
    Body=csv_buffer.getvalue()
)
print("Clustered data uploaded to S3 successfully.")

# Plot clusters using Plotly Express
fig = px.scatter(
    cluster_df,
    x='num_matches',
    y='mean_importance',
    color='Cluster',
    text='drug_name',
    title=f"Cohort {cohort_num}: Clusters of Features Based on Matches and Mean Importance",
    labels={"num_matches": "Number of Matches", "mean_importance": "Mean Importance"}
)
fig.update_traces(textposition='top center')

# Save the interactive plot to an HTML file
html_buffer = io.StringIO()
fig.write_html(html_buffer)

# Upload the Plotly HTML plot to S3
s3_client.put_object(
    Bucket="pgx-repository",
    Key=f"ade-risk-model/Step5_Time_to_Event_Model/6_feature_filters/cohort{cohort_num}_clustered_features_plot.html",
    Body=html_buffer.getvalue(),
    ContentType="text/html"
)
print("Interactive Plotly plot uploaded to S3 successfully.")