In [57]:
dir()

['In',
 'Out',
 'SQLContext',
 'SparkConf',
 'SparkContext',
 'SparkSession',
 '_',
 '_19',
 '_23',
 '_28',
 '_33',
 '_36',
 '_39',
 '_4',
 '_44',
 '_47',
 '_50',
 '_56',
 '__',
 '___',
 '__builtin__',
 '__builtins__',
 '__doc__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 '_dh',
 '_i',
 '_i1',
 '_i10',
 '_i11',
 '_i12',
 '_i13',
 '_i14',
 '_i15',
 '_i16',
 '_i17',
 '_i18',
 '_i19',
 '_i2',
 '_i20',
 '_i21',
 '_i22',
 '_i23',
 '_i24',
 '_i25',
 '_i26',
 '_i27',
 '_i28',
 '_i29',
 '_i3',
 '_i30',
 '_i31',
 '_i32',
 '_i33',
 '_i34',
 '_i35',
 '_i36',
 '_i37',
 '_i38',
 '_i39',
 '_i4',
 '_i40',
 '_i41',
 '_i42',
 '_i43',
 '_i44',
 '_i45',
 '_i46',
 '_i47',
 '_i48',
 '_i49',
 '_i5',
 '_i50',
 '_i51',
 '_i52',
 '_i53',
 '_i54',
 '_i55',
 '_i56',
 '_i57',
 '_i6',
 '_i7',
 '_i8',
 '_i9',
 '_ih',
 '_ii',
 '_iii',
 '_oh',
 '_pythonstartup',
 'atexit',
 'build_correlations',
 'combine_skus',
 'conf',
 'defaultdict',
 'exit',
 'get_ipython',
 'math',
 'os',
 'platform',
 'py4j',
 '

In [92]:
import time
from collections import defaultdict
from pyspark.sql import functions as sfunc
from pyspark.sql import types as stypes
import math
import sys

In [2]:
schema = stypes.StructType().add("fv", stypes.StringType()).add("sku", stypes.StringType()).add("score", stypes.FloatType())

In [3]:
train_df = spark.read.csv('gs://lbanor/pyspark/train_query*.gz', header=True, schema=schema)

In [4]:
tt = train_df.head(3)

[Row(fv='3383270414872112082', sku='MO578SHF77RTI', score=0.5),
 Row(fv='7143168022217708588', sku='DA923SHF54UJP', score=0.5),
 Row(fv='8844960186636261737', sku='LU621ACM67NYU', score=0.5)]

In [96]:
tt = train_df.collect()

In [98]:
tt[0]

Row(fv='3383270414872112082', sku='MO578SHF77RTI', score=0.5)

In [97]:
sys.getsizeof(tt)

42915448

In [20]:
train_df.createOrReplaceTempView('test1')

In [10]:
def build_correlations(row):
    return [{"sku": e.sku, "corr": [{"sku": i.sku, "score": e.score * i.score} for i in row]} for e in row]
sqlContext.udf.register("BUILD_CORRELATIONS", build_correlations, stypes.ArrayType(stypes.StructType(fields=[stypes.StructField("sku", stypes.StringType(), False), stypes.StructField("corr", stypes.ArrayType(stypes.StructType(fields=[stypes.StructField("sku", stypes.StringType(), False), stypes.StructField("score", stypes.FloatType(), False)])), False)])))

In [51]:
def combine_skus(ref_sku, row):
    d = defaultdict(float)
    ref_norm = 0.0
    for inner_row in row:
        for e in inner_row:
            d[e.sku] += e.score
            if e.sku == ref_sku:
                ref_norm += e.score
    ref_norm = math.sqrt(ref_norm)
    return {"norm": ref_norm, "corr": [{"sku": key, "similarity": value / ref_norm} for key, value in d.items()]}
sqlContext.udf.register("COMBINE_SKUS", combine_skus, stypes.StructType(fields=[stypes.StructField("norm", stypes.FloatType(), False), stypes.StructField("corr", stypes.ArrayType(stypes.StructType(fields=[stypes.StructField("sku", stypes.StringType(), False), stypes.StructField("similarity", stypes.FloatType(), False)]) ) )]))

In [85]:
query = """
SELECT
  data.sku sku,
  COMBINE_SKUS(data.sku, COLLECT_LIST(data.corr)) data
FROM(
  SELECT
  EXPLODE(BUILD_CORRELATIONS(data)) data
  FROM(
    SELECT
      fv,
      COLLECT_LIST(STRUCT(sku, score)) data
    FROM test1
    GROUP BY
      fv
    HAVING SIZE(data) > 1 AND SIZE(data) < 200
  )
)
GROUP BY
  data.sku
"""

In [81]:
r1 = spark.sql(query)

In [82]:
r1.createOrReplaceTempView('test2')

In [69]:
query_extract_norms = """
SELECT
  sku,
  data.norm norm
FROM test2
"""

In [84]:
t0 = time.time()
r2 = {e.sku: e.norm for e in spark.sql(query_extract_norms).collect()}
print(time.time() - t0)

1481.6083595752716


In [86]:
r2_broad = sc.broadcast(r2)

In [87]:
def normalize_corrs(corrs):
    return [{"sku": e.sku, "similarity": e.similarity / r2_broad.value[e.sku]} for e in corrs]
sqlContext.udf.register("NORMALIZE_CORRS", normalize_corrs, stypes.ArrayType(stypes.StructType(fields=[stypes.StructField("sku", stypes.StringType(), False), stypes.StructField("similarity", stypes.FloatType(), False)])))

In [88]:
final_query = """
select
sku,
NORMALIZE_CORRS(data.corr) corr
FROM test2
"""

In [90]:
final = spark.sql(final_query)

In [91]:
t0 = time.time()
final.head(1)
print(time.time() - t0)

381.65184354782104
