In [None]:
from collections import namedtuple

import pandas as pd
import numpy as np
import scipy.stats as ss

from pyspark.sql import SparkSession, Row
import pyspark.sql.functions as spf
from pyspark.sql.types import *

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
app_name = "svorobiev"
hdfs_data_path = "/user/mob202273/my_remote_dir/"

In [None]:
spark = SparkSession.builder.appName(app_name).getOrCreate()

In [None]:
data = spark.read.json(hdfs_data_path)

data.printSchema()

In [None]:
data.groupBy("message").count().show()

## Visualize sessions

In [None]:
experiment = "MY_VS_CONTEXTUAL"

Session = namedtuple("Session", ["timestamp", "tracks", "time"])

schema = ArrayType(StructType([
    StructField("timestamp", LongType(), False),
    StructField("tracks", LongType(), False),
    StructField("time", FloatType(), False)
]))

def sessionize(tracks):
    sessions = []
    session_tracks = 0
    session_time = 0
    for track in tracks:
        session_tracks += 1
        session_time += track["time"]
        if track["message"] == "last":
            sessions.append(Session(timestamp=track["timestamp"], tracks=session_tracks, time=session_time))
            session_tracks = 0
            session_time = 0
    return sessions
            
sessionize_udf = spf.udf(sessionize, schema)

sessions = (
  data
    .groupBy(
        spf.col("user"), 
        spf.col("experiments." + experiment).alias("treatment")
    )
    .agg(
        spf.sort_array(spf.collect_list(spf.struct(
            spf.col("timestamp"),
            spf.col("message"),
            spf.col("time"),
        )).alias("track")).alias("tracks")
    )
    .select(
        spf.col("treatment"),
        spf.explode(sessionize_udf(spf.col("tracks"))).alias("session"),
    )
    .select(
        spf.col("treatment"),
        spf.col("session.*"),
    )
    .toPandas()
    .set_index("timestamp")
    .sort_index()
)

In [None]:
figure, ax = plt.subplots(figsize=(15, 5))
for treatment, treatment_sessions in sessions.groupby("treatment"):
    ax.plot(treatment_sessions.index, treatment_sessions["tracks"], label=treatment)

In [None]:
figure, ax = plt.subplots(figsize=(15, 5))
for treatment, treatment_sessions in sessions.groupby("treatment"):
    ax.plot(treatment_sessions.index, treatment_sessions["time"], label=treatment)

## Analyze experiment

In [None]:
experiment = "MY_VS_CONTEXTUAL"

user_level_data = (
  data
    .groupBy(
        spf.col("user"), 
        spf.col("experiments." + experiment).alias("treatment")
    )
    .agg(
        spf.count(spf.when(spf.col("message") == "last", spf.col("user"))).alias("sessions"),
        (spf.count("user") / spf.count(spf.when(spf.col("message") == "last", spf.col("user")))).alias("mean_session_length"),
        (spf.sum("time") / spf.count(spf.when(spf.col("message") == "last", spf.col("user")))).alias("mean_session_time"),
        (spf.sum("latency") / spf.count("user")).alias("mean_request_time")
    )
)

metrics = [column for column in user_level_data.columns if column not in ("user", "treatment")]

metric_stats = []
for metric in metrics:
    metric_stats.extend(
        [
            spf.avg(metric).alias("mean_" + metric),
            spf.variance(metric).alias("var_" + metric),
            spf.count(metric).alias("n_" + metric)
        ]
    )
    
treatment_level_data = (
    user_level_data
      .groupBy(spf.col("treatment"))
      .agg(*metric_stats)
      .collect()
)

In [None]:
def color(value):
    return 'color:red;' if value < 0 else 'color:green;'

def background(value):
    return 'color:white;background-color:green' if value else 'color:white;background-color:red'
        

(
    pd.DataFrame(effects)[[
        "treatment", 
        "metric",
        "effect", 
        "upper", 
        "lower", 
        "control_mean", 
        "treatment_mean",
        "significant"
    ]]
    .sort_values("effect", ascending=False)
    .style
    .applymap(color, subset=["effect", "upper", "lower"])
    .applymap(background, subset=["significant"])
)