##Setup and Functions

In [0]:
from pyspark.sql import Row,SparkSession
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import MinMaxScaler
from pyspark.mllib.evaluation import MulticlassMetrics, BinaryClassificationMetrics
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import time
from pyspark.sql.functions import *
import numpy as np
import random
import datetime
import pandas as pd
import xgboost as xgb
import mlflow.xgboost
import math
import itertools

from sparkdl.xgboost import XgboostRegressor,XgboostClassifier
from sklearn.model_selection import RandomizedSearchCV

from pyspark.sql.types import DoubleType
from pyspark.sql.functions import lit, udf

In [0]:
blob_container = "team06" # The name of your container created in https://portal.azure.com
storage_account = "apatel" # The name of your Storage account created in https://portal.azure.com
secret_scope = "team06" # The name of the scope created in your local computer using the Databricks CLI
secret_key = "team06" # The name of the secret key created in your local computer using the Databricks CLI 
blob_url = f"wasbs://{blob_container}@{storage_account}.blob.core.windows.net"
mount_path = "/mnt/mids-w261"

In [0]:
# this is not used; it was for an alternate strategy we did not choose

def find_optimal_threshold_df(df, fold_num, search_center=0.6, search_bounds=0.2, granularity=5, times_to_zoom=4):
    """Finds optimal threshold for a model given a fold number and inserts new 
    prediction column based on this optimal threshold (scored using f1 score)"""
    
    prob_name = "probability_" + str(fold_num)
    pred_name = "prediction_" + str(fold_num)
    rev_pred_name = "rev_pred_" + str(fold_num)
    label_name = "label_" + str(fold_num)
    
    def ith_(v, i):
        try:
            return float(v[i])
        except ValueError:
            return None
    
    ith = udf(ith_, DoubleType())
    output = df.withColumn("del_prob",ith(prob_name, lit(1)))
    
    for i in range(times_to_zoom):
        search_space = np.linspace(search_center - search_bounds, search_center + search_bounds, granularity)
        best_score = 0
        prior_score = 0
        best_thresh = -1
        for threshold in search_space:
            test_df = output.select("_id",label_name,'del_prob')
            test_df = test_df.withColumn('prediction', when((col('del_prob') >= lit(threshold)), 1.0).otherwise(0.0))
            test_df = test_df.withColumnRenamed(label_name, "label")
            test_df = test_df.select('label','prediction')
            test_df.cache()
            test_metrics = MulticlassMetrics(test_df.rdd)
            f1_score = test_metrics.fMeasure(1.0,1.0)
            print("threshold:",threshold,"f1 score:",f1_score)
            if f1_score > best_score:
                best_score = f1_score
                best_thresh = threshold
            elif f1_score < prior_score:
                break
            prior_score = f1_score
        print("="*45)
        print("best score this level:", best_score, "at threshold", best_thresh)
        print("="*45)
        search_center = best_thresh
        search_bounds = search_bounds / 4
        test_df.unpersist()
        
    join_df = output.select("_id",label_name,'del_prob')
    join_df = test_df.withColumn('prediction', when((col('del_prob') >= lit(best_thresh)), 1.0).otherwise(0.0))
    join_metric = join_df.withColumnRenamed(label_name, "label").select("label","prediction")
    join_metrics = MulticlassMetrics(join_metric.rdd)
    new_f1 = test_metrics.fMeasure(1.0,1.0)
    
    join_df = join_df.withColumnRenamed("prediction",rev_pred_name).select('_id',rev_pred_name)

    score_metrics = df.select(label_name, pred_name).withColumnRenamed(label_name, "label").withColumnRenamed(pred_name, "prediction")
    metrics = MulticlassMetrics(score_metrics.rdd)
    orig_f1 = metrics.fMeasure(1.0,1.0)
    print("="*45)
    print("For fold",fold_num)
    print("="*45)
    print("Original f1 score:",orig_f1)
    print("New f1 score:", new_f1)

    final_df = df.join(join_df,['_id'])

    return final_df

