## Imports

In [0]:
from pyspark.sql.functions import col
import mlflow
import mlflow.spark
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator


## Load data

In [0]:
df = spark.table('msme_risk_analytics.gold_ml_training_data')
print(f"Total records: {df.count()}")
df.describe().show()

Total records: 14787
+-------+------------------+-----------------+-----------------+------------------+------------------+--------------------+------------------+------------------+
|summary|       loan_amount|           income|     Credit_Score|               LTV|             dtir1|loan_to_income_ratio|        risk_score|            Status|
+-------+------------------+-----------------+-----------------+------------------+------------------+--------------------+------------------+------------------+
|  count|             14787|            14787|            14787|             14787|             14787|               14787|             14787|             14787|
|   mean| 240031.4803543653|4676.173666058024|699.5558260634341| 84.14687049396446| 42.92743626158112|   56.53304232960376| 50.14005898412645|0.2792993845945763|
| stddev|140115.64712375688| 2721.29860235663|116.5768679802515|16.247318327348566|11.341874316869243|  52.034557850208415|7.8201396307381925|0.4486700926034405|
|    mi

##  Train/test split

In [0]:
train, test = df.randomSplit([0.8, 0.2], seed=42)
print(f"Train: {train.count()}, Test: {test.count()}")

Train: 11758, Test: 3029


## Feature vector

In [0]:
feature_cols = ['loan_amount', 'income', 'Credit_Score', 'LTV', 'dtir1', 
                'loan_to_income_ratio', 'risk_score']
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
train_vec = assembler.transform(train)
test_vec = assembler.transform(test)

## Evaluators

In [0]:
roc_eval = BinaryClassificationEvaluator(labelCol='Status', metricName='areaUnderROC')
acc_eval = MulticlassClassificationEvaluator(labelCol='Status', metricName='accuracy')

## Model 1 - Logistic Regression

In [0]:
mlflow.set_experiment("/Users/surendharbalakrishnan@outlook.com/msme_risk_ml")

with mlflow.start_run(run_name="Logistic_Regression"):
    lr = LogisticRegression(featuresCol='features', labelCol='Status', maxIter=10)
    lr_model = lr.fit(train_vec)
    lr_pred = lr_model.transform(test_vec)
    
    roc = roc_eval.evaluate(lr_pred)
    acc = acc_eval.evaluate(lr_pred)
    
    mlflow.log_param("model_type", "LogisticRegression")
    mlflow.log_metric("roc_auc", roc)
    mlflow.log_metric("accuracy", acc)
    
    print(f"Logistic Regression - ROC: {roc:.4f}, Accuracy: {acc:.4f}")

2026/01/30 10:45:51 INFO mlflow.tracking.fluent: Experiment with name '/Users/surendharbalakrishnan@outlook.com/msme_risk_ml' does not exist. Creating a new experiment.


Logistic Regression - ROC: 0.5936, Accuracy: 0.7349


## Model 2 - Random Forest

In [0]:
with mlflow.start_run(run_name="Random_Forest"):
    rf = RandomForestClassifier(featuresCol='features', labelCol='Status', numTrees=50, seed=42)
    rf_model = rf.fit(train_vec)
    rf_pred = rf_model.transform(test_vec)
    
    roc = roc_eval.evaluate(rf_pred)
    acc = acc_eval.evaluate(rf_pred)
    
    mlflow.log_param("model_type", "RandomForest")
    mlflow.log_param("num_trees", 50)
    mlflow.log_metric("roc_auc", roc)
    mlflow.log_metric("accuracy", acc)
    
    print(f"Random Forest - ROC: {roc:.4f}, Accuracy: {acc:.4f}")

Random Forest - ROC: 0.6432, Accuracy: 0.7603


## Model 3 - GBT

In [0]:
with mlflow.start_run(run_name="GBT"):
    gbt = GBTClassifier(featuresCol='features', labelCol='Status', maxIter=20, seed=42)
    gbt_model = gbt.fit(train_vec)
    gbt_pred = gbt_model.transform(test_vec)
    
    roc = roc_eval.evaluate(gbt_pred)
    acc = acc_eval.evaluate(gbt_pred)
    
    mlflow.log_param("model_type", "GBT")
    mlflow.log_param("max_iter", 20)
    mlflow.log_metric("roc_auc", roc)
    mlflow.log_metric("accuracy", acc)
    
    print(f"GBT - ROC: {roc:.4f}, Accuracy: {acc:.4f}")

GBT - ROC: 0.6955, Accuracy: 0.7613
