In [1]:
from amex_default_prediction.utils import spark_session
from pathlib import Path

intermediate_root = Path("../data/intermediate")
model_path = intermediate_root / "models/logistic/20220710212120-0.4.0-57daea5"
spark = spark_session()

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
spark.sql("set spark.sql.files.ignoreCorruptFiles=true")
df = spark.read.json(f"{intermediate_root}/models/*/*/metadata/part-*")

In [15]:
df.show(vertical=True, n=3, truncate=80)

-RECORD 0--------------------------------------------------------------------------------------------
 avgMetrics       | null                                                                             
 class            | org.apache.spark.ml.feature.StopWordsRemover                                     
 defaultParamMap  | {false, null, en_US, null, SparkTorchModel_17a0a3b6d587__output, null, [i, me... 
 paramMap         | {null, null, null, null, [120,156,108,187,217,206,171,64,219,166,215,27,81,21... 
 persistSubModels | null                                                                             
 sparkVersion     | 3.3.0                                                                            
 stdMetrics       | null                                                                             
 timestamp        | 1657609223020                                                                    
 uid              | SparkTorchModel_17a0a3b6d587                                  

In [21]:
df.printSchema()

root
 |-- avgMetrics: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- class: string (nullable = true)
 |-- defaultParamMap: struct (nullable = true)
 |    |-- caseSensitive: boolean (nullable = true)
 |    |-- foldCol: string (nullable = true)
 |    |-- locale: string (nullable = true)
 |    |-- numFolds: long (nullable = true)
 |    |-- outputCol: string (nullable = true)
 |    |-- seed: long (nullable = true)
 |    |-- stopWords: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- paramMap: struct (nullable = true)
 |    |-- estimatorParamMaps: array (nullable = true)
 |    |    |-- element: array (containsNull = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- isJson: boolean (nullable = true)
 |    |    |    |    |-- name: string (nullable = true)
 |    |    |    |    |-- parent: string (nullable = true)
 |    |    |    |    |-- value: double (nullable = true)
 |    |-- foldCol: strin

In [50]:
import pyspark.sql.functions as F

(
    df.withColumn("filename_parts", F.split(F.input_file_name(), "/"))
    .select(
        F.expr("filename_parts[7]").alias("model"),
        F.expr("filename_parts[8]").alias("version"),
        "paramMap.*",
    )
    .withColumn("estimatorElements", F.explode("estimatorParamMaps"))
    .withColumn("params", F.explode("estimatorElements"))
    .select("model", "version", F.expr("params.*"))
).show(n=10)

+--------+--------------------+------+---------------+--------------------+-----+
|   model|             version|isJson|           name|              parent|value|
+--------+--------------------+------+---------------+--------------------+-----+
|logistic|20220710213020-0....|  true|       regParam|LogisticRegressio...|  0.1|
|logistic|20220710213020-0....|  true|elasticNetParam|LogisticRegressio...|  0.0|
|logistic|20220710213020-0....|  true|       regParam|LogisticRegressio...|  0.1|
|logistic|20220710213020-0....|  true|elasticNetParam|LogisticRegressio...|  0.5|
|logistic|20220710213020-0....|  true|       regParam|LogisticRegressio...|  0.1|
|logistic|20220710213020-0....|  true|elasticNetParam|LogisticRegressio...|  1.0|
|logistic|20220710213020-0....|  true|       regParam|LogisticRegressio...|  1.0|
|logistic|20220710213020-0....|  true|elasticNetParam|LogisticRegressio...|  0.0|
|logistic|20220710213020-0....|  true|       regParam|LogisticRegressio...|  1.0|
|logistic|202207

In [60]:
(
    df.where('class="pyspark.ml.tuning.CrossValidatorModel"')
    .withColumn("filename_parts", F.split(F.input_file_name(), "/"))
    .withColumn("scores", F.explode("avgMetrics"))
    .groupby(
        F.expr("filename_parts[7]").alias("model"),
        F.expr("filename_parts[8]").alias("version"),
    )
    .agg(F.max("scores").alias("bestScore"))
    .orderBy(F.desc("version"))
).show(truncate=False)

+-----------------+-----------------------------+------------------+
|model            |version                      |bestScore         |
+-----------------+-----------------------------+------------------+
|gbt              |20220718014521-0.12.0-2d69426|0.7669987879564415|
|logistic         |20220714061632-0.12.0-2d69426|0.7596409224596953|
|logistic         |20220712070515-0.9.0-d8456e7 |0.7593114523235225|
|gbt-with-aft     |20220711061500-0.9.0-b94a9aa |0.7669615270823718|
|gbt              |20220711060731-0.9.0-b94a9aa |0.7668335629482486|
|gbt-with-aft     |20220711060036-0.9.0-b94a9aa |0.7674870606428792|
|logistic-with-aft|20220711055817-0.9.0-b94a9aa |0.759325690605417 |
|logistic-with-aft|20220711052503-0.8.0-6bbdfec |0.759325690605417 |
|logistic-with-aft|20220711052109-0.8.0-6bbdfec |0.7578148048246048|
|logistic-with-aft|20220711050708-0.8.0-6bbdfec |0.759325690605417 |
|logistic-with-aft|20220711050352-0.8.0-6bbdfec |0.759325690605417 |
|logistic-with-aft|20220711050213-