-
Notifications
You must be signed in to change notification settings - Fork 308
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5389909
commit 5bf0451
Showing
1 changed file
with
15 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,50 @@ | ||
# This helper script scans folders for wildcards and embeddings and writes them | ||
# to a temporary file to expose it to the javascript side | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
# The path to the folder containing the wildcards and embeddings | ||
FILE_DIR = os.path.dirname(os.path.realpath("__file__")) | ||
WILDCARD_PATH = os.path.join(FILE_DIR, 'scripts/wildcards') | ||
EMB_PATH = os.path.join(FILE_DIR, 'embeddings') | ||
FILE_DIR = Path().absolute() | ||
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards') | ||
EMB_PATH = FILE_DIR.joinpath('embeddings') | ||
# The path to the temporary file | ||
TEMP_PATH = os.path.join(FILE_DIR, 'tags/temp') | ||
TEMP_PATH = FILE_DIR.joinpath('tags/temp') | ||
|
||
|
||
def get_wildcards(): | ||
"""Returns a list of all wildcards""" | ||
return filter(lambda f: f.endswith(".txt"), os.listdir(WILDCARD_PATH)) | ||
"""Returns a list of all wildcards. Works on nested folders.""" | ||
wildcard_files = list(WILDCARD_PATH.rglob("*.txt")) | ||
resolved = [str(w.relative_to(WILDCARD_PATH)) for w in wildcard_files] | ||
return resolved | ||
|
||
|
||
def get_embeddings(): | ||
"""Returns a list of all embeddings""" | ||
return filter(lambda f: f.endswith(".bin") or f.endswith(".pt"), os.listdir(EMB_PATH)) | ||
return [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt"}] | ||
|
||
|
||
def write_to_temp_file(name, data): | ||
"""Writes the given data to a temporary file""" | ||
with open(os.path.join(TEMP_PATH, name), 'w', encoding="utf-8") as f: | ||
with open(TEMP_PATH.joinpath(name), 'w', encoding="utf-8") as f: | ||
f.write(('\n'.join(data))) | ||
|
||
|
||
# Check if the temp path exists and create it if not | ||
if not os.path.exists(TEMP_PATH): | ||
os.makedirs(TEMP_PATH) | ||
if not TEMP_PATH.exists(): | ||
TEMP_PATH.mkdir(parents=True, exist_ok=True) | ||
# Set up files to ensure the script doesn't fail to load them | ||
# even if no wildcards or embeddings are found | ||
write_to_temp_file('wc.txt', []) | ||
write_to_temp_file('emb.txt', []) | ||
|
||
# Write wildcards to wc.txt if found | ||
if os.path.exists(WILDCARD_PATH): | ||
if WILDCARD_PATH.exists(): | ||
wildcards = get_wildcards() | ||
if wildcards: | ||
write_to_temp_file('wc.txt', wildcards) | ||
|
||
# Write embeddings to emb.txt if found | ||
if os.path.exists(EMB_PATH): | ||
if EMB_PATH.exists(): | ||
embeddings = get_embeddings() | ||
if embeddings: | ||
write_to_temp_file('emb.txt', embeddings) |