diff --git a/CHANGELOG.md b/CHANGELOG.md index 933ed3ec2..3724bbffe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states. - Added MMLU downstream evaluation tasks. +- Added support for PyTorch v2.2. ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 58bd4f736..7fe1d94c7 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -34,9 +34,13 @@ ShardedOptimStateDictConfig, ShardedStateDictConfig, ) -from torch.distributed.fsdp.flat_param import FlatParamHandle from torch.futures import Future +try: + from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore +except ModuleNotFoundError: + from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore + from .aliases import PathOrStr from .config import BaseConfig, ShardedCheckpointerType, TrainConfig from .optim import Optimizer, fix_optim_state_dict @@ -1124,6 +1128,8 @@ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]: return [fsdp_model._handle] # type: ignore else: return [] + elif version.parse(torch.__version__) < version.parse("2.3.0"): + return fsdp_model._all_handles else: # Need to verify FSDP internals with newer versions. raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index db9af8201..6ff1c2f91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ requires-python = ">=3.8" license = { file = "LICENSE" } dependencies = [ "numpy", - "torch>=2.0,<2.2", + "torch>=2.0,<2.3", "omegaconf", "rich", "boto3",