Skip to content

Commit

Permalink
Umbriel
Browse files Browse the repository at this point in the history
  • Loading branch information
schmelly committed Apr 25, 2024
2 parents f124b54 + 7cd06a5 commit 9d3a417
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 22 deletions.
19 changes: 11 additions & 8 deletions graxpert/ai_model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
try:
os.rename(ai_models_dir, bge_ai_models_dir)
except Exception as e:
logging.error(f"Renaming {ai_models_dir} to {bge_ai_models_dir} failed. {bge_ai_models_dir} will be newly created. Consider deleting obsolete {ai_models_dir}.")
logging.error(f"Renaming {ai_models_dir} to {bge_ai_models_dir} failed. {bge_ai_models_dir} will be newly created. Consider deleting obsolete {ai_models_dir} manually.")

os.makedirs(bge_ai_models_dir, exist_ok=True)

Expand Down Expand Up @@ -64,7 +64,11 @@ def list_remote_versions(bucket_name):

def list_local_versions(ai_models_dir):
try:
model_dirs = [{"path": os.path.join(ai_models_dir, f), "version": f} for f in os.listdir(ai_models_dir) if re.search(r"\d\.\d\.\d", f)] # match semantic version
model_dirs = [
{"path": os.path.join(ai_models_dir, f), "version": f}
for f in os.listdir(ai_models_dir)
if re.search(r"\d\.\d\.\d", f) and len(os.listdir(os.path.join(ai_models_dir, f))) > 0 # match semantic version
]
return model_dirs
except Exception as e:
logging.exception(e)
Expand Down Expand Up @@ -122,11 +126,11 @@ def cleanup_orphaned_local_versions(orphaned_local_versions):
logging.exception(e)


def download_version(ai_models_dir, bucket_name, remote_version, progress=None):
def download_version(ai_models_dir, bucket_name, target_version, progress=None):
try:
remote_versions = list_remote_versions(bucket_name)
for r in remote_versions:
if remote_version == r["version"]:
if target_version == r["version"]:
remote_version = r
break

Expand All @@ -144,11 +148,11 @@ def download_version(ai_models_dir, bucket_name, remote_version, progress=None):

with zipfile.ZipFile(ai_model_zip, "r") as zip_ref:
zip_ref.extractall(ai_model_dir)

if not os.path.isfile(ai_model_file):
raise ValueError(f"Could not find ai 'model.onnx' file after extracting {ai_model_zip}")
os.remove(ai_model_zip)

except Exception as e:
# try to delete (rollback) ai_model_dir in case of errors
logging.exception(e)
Expand All @@ -163,7 +167,6 @@ def validate_local_version(ai_models_dir, local_version):


def get_execution_providers_ordered():
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider",
"CPUExecutionProvider"]
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]

return [provider for provider in supported_providers if provider in ort.get_available_providers()]
30 changes: 21 additions & 9 deletions graxpert/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,31 @@
import onnxruntime as ort

from graxpert.ai_model_handling import get_execution_providers_ordered
from graxpert.application.eventbus import eventbus
from graxpert.application.app_events import AppEvents
from graxpert.application.eventbus import eventbus
from graxpert.ui.ui_events import UiEvents

def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, progress=None):

def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, progress=None):

logging.info("Starting denoising")

if batch_size < 1:
logging.info(f"mapping batch_size of {batch_size} to 1")
batch_size = 1
elif batch_size > 32:
logging.info(f"mapping batch_size of {batch_size} to 32")
batch_size = 32
elif not (batch_size & (batch_size - 1) == 0): # check if batch_size is power of two
logging.info(f"mapping batch_size of {batch_size} to {2 ** (batch_size).bit_length() // 2}")
batch_size = 2 ** (batch_size).bit_length() // 2 # map batch_size to power of two

input = copy.deepcopy(image)

global cached_denoised_image
if cached_denoised_image is not None:
return blend_images(input, cached_denoised_image, strength)

num_colors = image.shape[-1]
if num_colors == 1:
image = np.array([image[:, :, 0], image[:, :, 0], image[:, :, 0]])
Expand Down Expand Up @@ -51,9 +62,7 @@ def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128,
output = copy.deepcopy(image)