In [0]:
def find_optimal_threshold(df, search_center=0.6, search_bounds=0.2, granularity=5, times_to_zoom=4):
    """Finds optimal threshold for a model based on f1 score"""
    
    def ith_(v, i):
        try:
            return float(v[i])
        except ValueError:
            return None
    
    ith = udf(ith_, DoubleType())
    output = df.withColumn("del_prob",ith("probability", lit(1)))
    
    for i in range(times_to_zoom):
        search_space = np.linspace(search_center - search_bounds, search_center + search_bounds, granularity)
        best_score = 0
        prior_score = 0
        best_thresh = -1
        for threshold in search_space:
            test_df = output.select('label','del_prob')
            test_df = test_df.withColumn('prediction', when((col('del_prob') >= lit(threshold)), 1.0).otherwise(0.0))
            test_df = test_df.select('label','prediction')
            test_df.cache()
            test_metrics = MulticlassMetrics(test_df.rdd)
            f1_score = test_metrics.fMeasure(1.0,1.0)
            print("threshold:",threshold,"f1 score:",f1_score)
            if f1_score > best_score:
                best_score = f1_score
                best_thresh = threshold
            elif f1_score < prior_score:
                break
            prior_score = f1_score
        print("="*45)
        print("best score this level:", best_score, "at threshold", best_thresh)
        print("="*45)
        search_center = best_thresh
        search_bounds = search_bounds / 4
        test_df.unpersist()

    print("overall best threshold:", best_thresh, "with f1 score", best_score)
    return best_thresh

In [0]:
preds_1 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_9_fold_1').select("features","label","prediction","probability")
preds_2 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_9_fold_2').select("features","label","prediction","probability")
preds_3 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_9_fold_3').select("features","label","prediction","probability")
preds_4 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_9_fold_4').select("features","label","prediction","probability")
preds_5 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_9_fold_5').select("features","label","prediction","probability")

display(preds_1)

