In [0]:
from pyspark.ml.classification import GBTClassifier
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.functions import col, udf
from pyspark.sql.types import DoubleType

# Carregar dados
train_ready = spark.read.format("delta").load("/FileStore/data/train_ready")
val_ready = spark.read.format("delta").load("/FileStore/data/val_ready")

# Balanceamento leve (undersample da classe maioritária)
minority_df = train_ready.filter(col("label") == 1)
majority_df = train_ready.filter(col("label") != 1)
train_balanced = majority_df.sample(False, 0.4, seed=42).union(minority_df)

print("\nDistribuição após balanceamento:")
train_balanced.groupBy("label").count().show()

# Treinar modelo GBT sem class weights
gbt = GBTClassifier(
    labelCol="label",
    featuresCol="features",
    maxIter=20,
    maxDepth=5,
    seed=42
)
model = gbt.fit(train_balanced)

# Inferência no conjunto de validação
val_preds = model.transform(val_ready)

# Aplicar threshold manual conhecido
def apply_threshold(df, threshold):
    predict_udf = udf(lambda prob: float(1.0) if prob[1] > threshold else float(0.0), DoubleType())
    return df.withColumn("adjusted_prediction", predict_udf(col("probability")))

# ✅ Threshold conhecido
best_threshold = 0.30
val_preds_adjusted = apply_threshold(val_preds, best_threshold)

# Avaliação final
final_rdd = val_preds_adjusted.select("adjusted_prediction", "label").rdd.map(lambda r: (float(r[0]), float(r[1])))
metrics = MulticlassMetrics(final_rdd)

print(f"\n✅ Avaliação com Threshold = {best_threshold:.2f}")
print("Confusion Matrix:")
print(metrics.confusionMatrix().toArray())

print("\n🎯 Métricas finais:")
print(f"Precision classe 1: {metrics.precision(1.0):.4f}")
print(f"Recall classe 1:    {metrics.recall(1.0):.4f}")
print(f"F1 classe 1:        {metrics.fMeasure(1.0):.4f}")

# Guardar modelo final
model.write().overwrite().save("/FileStore/models/gbt_top10_no_weights")


Distribuição após balanceamento:
+-----+------+
|label| count|
+-----+------+
|  0.0|674488|
|  1.0|125441|
+-----+------+


🔍 Threshold Search (para F1 classe 1):
Threshold = 0.05 | F1 (classe 1): 0.0841
Threshold = 0.10 | F1 (classe 1): 0.0841
Threshold = 0.15 | F1 (classe 1): 0.0841
Threshold = 0.20 | F1 (classe 1): 0.0841
Threshold = 0.25 | F1 (classe 1): 0.0841
Threshold = 0.30 | F1 (classe 1): 0.0841
Threshold = 0.35 | F1 (classe 1): 0.0841
Threshold = 0.40 | F1 (classe 1): 0.0841
Threshold = 0.45 | F1 (classe 1): 0.0841
Threshold = 0.50 | F1 (classe 1): 0.0841
Threshold = 0.55 | F1 (classe 1): 0.0837
Threshold = 0.60 | F1 (classe 1): 0.0030
Threshold = 0.65 | F1 (classe 1): 0.0007
Threshold = 0.70 | F1 (classe 1): 0.0000
Threshold = 0.75 | F1 (classe 1): 0.0000
Threshold = 0.80 | F1 (classe 1): 0.0000
Threshold = 0.85 | F1 (classe 1): 0.0000
Threshold = 0.90 | F1 (classe 1): 0.0000

✅ Melhor Threshold Encontrado: 0.10 com F1 da classe 1 = 0.0841

Confusion Matrix (com melhor th