### 1. Import libraries and initiate the spark session

In [1]:
# import libraries
from pyspark.mllib.linalg import Matrices
from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import *

import numpy as np

# initiate the spark session
spark = SparkSession \
    .builder \
    .master('yarn') \
    .appName('matrix-multiplication') \
    .enableHiveSupport() \
    .getOrCreate()


In order to retain the original row indices after multiplying two matrices, we'll add a few dummy values to the original matrices so that we can reproduce the original row indices as the first column of the resulting matrix after multiplication.

<img src="img/1_matrix_multiplication.png" alt="matrix multiplication" style="width: 1000px; height: 175px;"/>
<img src="img/2_matrix_multiplication_with_dummy_values.png" alt="matrix multiplication with dummy values" style="width: 1000px; height: 175px;"/>

### 2. Create the guest feature matrix

In [2]:
# create the guest feature matrix (an IndexedRowMatrix)
# first column in the matrix is composed of the dummy values
# corresponding to the actual row indices
num_guests = 1000000
guest_feature_vector_length = 100

guest_feature_rows = spark.sparkContext.parallelize(range(num_guests), 100) \
    .map(lambda guest_id: IndexedRow(
        guest_id, [guest_id] + list(np.random.randint(low = 1, high = 10, size = guest_feature_vector_length))
    ))

guest_feature_matrix = IndexedRowMatrix(guest_feature_rows.repartition(200))
print(f'dimensions of guest feature matrix: {guest_feature_matrix.numRows()} x {guest_feature_matrix.numCols()}')


dimensions of guest feature matrix: 1000000 x 101


### 3. Create the item feature matrix

In [3]:
# create the item feature matrix (a local LocalMatrix)
# first column and first row in the matrix is composed of the dummy values
num_items = 1000
item_feature_vector_length = 100

# first row
item_feature_array = [1] + [0] * item_feature_vector_length

# rest of the rows
for i in range(num_items):
    item_feature_array += [0] + list(np.random.randint(low = 1, high = 10, size = item_feature_vector_length))
item_feature_matrix = Matrices.dense(item_feature_vector_length + 1, num_items + 1, item_feature_array)
print(f'dimensions of item feature matrix: {item_feature_vector_length + 1} x {num_items + 1}')


dimensions of item feature matrix: 101 x 1001


### 4. Multiply guest features and item features to get the ratings

In [4]:
# calculate the guest-item rating matrix by multiplying
# the guest feature matrix with item feature matrix
ratings_matrix = guest_feature_matrix.multiply(item_feature_matrix)
print(f'dimensions of ratings matrix: {ratings_matrix.numRows()} x {ratings_matrix.numCols()}')


dimensions of ratings matrix: 1000000 x 1001


In [5]:
# extract the guest rating vectors
ratings_rdd = ratings_matrix.rows \
    .repartition(500) \
    .map(lambda ele: (ele.index, ele.vector.toArray().tolist())) \
    .map(lambda ele: (int(ele[0]), int(ele[1][0]), ele[1][1:]))

schema = StructType([
    StructField('raw_guest_index', IntegerType(), True),
    StructField('dummy_guest_index', IntegerType(), True),
    StructField('guest_rating_array', ArrayType(DoubleType()), True)
])

ratings = spark.createDataFrame(ratings_rdd, schema)
ratings.filter(col('raw_guest_index') != col('dummy_guest_index')).sample(0.1).show(10)


+---------------+-----------------+--------------------+
|raw_guest_index|dummy_guest_index|  guest_rating_array|
+---------------+-----------------+--------------------+
|         418290|           178850|[2170.0, 2487.0, ...|
|         418294|           178854|[2049.0, 2349.0, ...|
|         461013|           671413|[2141.0, 2292.0, ...|
|         461014|           671414|[2364.0, 2602.0, ...|
|         461019|           671419|[2415.0, 2653.0, ...|
|         742063|           563233|[2192.0, 2628.0, ...|
|         742068|           563238|[2177.0, 2673.0, ...|
|         619911|           128971|[2563.0, 2807.0, ...|
|         898620|           229150|[2238.0, 2456.0, ...|
|         709719|           968309|[2389.0, 2627.0, ...|
+---------------+-----------------+--------------------+
only showing top 10 rows

