In [1]:
import os
import sys
import re
import random
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql.functions import *
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d


import pyspark.mllib.recommendation as recommendation

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import pandas as pd
import numpy as np
from time import time

sqlContext = SQLContext(sc)
base = "/user/user26/netflix_dataset/"

training_folder = base + 'training_set/'

TRAINING DATASET FILE DESCRIPTION    
================================================================================    
                                                                                                                                   
The file "training_set.tar" is a tar of a directory containing 17770 files, one per movie.  The first line of each file contains the movie id followed by a colon.  Each subsequent line in the file corresponds to a rating from a customer and its date in the following format:

CustomerID,Rating,Date                                                          

- MovieIDs range from 1 to 17770 sequentially.    
- CustomerIDs range from 1 to 2649429, with gaps. There are 480189 users.
- Ratings are on a five star (integral) scale from 1 to 5.
- Dates have the format YYYY-MM-DD.

In [2]:
rawMovieRatings = sc.textFile(training_folder + 'mv_*')

movieID = -1

def xtractFields(s):
    # Using white space or tab character as separators,
    # split a line into list of strings 
    line = re.split(',',s)
    global movieID
    # if this line has at least 2 characters
    if (len(line) == 1): 
        movieID = line[0][0:-1]
        return (-1,-1,-1,-1)
    
    if (len(line) == 3):
        #try:      
            # try to parse the first and the second components to integer type
        return (movieID, line[0], line[1], line[2])
        #except ValueError:
            # if parsing has any error, return a special tuple
        #    return (-1,-1,-1,-2)
    else:
        # if this line has less than 2 characters, return a special tuple
        return (-1,-1,-1,-3)

movieRatingsRDD = rawMovieRatings.map( xtractFields ).filter(lambda x : x[0] != -1)

Movie | UserID | Rating | Date

In [3]:
movieRatingsRDD.takeSample(False, 10)


[('215', '108846', '5', '2004-11-02'),
 ('357', '2049750', '4', '2004-04-05'),
 ('334', '2377061', '3', '2005-12-01'),
 ('1862', '1830717', '3', '2004-03-04'),
 ('706', '363769', '5', '2005-06-27'),
 ('1561', '532817', '2', '2005-03-15'),
 ('378', '52153', '3', '2005-10-24'),
 ('199', '1593840', '5', '2005-03-19'),
 ('1054', '2119233', '3', '2002-07-17'),
 ('571', '1334087', '5', '2005-10-03')]

In [15]:
allData = movieRatingsRDD.map(lambda r: recommendation.Rating(r[1], r[0], r[2])).repartition(36).cache()

In [16]:
model = recommendation.ALS.train(allData, rank=50, blocks=36)

In [17]:
predictData = model.predictAll(allData.map(lambda x : (x.user, x.product)))

In [18]:
predictData.take(10)

[Rating(user=2535624, product=1962, rating=4.8207341047002465),
 Rating(user=2535624, product=313, rating=3.895793790054712),
 Rating(user=2535624, product=1145, rating=4.706000361166939),
 Rating(user=2535624, product=353, rating=3.925963342614346),
 Rating(user=2535624, product=760, rating=4.977842819197679),
 Rating(user=2535624, product=78, rating=3.9957500272736004),
 Rating(user=2535624, product=483, rating=3.9187013687188617),
 Rating(user=2535624, product=1798, rating=3.845024550094224),
 Rating(user=1063980, product=1307, rating=2.9265856977282994),
 Rating(user=1063980, product=700, rating=4.925513359182486)]

In [21]:
def RMSE(allData, predictData, model):
    predictionRatings = predictData.map(lambda x:((x.user, x.product), x.rating))\
                                   .join(allData.map(lambda x:((x.user, x.product), x.rating)))
    return predictionRatings

In [22]:
RMSE(allData, predictData, model).take(10)

[((751367, 83), (4.510286220894525, 5.0)),
 ((910339, 1295), (3.9635258508666174, 5.0)),
 ((1800047, 571), (4.03469739260353, 4.0)),
 ((320901, 1865), (2.202844081534235, 2.0)),
 ((635321, 189), (3.8381724343893664, 4.0)),
 ((1446632, 1798), (4.743015706815126, 5.0)),
 ((1489880, 1542), (4.275209589175548, 5.0)),
 ((178738, 1832), (3.9407140423373095, 4.0)),
 ((309400, 990), (3.3456385352191873, 4.0)),
 ((327808, 1470), (2.714020580562192, 3.0))]