features,label,prediction,probability
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 8, 9, 10, 11, 13, 15, 23, 24, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 0.07, 229.0, 16.0, 0.11, 1.0, 1.0, 1.0, 1.0, 0.07, 0.11, 0.016721392451989205, 0.017058616124661325, 10272.0, 57.0, 10332.426267281106, 16000.0, 111.0, -17.0, 10187.728110599079, 10260.0, 1793.1946564885495, 0.064, 0.11, 5.75, 11.0, 10.0, 1.0, 82.0, 425.79, 564.19, 284.85, 5.75, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.8526660203933716, 0.1473339945077896))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 8, 9, 10, 11, 13, 15, 23, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 0.14, 263.0, 19.0, 0.11, 1.0, 1.0, 1.0, 0.14, 0.11, 0.016721392451989205, 0.017058616124661325, 10278.0, 26.0, 10332.426267281106, 16000.0, 122.0, -28.0, 10187.728110599079, 10264.0, 1793.1946564885495, 0.167, 1.0, 11.0, 10.0, 1.0, 59.0, 301.3, 315.96, 202.32, 7.456603773584906, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.8603250980377197, 0.13967491686344147))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 7, 8, 9, 10, 11, 13, 15, 23, 27, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 1.0, 0.06, 252.0, 18.0, 0.05, 1.0, 1.0, 1.0, 1.0, 0.06, 0.05, 0.016721392451989205, 0.017058616124661325, 10152.0, 41.0, 10332.426267281106, 16000.0, 161.0, 39.0, 10187.728110599079, 10139.0, 1793.1946564885495, 0.112, 1.0, 11.0, 10.0, 1.0, 165.0, 334.77, 341.53, 575.44, 2.921259842519685, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.8675660490989685, 0.1324339658021927))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 8, 9, 10, 11, 13, 15, 23, 29, 31, 32, 33, 35, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 0.12, 273.0, 18.0, 0.11, 1.0, 1.0, 1.0, 0.12, 0.11, 0.016721392451989205, 0.017058616124661325, 3.0, 10112.0, 72.0, 10332.426267281106, 16000.0, 128.0, 94.0, 10187.728110599079, 10098.0, 1793.1946564885495, 0.08, 1.0, 11.0, 10.0, 1.0, 117.0, 313.68, 366.54, 212.88, 8.592592592592593, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.8229661583900452, 0.17703382670879364))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 8, 9, 10, 11, 13, 15, 23, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 0.08, 278.0, 18.0, 0.16, 1.0, 1.0, 1.0, 0.08, 0.16, 0.016721392451989205, 0.017058616124661325, 10100.0, 51.0, 10332.426267281106, 16000.0, 133.0, 6.0, 10187.728110599079, 10088.0, 1793.1946564885495, 0.158, 12.0, 1.0, 11.0, 10.0, 1.0, 73.0, 275.96, 359.61, 236.22, 8.226618705035971, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.833183228969574, 0.16681675612926483))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 8, 9, 10, 11, 13, 15, 23, 24, 25, 26, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 1.0, 105.0, 0.15, 275.0, 18.0, 0.38, 1.0, 1.0, 1.0, 1.0, 1.0, 105.0, 0.15, 0.38, 0.016721392451989205, 0.017058616124661325, 10164.0, 72.0, 10332.426267281106, 16000.0, 94.0, -33.0, 10187.728110599079, 10152.0, 1793.1946564885495, 0.167, 1.0, 11.0, 10.0, 1.0, 342.84, 312.72, 155.56, 23.764492753623188, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",1.0,1.0,"Map(vectorType -> dense, length -> 2, values -> List(0.03099977970123291, 0.9690002202987671))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 8, 9, 10, 11, 13, 15, 23, 24, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 0.07, 235.0, 17.0, 0.06, 1.0, 1.0, 1.0, 1.0, 0.07, 0.06, 0.016721392451989205, 0.017058616124661325, 10176.0, 72.0, 10332.426267281106, 16000.0, 150.0, -61.0, 10187.728110599079, 10162.0, 1793.1946564885495, 0.058, 0.06, 4.246445497630332, 11.0, 10.0, 1.0, 78.0, 414.11, 505.86, 306.23, 4.246445497630332, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.8345140218734741, 0.1654859483242035))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 13, 15, 23, 25, 26, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 1.0, 48.0, 1.0, 68.0, 0.1, 275.0, 18.0, 0.14, 1.0, 1.0, 1.0, 1.0, 48.0, 68.0, 0.1, 1.0, 0.14, 0.016721392451989205, 0.017058616124661325, 10164.0, 26.0, 10332.426267281106, 16000.0, 156.0, -50.0, 10187.728110599079, 10152.0, 1793.1946564885495, 0.105, 1.0, 11.0, 10.0, 1061.0, 291.61, 347.53, 301.89, 11.032490974729242, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",1.0,1.0,"Map(vectorType -> dense, length -> 2, values -> List(0.15619045495986938, 0.8438095450401306))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 8, 9, 10, 11, 13, 15, 23, 25, 26, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 1.0, 106.0, 0.21, 262.0, 18.0, 0.52, 1.0, 1.0, 1.0, 1.0, 106.0, 0.21, 0.52, 0.016721392451989205, 0.017058616124661325, 10074.0, 62.0, 10332.426267281106, 16000.0, 106.0, 72.0, 10187.728110599079, 10061.0, 1793.1946564885495, 0.141, 1.0, 725.0, 1.0, 11.0, 10.0, 1.0, 290.71, 356.0, 310.5, 43.95864661654135, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",1.0,1.0,"Map(vectorType -> dense, length -> 2, values -> List(0.02944082021713257, 0.9705591797828674))"
"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 4, 8, 9, 10, 11, 13, 15, 23, 26, 29, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84), values -> List(604.0, 21.0, 979.0, 3.0, 0.12, 268.0, 18.0, 0.18, 1.0, 1.0, 1.0, 3.0, 0.12, 0.18, 0.016721392451989205, 0.017058616124661325, 10092.0, 41.0, 10332.426267281106, 16000.0, 150.0, 33.0, 10187.728110599079, 10078.0, 1793.1946564885495, 0.125, 1.0, 117.0, 1.0, 11.0, 10.0, 1.0, 37.0, 318.88, 341.59, 199.61, 14.725563909774436, 0.2045052531119649, 17.565871638131753, 1.0, 7.0, 3.0, 0.2306252035167698, 20.615678931943993, 1.0, 20.0, 25.0, 0.42857142857142855, 128.5142857142857, 1.0, 33.0, 1.0))",0.0,0.0,"Map(vectorType -> dense, length -> 2, values -> List(0.6807273626327515, 0.31927260756492615))"


