In [None]:
# Reference: https://github.com/jadianes/spark-movie-lens

In [None]:
# 安裝 JDK / Hadoop / Spark / findspark

!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://mirror.its.dal.ca/apache/spark/spark-3.4.1/spark-3.4.1-bin-hadoop3.tgz
!tar xvf spark-3.4.1-bin-hadoop3.tgz
!pip install -q findspark
!pip install qdrant-client>=1.1.1

In [None]:
# 設定環境變數

import os
import math
import urllib
import zipfile

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.4.1-bin-hadoop3"
!update-alternatives --set java "/usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java"
!java -version

In [4]:
# 建立 spark 環境

import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.mllib.recommendation import ALS

spark = SparkSession.builder.master("local[*]").getOrCreate()
sc = SparkContext.getOrCreate()

In [5]:
# 取得 MovieLens Dataset

complete_dataset_url = 'http://files.grouplens.org/datasets/movielens/ml-latest.zip'

!mkdir datasets
datasets_path = os.path.join('.', 'datasets')
complete_dataset_path = os.path.join(datasets_path, 'ml-latest.zip')

complete_f = urllib.request.urlretrieve (complete_dataset_url, complete_dataset_path)
with zipfile.ZipFile(complete_dataset_path, "r") as z:
    z.extractall(datasets_path)

In [6]:
# 資料前處理

complete_ratings_file = os.path.join(datasets_path, 'ml-latest', 'ratings.csv')
complete_ratings_raw_data = sc.textFile(complete_ratings_file)
complete_ratings_raw_data_header = complete_ratings_raw_data.take(1)[0]
complete_ratings_data = complete_ratings_raw_data.filter(lambda line: line!=complete_ratings_raw_data_header)\
    .map(lambda line: line.split(",")).map(lambda tokens: (int(tokens[0]),int(tokens[1]),float(tokens[2]))).cache()

In [7]:
# 訓練模型

complete_model = ALS.train(complete_ratings_data, rank=25, seed=0, iterations=10, lambda_=0.1) #lambda = 正規化參數

In [8]:
# 取得電影特徵

movie_feature_matrix = complete_model.productFeatures().collect()

In [9]:
# 建立 Qdrant 連結

from qdrant_client import models, QdrantClient

qdrant = QdrantClient(host="127.0.0.1", port=6333)
# qdrant = QdrantClient(":memory:")
# qdrant = QdrantClient(path="path/to/db")

In [None]:
# 建立 Collections

qdrant.recreate_collection(
    collection_name="movies",
    vectors_config=models.VectorParams(size=25, distance=models.Distance.COSINE)
)

In [11]:
# 新增 Vectors

movies = list()
for item in movie_feature_matrix:
  movies.append(models.Record(id=item[0], vector=list(item[1])))

qdrant.upload_records(collection_name="movies", records=movies)

In [14]:
# 測試推薦電影

hits = qdrant.recommend(
    collection_name='movies',
    positive=[1],
    limit=20
)

print(hits)

[ScoredPoint(id=3114, version=0, score=0.9936248587067097, payload={}, vector=None), ScoredPoint(id=78499, version=0, score=0.9836965262909908, payload={}, vector=None), ScoredPoint(id=2355, version=0, score=0.9764555050785272, payload={}, vector=None), ScoredPoint(id=6377, version=0, score=0.9732577357024096, payload={}, vector=None), ScoredPoint(id=4886, version=0, score=0.9718452197925664, payload={}, vector=None), ScoredPoint(id=233193, version=0, score=0.9718091556670364, payload={}, vector=None), ScoredPoint(id=8961, version=0, score=0.9715311518122335, payload={}, vector=None), ScoredPoint(id=588, version=0, score=0.9675419424025228, payload={}, vector=None), ScoredPoint(id=116497, version=0, score=0.9570402921063018, payload={}, vector=None), ScoredPoint(id=50872, version=0, score=0.9558306746290046, payload={}, vector=None), ScoredPoint(id=194316, version=0, score=0.9538840901651645, payload={}, vector=None), ScoredPoint(id=207872, version=0, score=0.9538840901651645, payload=