Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QOL Items - handle metadata issues more cleanly for SD models, Loras and embeddings #15632

10 changes: 6 additions & 4 deletions extensions-builtin/Lora/ui_edit_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def is_non_comma_tagset(tags):
def build_tags(metadata):
tags = {}

for _, tags_dict in metadata.get("ss_tag_frequency", {}).items():
for tag, tag_count in tags_dict.items():
tag = tag.strip()
tags[tag] = tags.get(tag, 0) + int(tag_count)
ss_tag_frequency = metadata.get("ss_tag_frequency", {})
if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'):
for _, tags_dict in ss_tag_frequency.items():
for tag, tag_count in tags_dict.items():
tag = tag.strip()
tags[tag] = tags.get(tag, 0) + int(tag_count)

if tags and is_non_comma_tagset(tags):
new_tags = {}
Expand Down
22 changes: 13 additions & 9 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,17 +282,21 @@ def read_metadata_from_safetensors(filename):
json_start = file.read(2)

assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)

res = {}
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass

try:
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
except Exception:
errors.report(f"Error reading metadata from file: {filename}", exc_info=True)

return res

Expand Down
12 changes: 8 additions & 4 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,16 @@ def load_from_file(self, path, filename):
else:
return

embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
if data is not None:
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)

if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings[name] = embedding
else:
self.skipped_embeddings[name] = embedding
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")


def load_from_dir(self, embdir):
if not os.path.isdir(embdir.path):
Expand Down