Skip to content

Support for CSR matrix with record_set() #566

@harthur

Description

@harthur

It's awesome that Estimator.record_set() handles data formatting and uploading to S3. However, in using the FactorizationMachines, the data is a sparse matrix and not a numpy array, e.g. <176304x550182 sparse matrix of type '<class 'numpy.float32'>' with 352608 stored elements in Compressed Sparse Row format>

When trying to use the record_set() API, I get an error:

fm = FactorizationMachines(role=get_execution_role(),
                           train_instance_count=2,
                           train_instance_type='ml.c5.2xlarge',
                           predictor_type='binary_classifier',
                           num_factors=64)

records = fm.record_set(X_train_encoded)
fm.fit(records)
> TypeError: must be real number, not csr_matrix

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions