In [26]:
import findspark
findspark.init()

In [27]:
import pyspark
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

In [28]:
spark = SparkSession \
    .builder \
    .appName("concrec-recall") \
    .getOrCreate()

In [29]:
rating_df = spark.read.csv('../dataset/rating.csv', header=True, inferSchema=True)
rating_df.printSchema()
rating_df = rating_df.where('rating>7')

root
 |-- user_id: integer (nullable = true)
 |-- anime_id: integer (nullable = true)
 |-- rating: integer (nullable = true)



In [30]:
watch_seq_df = rating_df.groupBy('user_id').agg(collect_list(col('anime_id').cast('string')).alias('anime_ids'))

In [31]:
watch_seq_df.show()

+-------+--------------------+
|user_id|           anime_ids|
+-------+--------------------+
|    148|[20, 81, 170, 263...|
|    463|[20, 24, 68, 102,...|
|    471| [1604, 6702, 10681]|
|    496|[227, 430, 481, 4...|
|    833|[1, 6, 19, 33, 45...|
|   1088|[226, 2923, 5262,...|
|   1238|[120, 121, 150, 2...|
|   1342|[6746, 6880, 9919...|
|   1580|[849, 853, 1535, ...|
|   1591|[48, 57, 66, 232,...|
|   1645|[20, 59, 101, 120...|
|   1829|[1, 226, 339, 467...|
|   1959|[19, 199, 226, 51...|
|   2122|[64, 65, 226, 153...|
|   2142|             [11061]|
|   2366|[33, 226, 356, 15...|
|   2659|[1, 30, 32, 47, 1...|
|   2866|[1, 30, 32, 33, 3...|
|   3175|[164, 199, 416, 4...|
|   3749|               [223]|
+-------+--------------------+
only showing top 20 rows



In [32]:
watch_seq = watch_seq_df.collect()

In [33]:
watch_seq = [r['anime_ids'] for r in watch_seq]

In [34]:
from collections import defaultdict
matrix = defaultdict(lambda :defaultdict(int))

In [35]:
for i in range(len(watch_seq)):
    seq = watch_seq[i]
    for x in range(len(seq)):
        for y in range(x+1,len(seq)):
            a = seq[x]
            b = seq[y]
            if a==b:
                continue 
            matrix[a][b] +=1
            matrix[b][a] +=1

In [36]:
def get_transfer_prob(vs):
    neiqhbours = list(vs.keys())
    total_weight = __builtin__.sum(vs.values())
    probs = [vs[k] / total_weight for k in neiqhbours]
    return    {'neighbours': neiqhbours,'probs': probs}

In [37]:
tranfer_probs = {k:get_transfer_prob(v) for k,v in matrix.items()}

### Entrance

In [41]:
entrance_items = list(tranfer_probs)
neighbour_sum ={k:__builtin__.sum(matrix[k].values()) for k in entrance_items}
total_sum = __builtin__.sum(neighbour_sum.values())
entrance_probs = [neighbour_sum[e] / total_sum for e in entrance_items]
entrance_probs

[0.0018846217181866273,
 0.00022572821133446202,
 0.0004226465697159734,
 0.0008175823087249919,
 0.0004325200437646773,
 0.0006516073890056663,
 0.0004943070099757765,
 0.0004963680795544109,
 0.000835735513418749,
 0.0003178672068759576,
 0.0016106702800025709,
 0.00015067723910167481,
 0.0004510245488244565,
 0.0009216913002243417,
 0.0004985000537940772,
 0.003668747359647398,
 0.003348838420827633,
 0.000954073781211563,
 0.0025595277342195764,
 0.0002485595121867612,
 0.0030947692941847727,
 0.00032975662935536633,
 0.00015254493233204006,
 0.0007022655463738818,
 0.003094010292017817,
 0.0031013328097389383,
 0.0006371218890507456,
 0.0013932234107827308,
 0.0008337695205446803,
 0.0028094167088006873,
 0.00045553182920869423,
 0.0007414999959681031,
 0.0011318285434796117,
 0.002142990245639336,
 0.0013122970205017024,
 0.002382633496701325,
 0.001313505622678384,
 0.0011080448641120873,
 0.0019441606843490988,
 0.0013670176926531368,
 0.0012974924495721377,
 0.0009697227621952

### Deep Walk

In [42]:
import numpy as np

rng = np.random.default_rng()
def one_random_walk(length, entrance_items, entrance_probs, transfer_probs):
    start_vertex = rng.choice(entrance_items, 1, p=entrance_probs)[0]
    path = [str(start_vertex)]

    curr_vertex = start_vertex
    for _ in range(length):
        if curr_vertex not in transfer_probs:
            print(f"bad vertex {curr_vertex}")
            break

        neighbours = transfer_probs[curr_vertex]['neighbours']
        trans_prob = transfer_probs[curr_vertex]['probs']

        try:
            next_vertex = rng.choice(neighbours, 1, p=trans_prob)[0]
            path.append(str(next_vertex))

            curr_vertex = next_vertex
        except Exception as e:
            print(curr_vertex)
            print(e)
            break

    return path

In [44]:
n = 500
samples= [one_random_walk(20, entrance_items, entrance_probs, tranfer_probs) for _ in range(n)]

In [45]:
sample_df = spark.createDataFrame([[row] for row in samples],['anime_ids'])

In [46]:
sample_df.show()

+--------------------+
|           anime_ids|
+--------------------+
|[24439, 12859, 95...|
|[28621, 10568, 28...|
|[9253, 18893, 143...|
|[10372, 18095, 28...|
|[5356, 9253, 1981...|
|[4472, 4551, 3298...|
|[2129, 392, 4081,...|
|[8915, 1952, 6505...|
|[7103, 24151, 152...|
|[3446, 31637, 202...|
|[10714, 9367, 210...|
|[10083, 10620, 10...|
|[11433, 6956, 517...|
|[18671, 1817, 769...|
|[6408, 1918, 5177...|
|[7144, 9750, 2286...|
|[4334, 28497, 219...|
|[20159, 392, 659,...|
|[696, 323, 2251, ...|
|[20047, 20973, 16...|
+--------------------+
only showing top 20 rows



In [47]:
from pyspark.ml. feature import Word2Vec
item2vec = Word2Vec(vectorSize=5, maxIter=2, windowSize=15)

In [49]:
item2vec.setInputCol('anime_ids')
item2vec.setOutputCol('anime_ids_vec')

model = item2vec.fit(sample_df)
item_emb_df = model.getVectors()


In [51]:
item_vec = model.getVectors().collect()
item_emb = {}
for item in item_vec:
    item_emb[item.word] = item.vector.toArray()


In [52]:
@udf(returnType='array<float>')
def build_user_emb(anime_seq):
    anime_embs = [item_emb[aid] if aid in item_emb else [] for aid in anime_seq]
    anime_embs = list(filter(lambda l: len(l) > 0, anime_embs))
    ret = np.mean(anime_embs, axis=0).tolist()
    return ret

In [53]:
user_emb_df = rating_df \
    .where('rating > 7') \
    .groupBy('user_id') \
    .agg(
        collect_list(col('anime_id').cast('string')).alias('anime_ids')
    ) \
    .withColumn('user_emb', build_user_emb(col('anime_ids')))


In [54]:
user_emb_df.show()

+-------+--------------------+--------------------+
|user_id|           anime_ids|            user_emb|
+-------+--------------------+--------------------+
|    148|[20, 81, 170, 263...|[0.013078215, -0....|
|    463|[20, 24, 68, 102,...|[0.015662624, -0....|
|    471| [1604, 6702, 10681]|[0.040399462, -0....|
|    496|[227, 430, 481, 4...|[-0.016811699, -0...|
|    833|[1, 6, 19, 33, 45...|[0.0031106598, -0...|
|   1088|[226, 2923, 5262,...|[-0.0038642173, -...|
|   1238|[120, 121, 150, 2...|[-0.018002959, -0...|
|   1342|[6746, 6880, 9919...|[-0.030182596, -0...|
|   1580|[849, 853, 1535, ...|[-0.021111125, -0...|
|   1591|[48, 57, 66, 232,...|[-0.011426094, -0...|
|   1645|[20, 59, 101, 120...|[0.014206779, -0....|
|   1829|[1, 226, 339, 467...|[-0.006289381, -0...|
|   1959|[19, 199, 226, 51...|[-0.020051502, -0...|
|   2122|[64, 65, 226, 153...|[-0.0018397114, -...|
|   2142|             [11061]|[0.13350934, -0.2...|
|   2366|[33, 226, 356, 15...|[0.007979566, -0....|
|   2659|[1,

## LSH