In [0]:
for i in range(1,6):
    preds_df = find_optimal_threshold(preds_df, search_center=0.6, search_bounds=0.2, granularity=5, times_to_zoom=4, fold_num=i)

In [0]:
preds_1 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_7_fold_1')
preds_1.cache()
preds_1_thresh = find_optimal_threshold(preds_1)
preds_1.unpersist()

preds_2 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_7_fold_2')
preds_2.cache()
preds_2_thresh = find_optimal_threshold(preds_2) 
preds_2.unpersist()

preds_3 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_7_fold_3')
preds_3.cache()
preds_3_thresh = find_optimal_threshold(preds_3) 
preds_3.unpersist()

preds_4 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_7_fold_4')
preds_4.cache()
preds_4_thresh = find_optimal_threshold(preds_4) 
preds_4.unpersist()

preds_5 = spark.read.parquet(f'{blob_url}/xgboost_val_preds_4_7_fold_5')
preds_5.cache()
preds_5_thresh = find_optimal_threshold(preds_5) 
preds_5.unpersist()

##Re-Labeling on new thresholds

In [0]:
def infer_new_labels(df, threshold, name):
    """Based on input df and threshold, output df with re-inferred labels from new threshold."""
    def ith_(v, i):
        try:
            return float(v[i])
        except ValueError:
            return None
    
    ith = udf(ith_, DoubleType())
    output = df.withColumn("del_prob",ith("probability", lit(1)))
    
    metrics = MulticlassMetrics(df.select('label','prediction').rdd)
    orig_f1 = metrics.fMeasure(1.0,1.0)
    print("="*45)
    print("For",name)
    print("Original f1 score:",orig_f1)
    
    test_df = output.select('features','label','del_prob')
    test_df = test_df.withColumn('prediction', when((col('del_prob') >= lit(threshold)), 1.0).otherwise(0.0))
    test_df = test_df.select('features','label','prediction')
    test_df.cache()
    test_metrics = MulticlassMetrics(test_df.select('label','prediction').rdd)
    new_f1 = test_metrics.fMeasure(1.0,1.0)
    test_df.unpersist()
    
    print("New f1 score:", new_f1)
    
    return test_df

### Apply new labels and save out files

In [0]:
CV_name = 'fold_1'
save_name = 'xgboost_reinf_4_9_'+CV_name
preds_1 = spark.read.parquet(f'{blob_url}/xgboost_test_preds_4_9_fold_1')
preds_1.cache()
preds_1_reinf = infer_new_labels(df=preds_1, threshold=preds_1_thresh, name=CV_name)
preds_1.unpersist()
preds_1_reinf.cache()
preds_1_reinf.write.parquet(f"{blob_url}/{save_name}")
preds_1_reinf.unpersist()

CV_name = 'fold_2'
save_name = 'xgboost_reinf_4_9_'+CV_name
preds_2 = spark.read.parquet(f'{blob_url}/xgboost_test_preds_4_9_fold_2')
preds_2.cache()
preds_2_reinf = infer_new_labels(df=preds_2, threshold=preds_2_thresh, name=CV_name)
preds_2.unpersist()
preds_2_reinf.cache()
preds_2_reinf.write.parquet(f"{blob_url}/{save_name}")
preds_2_reinf.unpersist()

