Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out PyTorch device setting #503

Merged
merged 14 commits into from May 18, 2022
Merged

Factor out PyTorch device setting #503

merged 14 commits into from May 18, 2022

Conversation

kuutsav
Copy link
Contributor

@kuutsav kuutsav commented May 13, 2022

Addresses #493

@kuutsav kuutsav changed the title implements #493 Factor out PyTorch device setting May 13, 2022
@ascillitoe
Copy link
Contributor

Hi @kuutsav, welcome! and thanks for contributing 🙂

@kuutsav kuutsav requested a review from ascillitoe May 13, 2022 12:07
Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @kuutsav. The overall strategy looks good to me (assuming you are OK with the get_torch_device function name and location @arnaudvl?).

There are a few issues that need addressing though:

  • The flake8 linting issues. See CONTRIBUTING.md for more guidance, but essentially if you run pip install -r requirements/dev.txt you can then run flake8 alibi_detect and mypy alibi_detect to check these both pass locally before pushing.
  • The new function can also be added to pytorch/learned_kernel.py and pytorch/classifier.py. (You've correctly left out pytorch/spot_the_diff.py, the logic is different for this one).

@arnaudvl
Copy link
Contributor

Thanks for the contribution @kuutsav . Just left a small comment which is applicable to the various detectors.

@kuutsav
Copy link
Contributor Author

kuutsav commented May 13, 2022

Thanks for the contribution @kuutsav . Just left a small comment which is applicable to the various detectors.

@arnaudvl Have addressed the issue along with the linting issue. Have also added the method at pytorch/learned_kernel.py and pytorch/classifier.py.

Do we want to print or log here?
print('No GPU detected, fall back on CPU.')

@kuutsav kuutsav requested a review from arnaudvl May 13, 2022 12:28
@ascillitoe
Copy link
Contributor

ascillitoe commented May 13, 2022

Do we want to print or log here?
print('No GPU detected, fall back on CPU.')

Good question. IMO we should stick with logging.

@kuutsav kuutsav requested a review from ascillitoe May 13, 2022 12:47
Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Would be grateful if you could fix the 3 minor comment typo's I've just spotted (these were already there)

I'll also leave @arnaudvl to review one more time.

alibi_detect/cd/pytorch/lsdd.py Outdated Show resolved Hide resolved
alibi_detect/cd/pytorch/lsdd_online.py Outdated Show resolved Hide resolved
alibi_detect/cd/pytorch/mmd_online.py Outdated Show resolved Hide resolved
@ascillitoe ascillitoe linked an issue May 13, 2022 that may be closed by this pull request
@arnaudvl
Copy link
Contributor

One more place where it makes sense to use this utility function is the alibi_detect.utils.pytorch.predict_batch functionality. Note that the input argument type hint would need to be updated as well. Besides my comments, looks good!

@kuutsav
Copy link
Contributor Author

kuutsav commented May 14, 2022

One more place where it makes sense to use this utility function is the alibi_detect.utils.pytorch.predict_batch functionality. Note that the input argument type hint would need to be updated as well. Besides my comments, looks good!

Made this change as well. Had to ignore the types at couple of places as mypy was complaining due to reassignment of device.

@ascillitoe
Copy link
Contributor

Sorry @kuutsav, after looking at this a bit more (after the most recent mypy errors), I've realised the current changes to utils/pytorch/prediction.py are not quite right.

Changing device: torch.device to device: Optional[str] will raise more issues since device is type torch.device in predict_batch_transformer (which calls predict_batch), and in cd.preprocess.preprocess_drift (which calls both of these functions).

Instead of adding multiple (i.e. 5 or 6) type: ignore's, I've instead tweaked get_device so that it can take in a torch.device and simply return it untouched, for compatibility with predict_batch. @arnaudvl what do you think? I'll wait to get your thoughts before reviewing again...

Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but will leave @arnaudvl to do a final pass since I'm now self-reviewing...

@ascillitoe
Copy link
Contributor

Congrats (and thanks) on your first contribution @kuutsav! 🎉

@ascillitoe ascillitoe merged commit d5cf67b into SeldonIO:master May 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Factor out PyTorch device setting
3 participants