In [0]:
import random
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator
import matplotlib.pyplot as plt
from pyspark.sql import functions as F
from pyspark.sql.functions import col
import mlflow
import mlflow.spark
import pyspark
import time
import pandas as pd



In [0]:
version = '2_civ7'
model_name = 'campaign_clustering_turn_100'

source_table =  f'sandbox.cg_inverness.{model_name}_data_v2_features_v2_PCA_v{version}'


run_title = 'turn_100_clustering_v2'
k_max = 10
random_seed_count = 3

experiment_name = f'/Users/jak.marshall@2k.com/Kmeans_turn_100_PCA_v2_2'

print(f'source_table:      {source_table}')
print(f'experiment_name:   {experiment_name}')

source_table:      sandbox.cg_inverness.campaign_clustering_turn_100_data_v2_features_v2_PCA_v2_civ7
experiment_name:   /Users/jak.marshall@2k.com/Kmeans_turn_100_PCA_v2_2


In [0]:
df = spark.table(source_table)
df.limit(5).toPandas()

Unnamed: 0,output,CAMPAIGN_ID
0,"[-4.628801639618975, -4.226169662353111, 0.961...",414a180b6904bbcf54c54a2bdccbf6ad
1,"[-4.6850543781962415, -4.499950198787214, 1.31...",11f0d5e476e68256f43c2621e400978f
2,"[-6.943745629867604, -7.449101701768911, 0.072...",4ed7e07c53537bde06a7ebcd2c172191
3,"[-5.164776566411206, -5.2226536737636975, 1.24...",45c3dedc39afb9eda415e3643b28a5b5
4,"[-5.152690825407146, -5.537162389281079, 1.156...",4353676d55133ebc418ae4f9c761e8ad


In [0]:
df.count()

529282

In [0]:
def get_dtype(df,colname):
    return [dtype for name, dtype in df.dtypes if name == colname][0]

In [0]:
mlflow.set_experiment(experiment_name)
experiment = mlflow.get_experiment_by_name(experiment_name)
print(experiment.experiment_id)

291042099400795


In [0]:
'''
%sql
CREATE VOLUME sandbox.cg_inverness.my_volume2;
'''

In [0]:
dfs_tmpdir = "/Volumes/sandbox/cg_inverness/my_volume2/mlflow_tmp"

In [0]:
# Settings
k_max=6
k_range = range(2, k_max+1)

random_seed_count=2
seeds = [random.randint(1000, 9999) for _ in range(random_seed_count)]

subset = df
for random_seed in seeds:

    # Run
    run_name = f'{run_title}_{random_seed}'

    # Vars
    wssse_scores = []
    silhouette_scores = []
    child_run_ids = []
    run_ids = []

    evaluator = ClusteringEvaluator(predictionCol='prediction', featuresCol='output', metricName='silhouette')
    tags = {"mlflow.runName": run_name}

    # Training Runs
    with mlflow.start_run(experiment_id=experiment.experiment_id, tags=tags) as run:
        run_id = run.info.run_id
        run_ids.append(run_id)

        # Child Runs
        for k in k_range:
            with mlflow.start_run(nested=True) as child_run:
                child_run_id = child_run.info.run_id
                child_run_ids.append(child_run_id)

                # Train
                kmeans = KMeans().setK(k).setSeed(random_seed).setFeaturesCol("output")
                model = kmeans.fit(subset[['output','CAMPAIGN_ID']])

                # Predict
                predictions = model.transform(subset[['output','CAMPAIGN_ID']])

                # Eval
                silhouette = evaluator.evaluate(predictions)
                silhouette_scores.append(silhouette)

                wssse = model.summary.trainingCost
                wssse_scores.append(wssse)

                
                #Log
                mlflow.spark.log_model(
                    model,
                    "kmeans_model",
                    pip_requirements=['pyspark=='+pyspark.__version__],
                    input_example = None,
                    signature = None,
                    dfs_tmpdir=dfs_tmpdir
                )
                mlflow.log_param("k", k)
                mlflow.log_param("seed", random_seed)
                mlflow.set_tag("model_type", "kmeans")
                mlflow.set_tag("run_name", run_name) # Experiment vizualition workaround (group by tag run_name)
                mlflow.log_metric("k", k) # Experiment vizualition workaround (x-axis metric k)
                mlflow.log_metric("silhouette_score", silhouette)
                mlflow.log_metric("wsse_score", wssse)
                
                # Log
                print(f'random_seed: {random_seed}')
                print(f"k:           {k}")
                print(f"id:          {child_run_id}")
                print(f"wssse:       {wssse}")
                print(f"silhouette:  {silhouette}")