CV_name = 'fold_3'
save_name = 'xgboost_reinf_4_9_'+CV_name
preds_3 = spark.read.parquet(f'{blob_url}/xgboost_test_preds_4_9_fold_3')
preds_3.cache()
preds_3_reinf = infer_new_labels(df=preds_3, threshold=preds_3_thresh, name=CV_name)
preds_3.unpersist()
preds_3_reinf.cache()
preds_3_reinf.write.parquet(f"{blob_url}/{save_name}")
preds_3_reinf.unpersist()

CV_name = 'fold_4'
save_name = 'xgboost_reinf_4_9_'+CV_name
preds_4 = spark.read.parquet(f'{blob_url}/xgboost_test_preds_4_9_fold_4')
preds_4.cache()
preds_4_reinf = infer_new_labels(df=preds_4, threshold=preds_4_thresh, name=CV_name)
preds_4.unpersist()
preds_4_reinf.cache()
preds_4_reinf.write.parquet(f"{blob_url}/{save_name}")
preds_4_reinf.unpersist()

CV_name = 'fold_5'
save_name = 'xgboost_reinf_4_9_'+CV_name
preds_5 = spark.read.parquet(f'{blob_url}/xgboost_test_preds_4_9_fold_5')
preds_5.cache()
preds_5_reinf = infer_new_labels(df=preds_5, threshold=preds_5_thresh, name=CV_name)
preds_5.unpersist()
preds_5_reinf.cache()
preds_5_reinf.write.parquet(f"{blob_url}/{save_name}")
preds_5_reinf.unpersist()

#Voting and Test Set Performance

In [0]:
# We realized after conducting inference that the order had become scrambled between the folds
# Not all was lost! After much testing, we determined that only 24 rows had duplicate features, and only 18 of those with duplicate labels
# Reasoning that a fold's prediction will always be consistent with the same features, we held fold 1 in place and joined the other folds to it based on features
# 

# read in files

preds_1_reinf = spark.read.parquet(f'{blob_url}/xgboost_reinf_4_9_fold_1')
preds_2_reinf = spark.read.parquet(f'{blob_url}/xgboost_reinf_4_9_fold_2')
preds_3_reinf = spark.read.parquet(f'{blob_url}/xgboost_reinf_4_9_fold_3')
preds_4_reinf = spark.read.parquet(f'{blob_url}/xgboost_reinf_4_9_fold_4')
preds_5_reinf = spark.read.parquet(f'{blob_url}/xgboost_reinf_4_9_fold_5')

# verify lengths

print(preds_1_reinf.count())
print(preds_2_reinf.count())
print(preds_3_reinf.count())
print(preds_4_reinf.count())
print(preds_5_reinf.count())

In [0]:
all_preds = preds_1_reinf.withColumnRenamed("prediction","prediction_1").select("features","prediction_1","label")

all_preds = all_preds.alias("L").join(preds_2_reinf.alias("R").select("features","prediction","label")\
                                      .withColumnRenamed("prediction","prediction_2"), ['features', 'label'])

# drop duplicates with more than 2 matches

all_preds.createOrReplaceTempView("Preds")
all_preds = spark.sql("SELECT features, label, prediction_1, prediction_2 FROM (SELECT P.*, ROW_NUMBER() OVER(PARTITION BY features, label ORDER BY prediction_1 ASC) as rownum FROM Preds P) a WHERE a.rownum < 3")


all_preds = all_preds.alias("L").join(preds_3_reinf.alias("R").select("features","prediction","label")\
                                      .withColumnRenamed("prediction","prediction_3"), ['features', 'label'])

all_preds.createOrReplaceTempView("Preds")
all_preds = spark.sql("SELECT features, label, prediction_1, prediction_2, prediction_3 FROM (SELECT P.*, ROW_NUMBER() OVER(PARTITION BY features, label ORDER BY prediction_1 ASC) as rownum FROM Preds P) a WHERE a.rownum < 3")


all_preds = all_preds.alias("L").join(preds_4_reinf.alias("R").select("features","prediction","label")\
                                      .withColumnRenamed("prediction","prediction_4"), ['features', 'label'])