providers = get_execution_providers_ordered()
ort_options = ort.SessionOptions()
ort_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
session = ort.InferenceSession(ai_path, providers=providers, sess_options=ort_options)
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Available inference providers : {providers}")
logging.info(f"Used inference providers : {session.get_providers()}")
Expand Down Expand Up @@ -84,7 +93,7 @@ def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128,

if not input_tiles:
continue

input_tiles = np.array(input_tiles)

output_tiles = []
Expand Down Expand Up @@ -131,15 +140,18 @@ def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128,

return output


def blend_images(original_image, denoised_image, strength):
blend = denoised_image * strength + original_image * (1-strength)
blend = denoised_image * strength + original_image * (1 - strength)
return np.clip(blend, 0, 1)


def reset_cached_denoised_image(event):
global cached_denoised_image
cached_denoised_image = None


cached_denoised_image = None
eventbus.add_listener(AppEvents.LOAD_IMAGE_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(AppEvents.CALCULATE_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(UiEvents.APPLY_CROP_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(UiEvents.APPLY_CROP_REQUEST, reset_cached_denoised_image)
9 changes: 6 additions & 3 deletions graxpert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def collect_available_versions(ai_models_dir, bucket_name):
def bge_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):
return version_type(bge_ai_models_dir, bge_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$"))

def denoise_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):
return version_type(denoise_ai_models_dir, denoise_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$"))


def version_type(ai_models_dir, bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):

Expand Down Expand Up @@ -227,7 +230,7 @@ def main():
nargs="?",
required=False,
default=None,
type=bge_version_type,
type=denoise_version_type,
help='Version of the Denoising AI model, default: "latest"; available locally: [{}], available remotely: [{}]'.format(
", ".join(available_denoise_versions[0]), ", ".join(available_denoise_versions[1])
),
Expand All @@ -239,7 +242,7 @@ def main():
required=False,
default=None,
type=float,
help='Strength of the desired denoising effect, default: "1.0"',
help='Strength of the desired denoising effect, default: "0.5"',
)
denoise_parser.add_argument(
"-batch_size",
Expand All @@ -248,7 +251,7 @@ def main():
required=False,
default=None,
type=int,
help='Number of image tiles which Graxpert will denoise in parallel. Be careful: increasing this value might result in out-of-memory errors. Valid Range: 1..50, default: "3"',
help='Number of image tiles which Graxpert will denoise in parallel. Be careful: increasing this value might result in out-of-memory errors. Valid Range: 1..32, default: "4"',
)

if "-h" in sys.argv or "--help" in sys.argv:
Expand Down
2 changes: 1 addition & 1 deletion graxpert/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Prefs:
denoise_ai_version: AnyStr = None
graxpert_version: AnyStr = graxpert_version
denoise_strength: float = 0.5
ai_batch_size: int = 3
ai_batch_size: int = 4


def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs:
Expand Down
4 changes: 3 additions & 1 deletion graxpert/ui/right_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(self, master, **kwargs):
self.denoise_ai_version.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_AI_VERSION_CHANGED, {"denoise_ai_version": self.denoise_ai_version.get()}))

# ai settings
self.ai_batch_size_options = ["1","2","4","8","16","32"]
self.ai_batch_size = tk.IntVar()
self.ai_batch_size.set(graxpert.prefs.ai_batch_size)
self.ai_batch_size.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.AI_BATCH_SIZE_CHANGED, {"ai_batch_size": self.ai_batch_size.get()}))
Expand Down Expand Up @@ -239,7 +240,8 @@ def lang_change(lang):
GraXpertOptionMenu(self, variable=self.denoise_ai_version, values=self.denoise_ai_options).grid(**self.default_grid())

# ai settings
ValueSlider(self, variable=self.ai_batch_size, variable_name=_("AI Batch Size"), min_value=1, max_value=50, precision=0).grid(**self.default_grid())
CTkLabel(self, text=_("AI inference batch size"), font=self.heading_font2).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N)
GraXpertOptionMenu(self, variable=self.ai_batch_size, values=self.ai_batch_size_options).grid(**self.default_grid())

def setup_layout(self):
self.columnconfigure(0, weight=1)
Expand Down

0 comments on commit 9d3a417

Please sign in to comment.