In [None]:
%%configure -f
{
    "conf": {
        "spark.jars.packages": "ai.catboost:catboost-spark_3.5_2.12:1.2.7",
        "spark.executor.memory": "24g",
        "spark.executor.cores": "4",       
        "spark.driver.memory": "24g",      
        "spark.yarn.am.memory": "4g",     
        "spark.dynamicAllocation.enabled": "true", 
        "spark.task.cpus": "4",          
        "spark.jars.packages.resolve.transitive": "true",
        "spark.executor.extraJavaOptions": "--add-exports java.base/sun.net.util=ALL-UNNAMED",
        "spark.driver.extraJavaOptions": "--add-exports java.base/sun.net.util=ALL-UNNAMED",
        "spark.network.timeout": "1200s",  
        "spark.rpc.askTimeout": "1200s", 
        "spark.executor.memoryOverhead": "4g"
    }
}

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.ml.linalg import Vectors
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
import catboost_spark

In [None]:
# Adding a parameter tag
cohort = 'cohort1'

In [None]:
# S3 Paths
s3_bucket = f"s3://pgx-repository/ade-risk-model/Step5_Time_to_Event_Model/2_processed_datasets/{cohort}"
train_input_path = f"{s3_bucket}/train"
test_input_path = f"{s3_bucket}/test"

# Read processed train and test datasets from S3
print("Reading train and test datasets...")
train_df = spark.read.parquet(train_input_path)
test_df = spark.read.parquet(test_input_path)

print("Train and test datasets successfully loaded.")

In [None]:
# Verify output
print("Train Dataframe Schema:")
train_df.printSchema()
print("Test Dataframe Schema:")
test_df.printSchema()

In [None]:
# CatBoost Pool objects
train_pool = catboost_spark.Pool(train_df.select("features", "label"))

test_pool = catboost_spark.Pool(test_df.select("features", "label"))

In [None]:
# Seeds for different runs
seeds = [3, 19, 97, 11, 35, 90, 38, 74, 25, 974]

# Start model number tracker
model_num = 1

# Loop to train and save models (10 runs for stable feature selection)
for seed in seeds:
    print(f"Training model {model_num} with seed {seed}...")
    
    # Initialize CatBoost Spark Classifier with the current seed
    classifier = catboost_spark.CatBoostClassifier(randomSeed=seed)

    # Train the model
    model = classifier.fit(train_pool, evalDatasets=[test_pool])

    # Define the path to save the Spark model, including the model number
    spark_model_path = f"s3://pgx-repository/ade-risk-model/Step5_Time_to_Event_Model/4_models/{cohort}/spark_model_{model_num}"

    # Save the Spark model (with metadata)
    model.write().overwrite().save(spark_model_path)

    print(f"Spark model {model_num} with metadata saved to: {spark_model_path}")
    
    # Increment the model number for the next run
    model_num += 1