random_seed: 3323
k:           2
id:          9818596ebbfb4866aa1b2c0e9dca1185
wssse:       9374482.868516607
silhouette:  0.5070642968052774




random_seed: 3323
k:           3
id:          8150c6ad4f8646d5922e8229521005d2
wssse:       8258285.131011623
silhouette:  0.34300663461986264




random_seed: 3323
k:           4
id:          ea9b4edd29b84630ab25b9f4d81b7583
wssse:       7830418.4924876345
silhouette:  0.21889368421688613




random_seed: 3323
k:           5
id:          9f12c896bcd84195bf6604de77bdf494
wssse:       7387479.516236255
silhouette:  0.16610670409505357




random_seed: 3323
k:           6
id:          159711db1e9d4d299eaa9f6290bb1d89
wssse:       7124307.758581748
silhouette:  0.14337796808878903




random_seed: 5953
k:           2
id:          d8960ced54434340821d7d1f8d564339
wssse:       9374475.231641702
silhouette:  0.5072290198809569




random_seed: 5953
k:           3
id:          04ed0d20df85460ba1d8a767f2882d19
wssse:       8258285.345539516
silhouette:  0.3429692183881445




random_seed: 5953
k:           4
id:          405c3ab808154478bceb1e735a581550
wssse:       7825361.276919915
silhouette:  0.18445764141851634




random_seed: 5953
k:           5
id:          093f95b5638a435caf005cc547dccf63
wssse:       7381507.61448092
silhouette:  0.1754042758893006




random_seed: 5953
k:           6
id:          27235d684e2e40faaf36cd320967a981
wssse:       7060581.134973515
silhouette:  0.1831713705304888


In [0]:
original_data = spark.table(f'sandbox.cg_inverness.campaign_clustering_turn_100_data_v2_civ7')
original_data.limit(5).toPandas()

Unnamed: 0,CAMPAIGN_ID,scouts,settlers,all_other_units,max_gold_balance,max_influence_balance,max_settlementcap,max_science,max_production,max_influence,max_happiness,max_gold,max_food,max_culture,independent_dispersed,attacks_made,units_lost,enemies_defeated,became_suzerain,trade_routes,settlements_founded,saves,modded_games,used_advanced_general_options,maptype,gamedifficulty,display_service,buildversion,mapsize,gamespeed
0,414a180b6904bbcf54c54a2bdccbf6ad,2.0,6.0,6.0,812.0,448.0,6.0,23.0,57.0,12.0,48.0,56.0,79.0,44.0,0.0,0.0,0.0,0.0,0.0,2.0,8.0,1.0,1.0,1.0,other,PRINCE,Steam,1.3.1,STANDARD,STANDARD
1,11f0d5e476e68256f43c2621e400978f,0.0,2.0,12.0,722.0,86.0,5.0,52.0,80.0,12.0,54.0,85.0,107.0,16.0,0.0,48.0,2.0,12.0,0.0,2.0,8.0,1.0,0.0,1.0,continents-voronoi.js,PRINCE,SEN,1.3.1,SMALL,STANDARD
2,4ed7e07c53537bde06a7ebcd2c172191,3.0,2.0,8.0,614.0,153.0,5.0,51.0,101.0,26.0,50.0,197.0,137.0,79.0,0.0,1.0,0.0,0.0,5.0,5.0,3.0,1.0,1.0,1.0,continents-voronoi.js,KING,Steam,1.3.1,SMALL,ONLINE
3,45c3dedc39afb9eda415e3643b28a5b5,1.0,3.0,12.0,623.0,513.0,6.0,47.0,78.0,22.0,42.0,42.0,94.0,58.0,1.0,12.0,2.0,5.0,3.0,2.0,3.0,1.0,0.0,1.0,continents-voronoi.js,DEITY,Steam,1.3.1,HUGE,STANDARD
4,4353676d55133ebc418ae4f9c761e8ad,3.0,3.0,4.0,451.0,390.0,7.0,35.0,56.0,25.0,53.0,46.0,186.0,35.0,1.0,8.0,1.0,3.0,1.0,1.0,4.0,1.0,0.0,1.0,continents-voronoi.js,PRINCE,Steam,1.3.1,HUGE,STANDARD


