Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for python 3.11 #3190

Merged
merged 70 commits into from Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
fa4e31e
Add support for python 3.11
ddelange May 2, 2023
e7e6bbc
Add missing comma
ddelange May 2, 2023
14e6fee
Deprecate track_grad_norm
ddelange May 2, 2023
dd96a19
Use 'auto' instead of None defaults
ddelange May 2, 2023
9c922b9
Revert devices='auto'
ddelange May 2, 2023
62c2008
Merge branch 'master' into cp311
gradientsky May 2, 2023
cf8d153
fixes to inference; updated timeseries to align with the torch version
gradientsky May 3, 2023
b2aef1a
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange May 3, 2023
28ca9de
Deprecate compute_on_step
ddelange May 3, 2023
0d4f8a4
Rename validation_epoch_end -> on_validation_epoch_end
ddelange May 4, 2023
18fbb5f
Remove property overwrite attempt
ddelange May 4, 2023
6e626af
Update k -> top_k
ddelange May 4, 2023
06b41f2
Bump torchmetrics
ddelange May 4, 2023
6a0387e
Bump pytorch-lightning
ddelange May 4, 2023
2e65ea8
Remove map.py
ddelange May 4, 2023
ef9fb68
Remove map.py leftover
ddelange May 4, 2023
cdbd78e
Omit classes key from log_dict call
ddelange May 5, 2023
4c94b71
Sync torch and torchvision versions with pytorch-lightning
ddelange May 5, 2023
ea9da42
Merge branch 'autogluon:master' into cp311
ddelange May 5, 2023
88d0664
Remove track_grad_norm references
ddelange May 6, 2023
93900e4
Merge branch 'master' into cp311
ddelange May 26, 2023
43e4e16
Fix catboost installation error for Github macos runners
ddelange May 26, 2023
7533e8a
Merge branch 'master' into cp311
ddelange Jun 9, 2023
3429b5c
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Jun 16, 2023
6f1935d
Remove catboost hotfix
ddelange Jun 16, 2023
ac8f128
Bump onnx to 0.15.x
ddelange Jun 19, 2023
8e7d838
Revert onnx version bump
ddelange Jun 19, 2023
8024a1b
Bump tensorrt for cp311 compatibility
ddelange Jun 19, 2023
cb98de7
Update datasets and evaluate
ddelange Jun 19, 2023
75c13d0
Fix typo on version range
ddelange Jun 19, 2023
650cc81
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Jun 22, 2023
22cee55
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Jun 25, 2023
9bd84af
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Jul 3, 2023
a4fabd7
Bump ray version
ddelange Jul 3, 2023
b94c248
Merge branch 'autogluon:master' into cp311
ddelange Jul 17, 2023
7a4eef0
Merge branch 'master' into cp311
Jul 17, 2023
060a6cd
Undo deletion from a merge commit
ddelange Jul 23, 2023
ca8d8ef
Unify torchmetrics version notation
ddelange Jul 23, 2023
0d5519e
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Jul 26, 2023
bee41e4
Merge branch 'cp311' of https://github.com/ddelange/autogluon into cp311
ddelange Jul 26, 2023
c38fe7e
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Jul 28, 2023
9084527
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Aug 1, 2023
f24b561
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Aug 3, 2023
f0cbf1c
Revert merge remnant
ddelange Aug 8, 2023
70d5745
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Aug 8, 2023
983415b
Revert lower bound bumps
ddelange Aug 8, 2023
15da1de
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Aug 10, 2023
4081cdb
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Aug 11, 2023
953d559
Merge branch 'autogluon:master' into cp311
ddelange Aug 16, 2023
ad04285
Merge branch 'master' into cp311
ddelange Aug 24, 2023
6f1d2fd
Update torch version in error message
ddelange Aug 24, 2023
ed3f969
Add cp311 specifier
ddelange Aug 24, 2023
07246b8
Lint
ddelange Aug 24, 2023
8555539
Merge branch 'autogluon:master' into cp311
ddelange Aug 30, 2023
73758c2
Merge branch 'master' into cp311
ddelange Sep 10, 2023
11652e9
Allow torchmetrics 1.1.*
ddelange Sep 11, 2023
0524102
Merge branch 'master' of https://github.com/awslabs/autogluon into cp311
ddelange Sep 18, 2023
7228201
Merge branch 'autogluon:master' into cp311
ddelange Sep 27, 2023
fba061c
Bump ray to 2.7.0
ddelange Oct 2, 2023
6deb782
test
yinweisu Oct 3, 2023
2e62c43
disable tensorrt
yinweisu Oct 3, 2023
5227c72
fix tests
yinweisu Oct 4, 2023
84592cc
isort
yinweisu Oct 4, 2023
4315d6e
lint
yinweisu Oct 4, 2023
5586f48
isort
yinweisu Oct 4, 2023
681709d
fix
yinweisu Oct 4, 2023
c30990d
fix catboost
yinweisu Oct 4, 2023
b82d137
fix
yinweisu Oct 5, 2023
97d9386
Merge branch 'master' into py_311
Oct 9, 2023
bd6b173
address comments
yinweisu Oct 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 8 additions & 6 deletions .github/workflow_scripts/env_setup.sh
Expand Up @@ -16,21 +16,23 @@ function setup_build_contrib_env {
}

