Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing imports for torch dist ckpt in export #9930

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import torch
import yaml
import zarr
from torch.distributed.checkpoint import FileSystemReader
from tensorrt_llm._utils import np_bfloat16
from torch.distributed.checkpoint import FileSystemReader, TensorStorageMetadata
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
from transformers import AutoTokenizer, PreTrainedTokenizer

from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer
Expand Down Expand Up @@ -56,9 +58,11 @@ class TarFileSystemReader(FileSystemReader):
"""

def __init__(self, path: Union[Path, TarPath]) -> None:
"""No call to super().__init__ because it expects pure Path."""
self.path = path
self.storage_data = dict()
"""Makes sure that super().__init__ gets a pure path as expected."""
super_path = str(path) if isinstance(path, TarPath) else path
super().__init__(super_path)
if isinstance(path, TarPath):
self.path = path # overwrites path set in super().__init__ call


def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor=True):
Expand Down Expand Up @@ -228,6 +232,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
unpacked_checkpoint_dir = UnpackedNemoCheckpointDir(nemo_dir, load_checkpoints_to_cpu=True)

dist_ckpt_folder = nemo_dir / "model_weights"

if dist_ckpt_folder.exists():
model = load_sharded_metadata(dist_ckpt_folder)
nemo_model_config = unpacked_checkpoint_dir.model_config
Expand Down