Skip to content

Commit

Permalink
revise als python version: add command line argument passing; add tim…
Browse files Browse the repository at this point in the history
…e measurement; can test with data_csv
  • Loading branch information
NanmiaoWu committed Aug 5, 2020
1 parent 08a4e87 commit 573a106
Showing 1 changed file with 46 additions and 24 deletions.
70 changes: 46 additions & 24 deletions examples/algorithms/als/als.py
Expand Up @@ -4,6 +4,13 @@
# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

import numpy as np
from numpy import genfromtxt
import argparse
import time


def slice_array(a, row_start, row_stop, col_start, col_stop):
return a[row_start:row_stop, col_start:col_stop]


def ALS(ratings, regularization, num_factors, iterations, alpha, enable_output):
Expand Down Expand Up @@ -70,27 +77,42 @@ def ALS(ratings, regularization, num_factors, iterations, alpha, enable_output):
return [X, Y]


# test example
ratings = np.zeros((10, 5))
ratings[0, 1] = 4
ratings[1, 0] = 1
ratings[1, 2] = 4
ratings[1, 4] = 5
ratings[2, 3] = 2
ratings[3, 1] = 8
ratings[4, 2] = 4
ratings[6, 4] = 2
ratings[7, 0] = 1
ratings[8, 3] = 5
ratings[9, 0] = 1
ratings[9, 3] = 2

regularization = 0.1
alpha = 40
iterations = 500
num_factors = 3
enable_output = False

result = ALS(ratings, regularization, num_factors, iterations, alpha, enable_output)
print("X = ", result[0])
print("Y = ", result[1])
if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Non-distributed ALS')
parser.add_argument('-i', '--iterations', type=int, default=3,
help='number of iterations')
parser.add_argument('-f', '--num_factors', type=int, default=10,
help='number of factors')
parser.add_argument('--row_start', type=int, default=0,
help='row_start')
parser.add_argument('--row_stop', type=int, default=10,
help='row_sop')
parser.add_argument('--col_start', type=int, default=0,
help='col_start')
parser.add_argument('--col_stop', type=int, default=20,
help='col_sop')
parser.add_argument('-r', '--regularization', type=float, default=0.1,
help='regularization')
parser.add_argument('-a', '--alpha', type=float, default=40,
help='alpha')
parser.add_argument('-e', '--enable_output', type=bool, default=False,
help='enable_output')
parser.add_argument('--data_csv', required=True,
help='file name for reading data')
args = parser.parse_args()

t0 = time.perf_counter()

data_csv = genfromtxt(args.data_csv, delimiter=',')
ratings = slice_array(data_csv, args.row_start, args.row_stop,
args.col_start, args.col_stop)

result = ALS(ratings, args.regularization, args.num_factors,
args.iterations, args.alpha, args.enable_output)

t1 = time.perf_counter() - t0

print("X = ", result[0])
print("Y = ", result[1])
print("elapsed time is: ", t1)

0 comments on commit 573a106

Please sign in to comment.