function setup_torch_gpu {
# Security-patched torch.
python3 -m pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu118 reinstall_torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So changes from this file actually won't be reflected in our CI because we are checking the permission of the PR submission and avoid people modifying our setup scripts.

Currently there's no easy solution to enable it to run. I'm thinking of allow the script to be used if we tag this PR as something like safe to run in the future. For now, we might need to give @ddelange write permission to our repo briefly and revoke the permission once the PR is ready. @Innixma @gradientsky Ideas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like good practice to keep that closed off 💪

fwiw, I don't mind if a maintainer opens a new PR in favour of this one

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion with the team, I'll be creating a clone of this PR once it's ready to test for changes. Once those changes are being verified, we'll merge this PR and close the clone PR.

}

function setup_torch_cpu {
# Security-patched torch
python3 -m pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu
PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu reinstall_torch
}

function setup_torch_gpu_non_linux {
pip3 install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
setup_torch_gpu
}

function setup_torch_cpu_non_linux {
pip3 install torch==1.13.1 torchvision==0.14.1
setup_torch_cpu
}

function reinstall_torch {
pip3 install --force-reinstall torchvision~=0.15.1
Copy link
Contributor Author

@ddelange ddelange Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pip3 btw? might it be that the autogluon python setup is running a virtual environment, in which symlinked pip is somehow not available? see also the Dockerfile snippet in my earlier comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think pip3 is just because of legacy where we thought pip3 would be correct more often than pip (and in most cases it is identical). However, now it is probably the case that pip3 and pip are almost always the same, so it shouldn't matter if it is changed to pip

}

function setup_hf_model_mirror {
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/platform_tests-command.yml
Expand Up @@ -45,7 +45,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To trigger a platform test, you should be able to do it with comment on this PR by
/platform_tests ref=ddelange:cp311

However, our platform tests are currently failing because of other issues. @gradientsky @Innixma We'll want to fix the platform tests to unblock this PR

steps:
- name: Checkout repository for PR
if: (github.event_name == 'workflow_dispatch')
Expand Down Expand Up @@ -77,7 +77,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout repository for PR
if: (github.event_name == 'workflow_dispatch')
Expand Down Expand Up @@ -109,7 +109,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout repository for PR
if: (github.event_name == 'workflow_dispatch')
Expand Down Expand Up @@ -141,7 +141,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout repository for PR
if: (github.event_name == 'workflow_dispatch')
Expand Down Expand Up @@ -174,7 +174,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout repository for PR
if: (github.event_name == 'workflow_dispatch')
Expand Down Expand Up @@ -214,7 +214,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout repository for PR
if: (github.event_name == 'workflow_dispatch')
Expand Down
2 changes: 1 addition & 1 deletion core/src/autogluon/core/_setup_utils.py
Expand Up @@ -13,7 +13,7 @@
os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', '..')
)

