Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor torch device types out of od and into _types #829

Merged
merged 13 commits into from
Jul 26, 2023

Conversation

mauicv
Copy link
Collaborator

@mauicv mauicv commented Jul 11, 2023

What is this:

Defines TorchDeviceTypes: TypeAlias = Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] in _types.py and refactors the typing for the device in the detectors.

fixes #779, #679. Also fixes #763

@codecov
Copy link

codecov bot commented Jul 11, 2023

Codecov Report

Merging #829 (a3519f3) into master (d19cf09) will increase coverage by 0.08%.
The diff coverage is 94.84%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #829      +/-   ##
==========================================
+ Coverage   81.90%   81.98%   +0.08%     
==========================================
  Files         159      159              
  Lines       10338    10375      +37     
==========================================
+ Hits         8467     8506      +39     
+ Misses       1871     1869       -2     
Files Changed Coverage Δ
alibi_detect/saving/schemas.py 97.96% <86.84%> (-0.82%) ⬇️
alibi_detect/cd/classifier.py 100.00% <100.00%> (ø)
alibi_detect/cd/context_aware.py 97.50% <100.00%> (+0.06%) ⬆️
alibi_detect/cd/keops/learned_kernel.py 94.20% <100.00%> (+0.04%) ⬆️
alibi_detect/cd/keops/mmd.py 98.24% <100.00%> (+0.03%) ⬆️
alibi_detect/cd/learned_kernel.py 100.00% <100.00%> (ø)
alibi_detect/cd/lsdd.py 97.14% <100.00%> (+0.08%) ⬆️
alibi_detect/cd/lsdd_online.py 93.75% <100.00%> (+0.13%) ⬆️
alibi_detect/cd/mmd.py 97.77% <100.00%> (+0.05%) ⬆️
alibi_detect/cd/mmd_online.py 94.44% <100.00%> (+0.10%) ⬆️
... and 31 more

@mauicv mauicv requested a review from ascillitoe July 11, 2023 15:49
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
Device type used. The default tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of
``torch.device``. Only relevant for 'pytorch' backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, if you update the intersphinx_mapping like pytorch/pytorch#10400 and then reference torch.device like :py:class:torch.device (should be in backticks), does it work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I leave this to explore in a separate issue? This PR already has a much wider scope than initially intended! 😅

@@ -37,3 +34,5 @@
# type aliases, for use with mypy (must be FwdRef's if involving opt. deps.)
OptimizerTF: TypeAlias = Union['tf.keras.optimizers.Optimizer', 'tf.keras.optimizers.legacy.Optimizer',
Type['tf.keras.optimizers.Optimizer'], Type['tf.keras.optimizers.legacy.Optimizer']]

TorchDeviceType: TypeAlias = Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re the forward reference 'torch.device' in here, I can't think of a good fix at the moment, but just noting that this introduces lots of additional sphinx warnings, and is not rendered "perfectly" in the docs (we've gone from 6 to 29 warnings, which makes me sad).

I suspect the forward ref would be resolved during docs compilation if we installed alibi-detect[all] on read-the-docs (#499) which is now allowed, but it seems wasteful...

Copy link
Collaborator Author

@mauicv mauicv Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮‍💨 arg... I'll open an issue. Maybe this PR might need to be reigned in! Or split into two!

device
Device type used. The default tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of
``torch.device``.
Copy link
Contributor

@ascillitoe ascillitoe Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary indents?

Copy link
Contributor

@ascillitoe ascillitoe Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring also seems slightly inaccurate? Maybe just something like Torch device to be serialised.?


Returns
-------
a string with value ``'cuda'`` or ``'cpu'``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str(torch.device('cuda:0')) will return 'cuda:0', which makes the Returns docstring slightly incorrect, but will also break our save/load. I think save/load itself would work, as 'cuda:0' will be resolved by get_device just fine. However, pydantic validation will fail since we have Literal['cpu', 'gpu', 'cuda'].

Possible solutions to me are:

  1. Implement Inconsistency in device kwarg between detectors and preprocess_drift function #679 (comment) properly, by implementing a custom pydantic validator to properly validate 'cuda:<int>' strings.
  2. Relax the pydantic validation in schemas.py to device: Optional[str] = None for now.
  3. Remove support for passing torch.device from this PR completely.
  4. Do nothing, except throw a warning/error in get_device if torch.device passed with a device index. So user knows they cannot serialise the detector when doing this...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we just format the str(torch.device('cuda:0')) to remove the device index and raise a warning alerting the user to the change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution. It is simple, and prevents serialised detectors being unloadable e.g. if saved with cuda:8 and loaded on a 4 gpu machine.

If we extend the pydantic validation to support the device index in the future, we could still save as cuda, and the user could manually add a device index in the config.toml if they desired.

Copy link
Collaborator Author

@mauicv mauicv Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so when saving the detector first gets the config, then validates and then replaces the values with the string representations... 🤔 I've added the Pydantic validation as It seems like the best way of going about this.

I've kept it simple for now though, it just validates the device type from str(device).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn I forgot about the pre-saving validation...

@@ -188,6 +189,11 @@ def _save_detector_config(detector: ConfigurableDetector,
if optimizer is not None:
cfg['optimizer'] = _save_optimizer_config(optimizer)

# Serialize device
device = cfg.get('device')
if device is not None:
Copy link
Contributor

@ascillitoe ascillitoe Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@ascillitoe ascillitoe Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the _save_device_config wrapper, isn't it easier just to do cfg['device'] = save_device_config_pt(device) here?

Granted, we do have a _save_optimizer_config wrapper, but that is a little different since we do have some sort of optimizer for tensorflow and torch. For device, is is torch only atm so not sure we need the wrapper...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p.s. Maybe _save_device would be more accurate than _save_device_config? _save_optimizer_config etc are named _config since they do actually return a "config dict", whereas _save_device_config is only returning a str.


# if device is not none then we're using pytorch
if device is not None:
return save_device_config_pt(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't if device is not None unnecessary? Can you even arrive inside _save_device_config if device is None? Because its already checked here?

@@ -295,7 +295,7 @@ class PreprocessConfig(CustomBaseModel):
Optional tokenizer for text drift. Either a string referencing a HuggingFace tokenizer model name, or a
:class:`~alibi_detect.utils.schemas.TokenizerConfig`.
"""
device: Optional[Literal['cpu', 'cuda']] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: agreed to format device string to remove device index prior to saving. See this comment

Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments, main one regarding serialisation.

I'll do a final pass once tests are written. Regarding tests, I reckon we could get away with a single unit test saving with save_device_config and then running through get_device? Parameterised with all the supported device types...

@mauicv mauicv requested a review from ascillitoe July 21, 2023 16:30
Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM bar one minor nitpick

@mauicv mauicv merged commit c2f0a5a into SeldonIO:master Jul 26, 2023
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor device types to _types for outlier detectors Remove duplicated utils._types logic
2 participants