Skip to content

Commit

Permalink
Fix new code import
Browse files Browse the repository at this point in the history
Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com>
  • Loading branch information
mikolajblaz committed Jun 19, 2024
1 parent dbe8423 commit f8aade0
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase
from megatron.core.dist_checkpointing.serialization import (
get_default_load_sharded_strategy,
get_default_save_sharded_strategy,
)
from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
Expand All @@ -46,6 +43,13 @@
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.parallel_state import get_data_parallel_group

try:
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
except ImportError:
from megatron.core.dist_checkpointing.serialization import (
_verify_checkpoint_and_load_strategy as get_default_load_sharded_strategy,
)

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError) as e:
Expand Down

0 comments on commit f8aade0

Please sign in to comment.