PYTHON_REQUIRES = '>=3.8, <3.11'
PYTHON_REQUIRES = '>=3.8, <3.12'


# Only put packages here that would otherwise appear multiple times across different module's setup.py files.
Expand Down
2 changes: 1 addition & 1 deletion docs/install-cpu-pip.md
Expand Up @@ -4,7 +4,7 @@ pip install -U setuptools wheel

# CPU version of pytorch has smaller footprint - see installation instructions in
# pytorch documentation - https://pytorch.org/get-started/locally/
pip install torch==1.13.1+cpu torchvision==0.14.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install torchvision~=0.15.1 --force-reinstall --extra-index-url https://download.pytorch.org/whl/cpu
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched the docs to torch's new --extra-index-url notation.

torchvision etc have an exact pin on a torch version, so this command is sufficient (and will reinstall cuda version if a user previously had a cpu version installed amd vice versa)


pip install autogluon
```
2 changes: 1 addition & 1 deletion docs/install-cpu-source.md
Expand Up @@ -4,7 +4,7 @@ pip install -U setuptools wheel

# CPU version of pytorch has smaller footprint - see installation instructions in
# pytorch documentation - https://pytorch.org/get-started/locally/
pip install torch==1.13.1+cpu torchvision==0.14.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu
pip install torchvision~=0.15.1 --force-reinstall --extra-index-url https://download.pytorch.org/whl/cpu

git clone https://github.com/autogluon/autogluon
cd autogluon && ./full_install.sh
Expand Down
2 changes: 1 addition & 1 deletion docs/install-gpu-pip.md
Expand Up @@ -3,7 +3,7 @@ pip install -U pip
pip install -U setuptools wheel

# Install the proper version of PyTorch following https://pytorch.org/get-started/locally/
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torchvision~=0.15.1 --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118

pip install autogluon
```
Expand Down
2 changes: 1 addition & 1 deletion docs/install-gpu-source.md
Expand Up @@ -3,7 +3,7 @@ pip install -U pip
pip install -U setuptools wheel

# Install the proper version of PyTorch following https://pytorch.org/get-started/locally/
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torchvision~=0.15.1 --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118

