Skip to content

Commit

Permalink
Update conversion script.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Nov 30, 2023
1 parent 1b661bc commit 4eaaa3c
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions convert_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def rename_key(rename, name):
return name


def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
Expand All @@ -34,13 +34,15 @@ def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename=
# for k, v in loaded.items():
# print(f'{k}\t{v.shape}\t{v.dtype}')

loaded = {rename_key(rename, k).lower(): v.contiguous()
for k, v in loaded.items()}
# For tensors to be contiguous
for k, v in loaded.items():
for transpose_name in transpose_names:
if transpose_name in k:
loaded[k] = v.transpose(0, 1)
loaded = {rename_key(rename, k).lower(): v.contiguous()
for k, v in loaded.items()}

loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()}

for k, v in loaded.items():
print(f"{k}\t{v.shape}\t{v.dtype}")
Expand All @@ -57,6 +59,8 @@ def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename=


if __name__ == "__main__":
convert_file(args.input, args.output, ["lora_A"], {
"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"})
convert_file(args.input, args.output, ["lora_A"],
rename={"time_faaaa": "time_first", "time_maa": "time_mix",
"lora_A": "lora.0", "lora_B": "lora.1"},
transpose_names=["time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2"])
print(f"Saved to {args.output}")

1 comment on commit 4eaaa3c

@IgorAlexey
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 62, in
convert_file(args.input, args.output, ["lora_A"],
TypeError: convert_file() got multiple values for argument 'rename'

Please sign in to comment.