In [0]:
import os
os.environ["MLFLOW_DFS_TMP"] = "/Volumes/sandbox/cg_inverness/my_volume/mlflow_tmp"

In [0]:
import mlflow
logged_model = 'runs:/8150c6ad4f8646d5922e8229521005d2/kmeans_model' 
# Load model
loaded_model = mlflow.spark.load_model(logged_model)


first_split_predictions = loaded_model.transform(df[['output','CAMPAIGN_ID']])

In [0]:
first_split_predictions.groupby('prediction').count().show()

+----------+------+
|prediction| count|
+----------+------+
|         1|336110|
|         2| 60697|
|         0|132475|
+----------+------+



In [0]:
from pyspark.sql.functions import median,mode,mean,min,max,percentile_approx

kmeans2 = original_data.join(first_split_predictions,original_data.CAMPAIGN_ID == first_split_predictions.CAMPAIGN_ID)

In [0]:
for col in original_data.columns:
    print(col)
    if kmeans2.schema[col].dataType.simpleString() in ['int', 'double']:
        kmeans2.groupby('prediction').agg(median(col),mean(col),percentile_approx(col,.25),percentile_approx(col,.75)).show()

CAMPAIGN_ID
scouts
+----------+--------------+------------------+--------------------------------------+--------------------------------------+
|prediction|median(scouts)|       avg(scouts)|percentile_approx(scouts, 0.25, 10000)|percentile_approx(scouts, 0.75, 10000)|
+----------+--------------+------------------+--------------------------------------+--------------------------------------+
|         1|           2.0|3.5812684562268258|                                   1.0|                                   6.0|
|         2|          11.5| 9.342123193262685|                                  11.5|                                  11.5|
|         0|           4.0| 5.623298551627473|                                   2.0|                                  11.5|
+----------+--------------+------------------+--------------------------------------+--------------------------------------+

settlers
+----------+----------------+------------------+----------------------------------------+-------

In [0]:
for col in original_data.columns:
    if kmeans2.schema[col].dataType.simpleString() not in ['int', 'double'] and col not in ('CAMPAIGN_ID'):
        print(col)
        counts=kmeans2.groupby(col).pivot('prediction').count().toPandas()
        print(counts)
        print(counts.iloc[:,1:].div(counts.iloc[:,1:].sum(axis=0),axis='columns'))

maptype
                      maptype       0       1         2
0          terra-incognita.js    6846   11430    141265
1       continents-voronoi.js   74187  206823    635506
2             pangaea-plus.js   13772   32079   2526429
3                  shuffle.js   13474   26730     65641
4                       other  138449  206320    644943
5          pangaea-voronoi.js    9210   16468  14327441
6   shattered-seas-voronoi.js   30821   43534    311273
7          continents-plus.js   26121   30673    228135
8               continents.js    7485   25007    106577
9                  fractal.js   15915   17287    132341
10             archipelago.js   16875   14195     80736
           0         1         2
0   0.019385  0.018127  0.007357
1   0.210069  0.328006  0.033099
2   0.038997  0.050875  0.131583
3   0.038153  0.042392  0.003419
4   0.392035  0.327208  0.033590
5   0.026079  0.026117  0.746210
6   0.087273  0.069042  0.016212
7   0.073965  0.048645  0.011882
8   0.021195  0.039659 