git clone https://github.com/autogluon/autogluon
cd autogluon && ./full_install.sh
Expand Down
2 changes: 1 addition & 1 deletion docs/install-windows-gpu.md
Expand Up @@ -11,7 +11,7 @@ conda activate myenv
4. Install the proper GPU PyTorch version by following the [PyTorch Install Documentation](https://pytorch.org/get-started/locally/) (Recommended). Alternatively, use the following command:

```bash
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torchvision~=0.15.1 --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
```

5. Sanity check that your installation is valid and can detect your GPU via testing in Python:
Expand Down
3 changes: 0 additions & 3 deletions examples/automm/object_detection/detection_eval.py
Expand Up @@ -10,9 +10,6 @@
python detection_eval.py \
--test_path ./VOCdevkit/VOC2007/Annotations/test_cocoformat.json \
--checkpoint_name faster_rcnn_r50_fpn_1x_voc0712

Note that for now it's required to install nightly build torchmetrics.
This will be solved in next pr. (MeanAveragePrecision will be moved to AG temporarily.)
"""

import argparse
Expand Down
3 changes: 0 additions & 3 deletions examples/automm/object_detection/detection_train.py
Expand Up @@ -23,9 +23,6 @@
--lr <learning_rate> \
--wd <weight_decay> \
--epochs <epochs>

Note that for now it's required to install nightly build torchmetrics.
This will be solved in next pr. (MeanAveragePrecision will be moved to AG temporarily.)
"""

import argparse
Expand Down
10 changes: 5 additions & 5 deletions multimodal/setup.py
Expand Up @@ -33,13 +33,13 @@
"evaluate>=0.2.2,<0.4.0",
"accelerate>=0.9,<0.17",
"timm>=0.6.12,<0.7.0",
"torch>=1.9,<1.14",
"torchvision<0.15.0",
"torch>=1.11,<2.1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes users install torch 2.0 by default. Is the CI still testing torch 1.13.1? If so, there is some inconsistency between what we test and what users use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this thread answer your question? #3190 (comment)

after this merges I think there are no explicit tests that run with pytorch 1.x anymore

"torchvision>=0.12.0,<0.16",
"fairscale>=0.4.5,<0.4.14",
"scikit-image>=0.19.1,<0.20.0",
"pytorch-lightning>=1.9.0,<1.10.0",
"scikit-image>=0.19.1,<0.21.0",
"pytorch-lightning>=2.0.0,<2.1",
"text-unidecode>=1.3,<1.4",
"torchmetrics>=0.11.0,<0.12.0",
"torchmetrics~=1.0.0rc0",
"transformers>=4.23.0,<4.27.0",
"nptyping>=1.4.4,<2.5.0",
"omegaconf>=2.1.1,<2.3.0",
Expand Down
Expand Up @@ -16,7 +16,6 @@ optimization:
skip_final_val: False # Flag to skip the last validation
gradient_clip_val: 1
gradient_clip_algorithm: "norm"
track_grad_norm: -1 # Whether to check gradient norm. We can set it to 2 to check for gradient norm.
yinweisu marked this conversation as resolved.
Show resolved Hide resolved
log_every_n_steps: 10
val_metric: null
top_k: 3
Expand Down
7 changes: 3 additions & 4 deletions multimodal/src/autogluon/multimodal/matcher.py
Expand Up @@ -864,15 +864,15 @@ def _fit(

if not hpo_mode:
if num_gpus <= 1:
strategy = None
strategy = "auto"
else:
strategy = config.env.strategy
else:
# we don't support running each trial in parallel without ray lightning
if use_ray_lightning:
strategy = hpo_kwargs.get("_ray_lightning_plugin")
else:
strategy = None
strategy = "auto"
num_gpus = min(num_gpus, 1)

config.env.num_gpus = num_gpus
Expand All @@ -886,7 +886,7 @@ def _fit(
log_filter = LogFilter(blacklist_msgs)
with apply_log_filter(log_filter):
trainer = pl.Trainer(
accelerator="gpu" if num_gpus > 0 else None,
accelerator="gpu" if num_gpus > 0 else "auto",
devices=get_available_devices(
num_gpus=num_gpus,
auto_select_gpus=config.env.auto_select_gpus,
Expand All @@ -910,7 +910,6 @@ def _fit(
log_every_n_steps=OmegaConf.select(config, "optimization.log_every_n_steps", default=10),
enable_progress_bar=enable_progress_bar,
fast_dev_run=config.env.fast_dev_run,
track_grad_norm=OmegaConf.select(config, "optimization.track_grad_norm", default=-1),
val_check_interval=config.optimization.val_check_interval,
check_val_every_n_epoch=config.optimization.check_val_every_n_epoch
if hasattr(config.optimization, "check_val_every_n_epoch")
Expand Down
Expand Up @@ -230,12 +230,13 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
else:
self.evaluate(batch, "val")

def validation_epoch_end(self, validation_step_outputs):
def on_validation_epoch_end(self):
val_result = self.validation_metric.compute()
if self.use_loss:
self.log_dict({"val_direct_loss": val_result}, sync_dist=True)
else:
# TODO: add mAP/mAR_per_class
val_result.pop("classes", None) # introduced in torchmetrics v1.0.0
mAPs = {"val_" + k: v for k, v in val_result.items()}
mAPs["val_mAP"] = mAPs["val_map"]
self.log_dict(mAPs, sync_dist=True)
Expand Down
Expand Up @@ -239,13 +239,13 @@ def training_step(self, batch, batch_idx):

def on_validation_start(self) -> None:
if self.skip_final_val and self.trainer.should_stop:
self.trainer.val_dataloaders = [] # skip the final validation by setting val_dataloaders empty
self.log(
self.validation_metric_name,
self.validation_metric,
on_step=False,
on_epoch=True,
)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@suzhoum Can you verify the fast build mode still behaves well after this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more context in the commit message: 18fbb5f

return super().on_validation_start()

def validation_step(self, batch, batch_idx):
Expand Down
2 changes: 1 addition & 1 deletion multimodal/src/autogluon/multimodal/optimization/utils.py
Expand Up @@ -11,6 +11,7 @@
from pytorch_metric_learning import distances, losses, miners
from torch import nn, optim
from torch.nn import functional as F
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from transformers import Adafactor
from transformers.trainer_pt_utils import get_parameter_names

Expand Down Expand Up @@ -62,7 +63,6 @@
ROOT_MEAN_SQUARED_ERROR,
SPEARMANR,
)
from ..utils.map import MeanAveragePrecision
from .losses import FocalLoss, MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy
from .lr_scheduler import (
get_cosine_schedule_with_warmup,
Expand Down
7 changes: 3 additions & 4 deletions multimodal/src/autogluon/multimodal/predictor.py
Expand Up @@ -1429,15 +1429,15 @@ def _fit(
reduce_bucket_size=config.env.deepspeed_allreduce_size,
)
else:
strategy = None
strategy = "auto"
else:
strategy = config.env.strategy
else:
# we don't support running each trial in parallel without ray lightning
if use_ray_lightning:
strategy = hpo_kwargs.get("_ray_lightning_plugin")
else:
strategy = None
strategy = "auto"
num_gpus = min(num_gpus, 1)

config.env.num_gpus = num_gpus
Expand All @@ -1451,7 +1451,7 @@ def _fit(
log_filter = LogFilter(blacklist_msgs)
with apply_log_filter(log_filter):
trainer = pl.Trainer(
accelerator="gpu" if num_gpus > 0 else None,
accelerator="gpu" if num_gpus > 0 else "auto",
devices=get_available_devices(
num_gpus=num_gpus,
auto_select_gpus=config.env.auto_select_gpus,
Expand All @@ -1475,7 +1475,6 @@ def _fit(
log_every_n_steps=OmegaConf.select(config, "optimization.log_every_n_steps", default=10),
enable_progress_bar=enable_progress_bar,
fast_dev_run=config.env.fast_dev_run,
track_grad_norm=OmegaConf.select(config, "optimization.track_grad_norm", default=-1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, as lightning has removed track_grad_norm from trainer, do you know if lightning has some equivalent design for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref Lightning-AI/pytorch-lightning#16745 they suggest an explicit log_dict. any thoughts on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val_check_interval=config.optimization.val_check_interval,
check_val_every_n_epoch=config.optimization.check_val_every_n_epoch
if hasattr(config.optimization, "check_val_every_n_epoch")
Expand Down
1 change: 0 additions & 1 deletion multimodal/src/autogluon/multimodal/utils/__init__.py
Expand Up @@ -46,7 +46,6 @@
from .inference import extract_from_output, infer_batch, predict, process_batch, use_realtime
from .load import CustomUnpickler, load_text_tokenizers
from .log import LogFilter, apply_log_filter, get_fit_complete_message, get_fit_start_message, make_exp_dir
from .map import MeanAveragePrecision
from .matcher import compute_semantic_similarity, convert_data_for_ranking, create_siamese_model, semantic_search
from .metric import compute_ranking_score, compute_score, get_minmax_mode, get_stopping_threshold, infer_metrics
from .misc import logits_to_prob, merge_bio_format, shopee_dataset, tensor_to_ndarray, visualize_ner
Expand Down
2 changes: 1 addition & 1 deletion multimodal/src/autogluon/multimodal/utils/inference.py
Expand Up @@ -487,7 +487,7 @@ def predict(

if num_gpus <= 1:
# Force set strategy to be None if it's cpu-only or we have only one GPU.
strategy = None
strategy = "auto"

precision = infer_precision(num_gpus=num_gpus, precision=predictor._config.env.precision, cpu_only_warning=False)

Expand Down