all_preds.createOrReplaceTempView("Preds")
all_preds = spark.sql("SELECT features, label, prediction_1, prediction_2, prediction_3, prediction_4 FROM (SELECT P.*, ROW_NUMBER() OVER(PARTITION BY features, label ORDER BY prediction_1 ASC) as rownum FROM Preds P) a WHERE a.rownum < 3")


all_preds = all_preds.alias("L").join(preds_5_reinf.alias("R").select("features","prediction","label")\
                                      .withColumnRenamed("prediction","prediction_5"), ['features', 'label'])

all_preds.createOrReplaceTempView("Preds")
all_preds = spark.sql("SELECT label, features, prediction_1, prediction_2, prediction_3, prediction_4, prediction_5 FROM (SELECT P.*, ROW_NUMBER() OVER(PARTITION BY features, label ORDER BY prediction_1 ASC) as rownum FROM Preds P) a WHERE a.rownum < 3")

In [0]:
# verify join results

all_preds.cache()

display(all_preds)
all_preds.count()

label,features,prediction_1,prediction_2,prediction_3,prediction_4,prediction_5
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 22, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 49, 50, 51, 52, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(1554.0, 20.0, 487.0, 1.0, 159.0, 1.0, 156.0, 1.0, 0.33, 401.0, 199.0, 0.21, 1.0, 1.0, 1.0, 1.0, 159.0, 1.0, 156.0, 0.33, 1.0, 0.21, 0.009816910506378085, 0.01025791928499623, 118.0, 10100.0, 82.0, 22000.0, 16093.0, -83.0, 10102.0, 10091.0, 1524.0, 0.111, 1.0, 26.0, 21.0, 1.0, 619.0, 253.07, 347.06, 251.83, 13.962666666666667, 0.20807924111123, 11.206206982632574, 3.0, 10.0, 0.19555506837302347, 13.98635279932038, 52.0, 125.0, 0.08695652173913043, 6.6521739130434785, 3091.0, 4412.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 81, 83, 84), values -> List(1212.0, 728.0, 607.0, 1.0, 24.0, 1.0, 17.0, 1.0, 0.17, 65.0, 684.0, 0.18, 1.0, 1.0, 1.0, 1.0, 1.0, 24.0, 1.0, 17.0, 0.17, 1.0, 0.18, 0.0034340306679597715, 0.004106753071779026, 10159.0, 51.0, 22000.0, 16000.0, -50.0, -100.0, 10170.035815268615, 9881.0, 936.1568998109641, 0.071, 1.0, 72.0, 61.0, 1.0, 488.0, 364.77, 244.85, 344.29, 16.705128205128204, 0.16533909524420967, 11.721165724304962, 8.0, 9.0, 0.13805686483199026, 10.893720541280219, 226.0, 230.0, 0.45454545454545453, 3391.0, 5258.0))",0.0,0.0,0.0,0.0,0.0
1.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 22, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 54, 55, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(1999.0, 146.0, 1135.0, 1.0, 24.0, 1.0, 18.0, 1.0, 0.17, 241.0, 481.0, 0.14, 1.0, 1.0, 1.0, 1.0, 24.0, 1.0, 18.0, 0.17, 1.0, 0.14, 0.00930196273177531, 0.008631014346010992, 10173.0, 36.0, 7620.0, 16093.0, -122.0, 10173.0, 10116.0, 5791.0, 0.054, 1.5, 4.0, 28.0, 29.0, 1.0, 397.0, 329.88, 335.07, 373.94, 9.17117117117117, 0.20807924111123, 11.206206982632574, 3.0, 10.0, 0.18663221228016044, 11.848426411601357, 65.0, 188.0, 0.1391509433962264, 9.617924528301886, 2270.0, 3279.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(364.0, 36.0, 76.0, 1.0, 103.0, 1.0, 106.0, 1.0, 0.36, 324.0, 38.0, 0.31, 1.0, 1.0, 1.0, 103.0, 1.0, 106.0, 0.36, 1.0, 0.31, 0.008701554732167022, 0.009101861508585643, 10026.0, 36.0, 366.0, 12875.0, 256.0, 228.0, 10027.0, 10016.0, 244.0, 0.273, 1.333, 3.0, 31.0, 26.0, 1.0, 939.0, 228.69, 282.13, 181.05, 25.75811209439528, 0.16280382878038288, 12.261440632394063, 14.0, 11.0, 0.14269574059428314, 10.450193951719204, 230.0, 268.0, 0.12149532710280374, 5.841121495327103, 2734.0, 4785.0))",1.0,1.0,1.0,1.0,1.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(587.0, 146.0, 599.0, 1.0, 21.0, 1.0, 23.0, 1.0, 0.24, 300.0, 220.0, 0.4, 1.0, 1.0, 1.0, 21.0, 1.0, 23.0, 0.24, 1.0, 0.4, 0.00948494656296053, 0.008076466597018312, 10148.0, 7620.0, 16093.0, 228.0, 206.0, 10149.0, 10093.0, 4267.0, 0.123, 0.2, 51.6, 5.0, 26.0, 29.0, 1.0, 437.0, 219.98, 302.39, 260.44, 62.15613382899628, 0.2398833534977601, 14.220697917069211, 5.0, 14.0, 0.22719161699381146, 14.414062796613386, 53.0, 213.0, 0.22287968441814596, 13.303747534516765, 1434.0, 3169.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(479.0, 650.0, 96.0, 1.0, 95.0, 1.0, 89.0, 1.0, 0.36, 59.0, 376.0, 0.25, 1.0, 1.0, 1.0, 1.0, 95.0, 1.0, 89.0, 0.36, 1.0, 0.25, 0.0034362139165875897, 0.003369378573457475, 10162.0, 22000.0, 16093.0, 194.0, 172.0, 10169.0, 9940.0, 2814.7330729166665, 0.084, 0.333, 3.0, 67.0, 73.0, 1.0, 436.0, 255.22, 317.53, 347.05, 22.545454545454547, 0.20844771293625383, 11.74154436237994, 3.0, 14.0, 0.16388358778625955, 12.013120229007633, 172.0, 218.0, 0.14427860696517414, 9.124378109452737, 2354.0, 3841.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(1083.0, 650.0, 5431.0, 1.0, 20.0, 1.0, 16.0, 1.0, 0.22, 53.0, 653.0, 0.23, 1.0, 1.0, 1.0, 1.0, 20.0, 1.0, 16.0, 0.22, 1.0, 0.23, 0.0034362139165875897, 0.003369378573457475, 10220.0, 22000.0, 16093.0, 161.0, 133.0, 10227.0, 9996.0, 7620.0, 0.056, 2.0, 67.0, 73.0, 1.0, 510.0, 979.97, 296.98, 290.42, 14.5, 0.2020475085979365, 19.965618918126317, 1.0, 5.0, 2.0, 0.16388358778625955, 12.013120229007633, 172.0, 218.0, 0.2, 7.672727272727273, 1316.0, 4339.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(1888.0, 26.0, 4227.0, 1.0, 59.0, 1.0, 73.0, 1.0, 0.1, 235.0, 314.0, 0.47, 1.0, 1.0, 1.0, 1.0, 59.0, 1.0, 73.0, 0.1, 1.0, 0.47, 0.008255730076970672, 0.008484659861292751, 10198.798345398138, 853.0, 14484.0, 178.0, 167.0, 10166.0, 10162.0, 213.0, 0.345, 1.0, 33.0, 27.0, 1.0, 441.0, 490.7, 266.75, 277.57, 32.86695278969957, 0.13976920811430713, 10.67751162437862, 16.0, 16.0, 0.17988804276564627, 12.854909476019971, 130.0, 203.0, 0.04411764705882353, 1.2058823529411764, 3282.0, 5227.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 22, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(192.0, 236.0, 1026.0, 1.0, 25.0, 1.0, 15.0, 1.0, 0.1, 22.0, 1060.0, 0.16, 1.0, 1.0, 1.0, 1.0, 25.0, 1.0, 15.0, 0.1, 1.0, 0.16, 0.0016240220751936492, 0.001479771782829008, 10155.0, 15.0, 1250.0, 16093.0, 67.0, 67.0, 10156.0, 10073.0, 122.0, 0.112, 6.0, 144.0, 155.0, 1.0, 963.0, 526.58, 412.0, 514.11, 8.894736842105264, 0.12154643832964007, 7.731726942191433, 16.0, 16.0, 0.13340623291416073, 11.787862219792236, 232.0, 190.0, 0.08925869894099848, 9.703479576399396, 3069.0, 3233.0))",0.0,0.0,0.0,0.0,0.0
0.0,"Map(vectorType -> sparse, length -> 85, indices -> List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 22, 25, 26, 27, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 78, 79, 80, 81, 83, 84), values -> List(390.0, 896.0, 435.0, 1.0, 76.0, 1.0, 65.0, 1.0, 0.22, 132.0, 189.0, 0.19, 1.0, 1.0, 1.0, 1.0, 76.0, 1.0, 65.0, 0.22, 1.0, 0.19, 0.005605913353046195, 0.006235004123687768, 10314.0, 15.0, 22000.0, 14484.0, -94.0, -128.0, 10305.0, 9981.0, 2275.280155642023, 0.118, 1.0, 45.0, 44.0, 1.0, 527.0, 263.66, 288.42, 242.9, 18.585365853658537, 0.15813352609157041, 14.09230863874748, 15.0, 11.0, 0.17840490797546013, 14.714437627811861, 112.0, 152.0, 0.0847457627118644, 9.0, 3241.0, 3908.0))",0.0,0.0,0.0,0.0,0.0


