In [82]:
import numpy as np
import pandas as pd 
import os
import math
import urllib
import zipfile
import matplotlib.pyplot as plt 

import pyspark
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import TrainValidationSplit
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import countDistinct, col, lit
from pyspark.sql.types import StructType, StructField, IntegerType

# Build our Spark Session and Context
spark = pyspark.sql.SparkSession.builder.getOrCreate()
sc = spark.sparkContext

In [84]:
# read in dataframes
ratings_df = spark.read.csv('../data/movies/ratings.csv',header=True)
movies_df = spark.read.csv('../data/movies/movies.csv',header=True)

# cut out timestamp column
ratings_df = ratings_df.drop(ratings_df['timestamp'])

In [85]:
# convert to RDD
ratings_rdd = ratings_df.rdd
movies_rdd = movies_df.rdd

In [87]:
# train-test split
train_rdd, validation_rdd, test_rdd = ratings_rdd.randomSplit([6, 2, 2], seed=42)
validation_for_predict_rdd = validation_rdd.map(lambda x: (x[0], x[1]))
test_for_predict_rdd = test_rdd.map(lambda x: (x[0], x[1]))


In [91]:
# build the model
seed = 42
iterations = 10
regularization_parameter = 0.1
ranks = [4, 8, 12]
errors = [0, 0, 0]
err = 0
tolerance = 0.02

min_error = float('inf')
best_rank = -1
best_iteration = -1
for rank in ranks:
    model = ALS.train(training_RDD, rank, seed=seed, iterations=iterations,
                      lambda_=regularization_parameter)
    predictions = model.predictAll(validation_for_predict_RDD).map(lambda r: ((r[0], r[1]), r[2]))
    rates_and_preds = validation_RDD.map(lambda r: ((int(r[0]), int(r[1])), float(r[2]))).join(predictions)
    error = math.sqrt(rates_and_preds.map(lambda r: (r[1][0] - r[1][1])**2).mean())
    errors[err] = error
    err += 1
    print('For rank {} the RMSE is {}'.format(rank, error))
    if error < min_error:
        min_error = error
        best_rank = rank

print('The best model was trained with rank {}'.format(best_rank))

SyntaxError: invalid syntax (<ipython-input-91-f1dd84607756>, line 21)

In [16]:
# ratings.show(5), movies.show(5)