In [18]:
from collections import namedtuple

import pandas as pd
from pyspark.sql import SparkSession
import pyspark.sql.functions as spf
from pyspark.sql.types import *

In [2]:
spark = SparkSession.builder.appName("anokhin").getOrCreate()

In [3]:
data = spark.read.json("/user/anokhin/week{1,2}/")

data.printSchema()

root
 |-- experiments: struct (nullable = true)
 |    |-- AA: string (nullable = true)
 |    |-- PERSONALIZED: string (nullable = true)
 |-- latency: double (nullable = true)
 |-- message: string (nullable = true)
 |-- recommendation: long (nullable = true)
 |-- time: double (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- track: long (nullable = true)
 |-- user: long (nullable = true)



In [13]:
data.groupBy("message").count().show()

+-------+------+
|message| count|
+-------+------+
|   next|916001|
|   last|150001|
+-------+------+



## Save data for Torch

In [38]:
Pair = namedtuple("Pair", ["start", "track", "time"])

schema = ArrayType(StructType([
    StructField("start", LongType(), False),
    StructField("track", LongType(), False),
    StructField("time", FloatType(), False)
]))

def collect_pairs(tracks):
    pairs = []
    start_track = tracks[0]["track"]
    for prev, current in zip(tracks[:-1], tracks[1:]):
        if prev["message"] != "last":
            pairs.append(Pair(start=start_track, track=current["track"], time=current["time"]))
        else:
            start_track = current["track"]
    return pairs
            
collect_pairs_udf = spf.udf(collect_pairs, schema)

(
    data
        .groupBy("user")
        .agg(
            spf.sort_array(spf.collect_list(spf.struct(
                spf.col("timestamp"),
                spf.col("track"),
                spf.col("time"),
                spf.col("message")
            ))).alias("history")
        )
        .select(
            spf.col("user"),
            spf.explode(collect_pairs_udf(spf.col("history"))).alias("pair")
        )
        .select(
            spf.col("user"),
            spf.col("pair.*")
        )
        .toPandas()
        .to_pickle("pairs.pkl")
)

In [39]:
pd.read_pickle("pairs.pkl").head()

Unnamed: 0,user,start,track,time
0,26,11761,7287,0.78
1,26,11761,22875,0.24
2,26,11761,27161,0.5
3,26,11761,5475,0.0
4,26,11761,21522,0.87