In [0]:
all_preds_f = all_preds.withColumn("weighted_pred", lit(0.15)*col('prediction_1') + lit(0.15)*col('prediction_2') + lit(0.2)*col('prediction_3') + lit(0.3)*col('prediction_4')+ lit(0.2)*col('prediction_5'))


all_preds_f = all_preds_f.withColumn("prediction", (col("weighted_pred") >= lit(0.5)).cast('double')).cache()

test_df = all_preds_f.select('prediction','label').cache()

In [0]:
all_preds_f.write.mode("overwrite").parquet(f"{blob_url}/xgboost_all_folds_weighted_preds_4_10")
# all_preds_f = spark.read.parquet(f'{blob_url}/xgboost_all_folds_weighted_preds_4_9')

In [0]:
test_df = preds_1_reinf.select('prediction','label')
test_metrics = MulticlassMetrics(test_df.rdd)
f1_score = test_metrics.fMeasure(1.0,1.0)

print("Fold 1 test set (2019) f1 score:", f1_score)

test_df = preds_2_reinf.select('prediction','label')
test_metrics = MulticlassMetrics(test_df.rdd)
f1_score = test_metrics.fMeasure(1.0,1.0)

print("Fold 2 test set (2019) f1 score:", f1_score)

test_df = preds_3_reinf.select('prediction','label')
test_metrics = MulticlassMetrics(test_df.rdd)
f1_score = test_metrics.fMeasure(1.0,1.0)

print("Fold 3 test set (2019) f1 score:", f1_score)

test_df = preds_4_reinf.select('prediction','label')
test_metrics = MulticlassMetrics(test_df.rdd)
f1_score = test_metrics.fMeasure(1.0,1.0)

print("Fold 4 test set (2019) f1 score:", f1_score)

test_df = preds_5_reinf.select('prediction','label')
test_metrics = MulticlassMetrics(test_df.rdd)
f1_score = test_metrics.fMeasure(1.0,1.0)

print("Fold 5 test set (2019) f1 score:", f1_score)


test_df = all_preds_f.select('prediction','label')

test_metrics = MulticlassMetrics(test_df.rdd)
f1_score = test_metrics.fMeasure(1.0,1.0)

print("All Folds weighted vote test set (2019) f1 score:", f1_score)