Skip to content

Commit

Permalink
Fix base OPT-125M and finetuned OPT models in Colab TPU instances
Browse files Browse the repository at this point in the history
  • Loading branch information
vfbd committed Jul 5, 2022
1 parent c94f875 commit 2a78b66
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tpu_mtj_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,13 +1225,14 @@ def callback(model_dict, f, **_):
if utils.num_shards is not None:
utils.current_shard += 1
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None)

# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
# transformers will throw a tantrum otherwise. To attain
# the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec:
if model_spec_key is None:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
utils.bar.update(1)
continue
Expand All @@ -1246,7 +1247,7 @@ def callback(model_dict, f, **_):
if current_offset != model_dict[key].seek_offset:
f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
spec = model_spec[key]
spec = model_spec[model_spec_key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
Expand Down

0 comments on commit 2a78b66

Please sign in to comment.