From 2a78b6693225975a6a1159b9d13c8a8210e6a4f1 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 5 Jul 2022 15:28:58 -0400 Subject: [PATCH] Fix base OPT-125M and finetuned OPT models in Colab TPU instances --- tpu_mtj_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 458a67bb8..da0511df2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1225,13 +1225,14 @@ def callback(model_dict, f, **_): if utils.num_shards is not None: utils.current_shard += 1 for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): + model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None) # Some model weights are used by transformers but not by MTJ. # We have to materialize these weights anyways because # transformers will throw a tantrum otherwise. To attain # the least possible memory usage, we create them as meta # tensors, which don't take up any actual CPU or TPU memory. - if key not in model_spec: + if model_spec_key is None: model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") utils.bar.update(1) continue @@ -1246,7 +1247,7 @@ def callback(model_dict, f, **_): if current_offset != model_dict[key].seek_offset: f.read(model_dict[key].seek_offset - current_offset) current_offset = model_dict[key].seek_offset - spec = model_spec[key] + spec = model_spec[model_spec_key] transforms = set(spec.get("transforms", ())) if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): error = f"Duplicate key {repr(key)}"