In [None]:
%%configure -f
{
    "conf": {
        "spark.broadcast.compress": "true", 
        "spark.jars.packages": "ai.catboost:catboost-spark_3.5_2.12:1.2.7",
        "spark.jars.packages.resolve.transitive": "true",
        "spark.executor.memory": "180g",
        "spark.executor.cores": "1",   
        "spark.executorEnv.CATBOOST_WORKER_INIT_TIMEOUT": "3600s",
        "spark.executor.extraJavaOptions": "--add-exports java.base/sun.net.util=ALL-UNNAMED",
        "spark.executor.memoryOverhead": "8g",
        "spark.driver.extraJavaOptions": "--add-exports java.base/sun.net.util=ALL-UNNAMED",
        "spark.driver.memory": "45g",          
        "spark.dynamicAllocation.enabled": "true",
        "spark.dynamicAllocation.minExecutors": "2",
        "spark.dynamicAllocation.maxExecutors": "103",     
        "spark.network.timeout": "1200s",  
        "spark.rpc.askTimeout": "1200s", 
        "spark.rpc.message.maxSize": "512",
        "spark.sql.broadcastTimeout": "1200s",
        "spark.sql.session.timeout": "1200s",
        "spark.sql.shuffle.partitions": "103",
        "spark.sql.autoBroadcastJoinThreshold": "-1",
        "spark.shuffle.service.enabled": "true",
        "spark.task.cpus": "1",  
        "spark.yarn.am.memory": "45g"
    }
}

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.ml.linalg import Vectors
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
import catboost_spark

In [None]:
# Adding a parameter tag
cohort = 'cohort6'

In [None]:
# S3 Paths
s3_bucket = f"s3://pgx-repository/ade-risk-model/Step5_Time_to_Event_Model/2_processed_datasets/{cohort}"
train_input_path = f"{s3_bucket}/train"
test_input_path = f"{s3_bucket}/test"

# Read processed train and test datasets from S3
print("Reading train and test datasets...")
train_df = spark.read.parquet(train_input_path)
test_df = spark.read.parquet(test_input_path)

print("Train and test datasets successfully loaded.")

In [None]:
# Verify output
print("Train Dataframe Schema:")
train_df.printSchema()
print("Test Dataframe Schema:")
test_df.printSchema()

In [None]:
from pyspark.sql import functions as F

# Step 1: Compute Features
polypharmacy_df = df.groupBy("mi_person_key").agg(
    F.countDistinct("drug_name").alias("polypharmacy")
)

activity_count_df = df.groupBy("mi_person_key").agg(
    F.count("*").alias("activity_count")
)

# Add computed features to the main DataFrame
enhanced_df = df.join(polypharmacy_df, on="mi_person_key", how="left") \
                .join(activity_count_df, on="mi_person_key", how="left")

# Step 2: Create Polypharmacy and Activity Count Bins
enhanced_df = enhanced_df.withColumn(
    "polypharmacy_bin",
    F.when(F.col("polypharmacy") <= 5, "low")
     .when((F.col("polypharmacy") > 5) & (F.col("polypharmacy") <= 10), "medium")
     .otherwise("high")
).withColumn(
    "activity_count_bin",
    F.when(F.col("activity_count") <= 50, "low")
     .when((F.col("activity_count") > 50) & (F.col("activity_count") <= 200), "medium")
     .otherwise("high")
)

# Step 3: Add a Hash Partition Column
enhanced_df = enhanced_df.withColumn(
    "hash_partition", F.abs(F.hash("mi_person_key")) % 103  # Distributes evenly across 103 partitions
)

# Step 4: Repartition the Data
partitioned_df = enhanced_df.repartition(103, "hash_partition", "polypharmacy_bin", "mi_person_key")

# Step 5: Create CatBoost Pool
train_pool = catboost_spark.Pool(partitioned_df.select("features", "label"))


In [None]:
# CatBoost Pool objects
from pyspark import StorageLevel

# Cache or persist the Spark DataFrames before creating the Pool
train_df = train_df.select("features", "label").persist(StorageLevel.MEMORY_AND_DISK)
test_df = test_df.select("features", "label").persist(StorageLevel.MEMORY_AND_DISK)

# Create the CatBoost Pool objects
train_pool = catboost_spark.Pool(train_df)
test_pool = catboost_spark.Pool(test_df)

# Confirm the DataFrames are cached/persisted
print(train_df.storageLevel)
print(test_df.storageLevel)

In [None]:
train_df.unpersist()
test_df.unpersist()