Skip to content

Force dtype to int64 to ensure that we don't index with non-long tensor #258

@TobiasMadsenQiagen

Description

@TobiasMadsenQiagen

In the triplet data loaders (utils.py:load_triplet_data and utils.py:load_raw_triplet_data) the imported data must be forced to be of type int64, to ensure that torch tensors are always long. Otherwise torch may complain that a vector used for indexing is not of type long, when calling predict:

line 186, in __call__
return self.emb[idx].to(self.device)
IndexError: tensors used as indices must be long, byte or bool tensors

np.asarray tries to infer the data type for the input, which on the windows system we have tested on is int32 as long as the input ints are smaller than 2^31-1. On mac and ubuntu we did not observe the problem.
We have tested with dglke 0.1.2.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions