In [0]:
from io import StringIO
from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry

In [0]:
mat1_rdd = sc.textFile("/FileStore/tables/matrix1_small.txt")
mat2_rdd = sc.textFile("/FileStore/tables/matrix2_small.txt")

In [0]:
# function that takes a list of lists and creates a list of matrix entry objects with row index, col index, and value
def create_matrix_entries(arr):
    mat_entries = []
    for i in range(len(arr)):
        for j in range(len(arr[i])):
            mat_entries.append(MatrixEntry(i, j, int(arr[i][j])))
            
    return mat_entries

In [0]:
# create list of matrix entries
mat1_entries = create_matrix_entries(mat1_rdd.map(lambda x: x.strip().split(' ')).collect())
mat2_entries = create_matrix_entries(mat2_rdd.map(lambda x: x.strip().split(' ')).collect())

In [0]:
# create coordinate matrix using list of matrix entries parallelized to an RDD
mat1 = CoordinateMatrix(sc.parallelize(mat1_entries))
mat2 = CoordinateMatrix(sc.parallelize(mat2_entries))

In [0]:
# create m and n matrices that have column as the key
m = mat1.entries.map(lambda x: (x.j, (x.i, x.value)))
n = mat2.entries.map(lambda x: (x.j, (x.i, x.value)))

In [0]:
# join m and n matrices and map values to be of format (j, (i, k), (v * w)) and then reduce by j and sum up the products of ((i, k), (v * w)) and map to a matrix entry object
product = m.join(n).map(lambda x: ((x[1][0][0], x[1][1][0]), (x[1][0][1] * x[1][1][1]))).reduceByKey(lambda x, y: x+y).map(lambda x: MatrixEntry(x[0][0], x[0][1], x[1]))

In [0]:
# create coordinate matrix with the product
final = CoordinateMatrix(product)

In [0]:
# create output using final product
output = [[0 for j in range(final.numCols())] for i in range(final.numRows())]
for me in final.entries.collect():
    output[me.i][me.j] = me.value

In [0]:
# outputting to file

with StringIO() as file:
        
    for i in output:
        out = ''
        for j in i:
            out += f'{j} '
        
        file.write(out.strip() +'\n')
#         print(out.strip() +'\n', end='')
        
    dbutils.fs.put("/FileStore/tables/Hw2_Q6_output.txt", file.getvalue(), True)

Wrote 140000 bytes.
