Skip to content

Commit

Permalink
Added slider for denoising threshold in preparation for new AI model
Browse files Browse the repository at this point in the history
  • Loading branch information
Steffenhir committed Apr 29, 2024
1 parent 7cd06a5 commit 636873b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
6 changes: 5 additions & 1 deletion graxpert/application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def initialize(self):
eventbus.add_listener(AppEvents.CALCULATE_REQUEST, self.on_calculate_request)
# denoising
eventbus.add_listener(AppEvents.DENOISE_STRENGTH_CHANGED, self.on_denoise_strength_changed)
eventbus.add_listener(AppEvents.DENOISE_THRESHOLD_CHANGED, self.on_denoise_threshold_changed)
eventbus.add_listener(AppEvents.DENOISE_REQUEST, self.on_denoise_request)
# saving
eventbus.add_listener(AppEvents.SAVE_AS_CHANGED, self.on_save_as_changed)
Expand Down Expand Up @@ -322,6 +323,9 @@ def on_smoothing_changed(self, event):

def on_denoise_strength_changed(self, event):
self.prefs.denoise_strength = event["denoise_strength"]

def on_denoise_threshold_changed(self, event):
self.prefs.denoise_threshold = event["denoise_threshold"]

def on_denoise_request(self, event):
if self.images.get("Original") is None:
Expand All @@ -342,7 +346,7 @@ def on_denoise_request(self, event):

self.prefs.images_linked_option = True
ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, progress=progress)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, threshold=self.prefs.denoise_threshold, progress=progress)

denoised = AstroImage()
denoised.set_from_array(imarray)
Expand Down
1 change: 1 addition & 0 deletions graxpert/application/app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class AppEvents(Enum):
CALCULATE_ERROR = auto()
# denoising
DENOISE_STRENGTH_CHANGED = auto()
DENOISE_THRESHOLD_CHANGED = auto()
DENOISE_REQUEST = auto()
DENOISE_BEGIN = auto()
DENOISE_PROGRESS = auto()
Expand Down
29 changes: 19 additions & 10 deletions graxpert/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from graxpert.ui.ui_events import UiEvents


def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, progress=None):
def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, threshold=1.0, progress=None):

logging.info("Starting denoising")

Expand All @@ -26,10 +26,13 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
batch_size = 2 ** (batch_size).bit_length() // 2 # map batch_size to power of two

input = copy.deepcopy(image)

median = np.median(image[::4, ::4, :], axis=[0, 1])
mad = np.median(np.abs(image[::4, ::4, :] - median), axis=[0, 1])

global cached_denoised_image
if cached_denoised_image is not None:
return blend_images(input, cached_denoised_image, strength)
return blend_images(input, cached_denoised_image, strength, threshold, median, mad)

num_colors = image.shape[-1]
if num_colors == 1:
Expand All @@ -56,16 +59,18 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
image = np.concatenate((image, image[:, (w - offset) :, :]), axis=1)
image = np.concatenate((image[:, :offset, :], image), axis=1)

median = np.median(image[::4, ::4, :], axis=[0, 1])
mad = np.median(np.abs(image[::4, ::4, :] - median), axis=[0, 1])

output = copy.deepcopy(image)

providers = get_execution_providers_ordered()
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Available inference providers : {providers}")
logging.info(f"Used inference providers : {session.get_providers()}")

if "1.0.0" in ai_path or "1.1.0" in ai_path:
model_threshold = 1.0
else:
model_threshold = 100.0

last_progress = 0
for b in range(0, ith * itw + batch_size, batch_size):
Expand All @@ -87,7 +92,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
tile = image[x : x + window_size, y : y + window_size, :]
tile = (tile - median) / mad * 0.04
input_tile_copies.append(np.copy(tile))
tile = np.clip(tile, -1.0, 1.0)
tile = np.clip(tile, -model_threshold, model_threshold)

input_tiles.append(tile)

Expand All @@ -114,7 +119,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,

x = stride * i
y = stride * j
tile = np.where(input_tile_copies[t_idx] < 0.95, tile, input_tile_copies[t_idx])
tile = np.where(input_tile_copies[t_idx] < model_threshold, tile, input_tile_copies[t_idx])
tile = tile / 0.04 * mad + median
tile = tile[offset : offset + stride, offset : offset + stride, :]
output[x + offset : stride * (i + 1) + offset, y + offset : stride * (j + 1) + offset, :] = tile
Expand All @@ -134,18 +139,21 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
output = np.moveaxis(output, 0, -1)

cached_denoised_image = output
output = blend_images(input, output, strength)
output = blend_images(input, output, strength, threshold, median, mad)

logging.info("Finished denoising")

return output


def blend_images(original_image, denoised_image, strength):
blend = denoised_image * strength + original_image * (1 - strength)
def blend_images(original_image, denoised_image, strength, threshold, median, mad):
threshold = threshold / 0.04 * mad + median
blend = np.where(original_image < threshold, denoised_image, original_image)
blend = blend * strength + original_image * (1 - strength)
return np.clip(blend, 0, 1)



def reset_cached_denoised_image(event):
global cached_denoised_image
cached_denoised_image = None
Expand All @@ -155,3 +163,4 @@ def reset_cached_denoised_image(event):
eventbus.add_listener(AppEvents.LOAD_IMAGE_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(AppEvents.CALCULATE_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(UiEvents.APPLY_CROP_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, reset_cached_denoised_image)
1 change: 1 addition & 0 deletions graxpert/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Prefs:
denoise_ai_version: AnyStr = None
graxpert_version: AnyStr = graxpert_version
denoise_strength: float = 0.5
denoise_threshold: float = 10.0
ai_batch_size: int = 4


Expand Down
12 changes: 11 additions & 1 deletion graxpert/ui/left_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def __init__(self, parent, **kwargs):
self.denoise_strength = tk.DoubleVar()
self.denoise_strength.set(graxpert.prefs.denoise_strength)
self.denoise_strength.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_STRENGTH_CHANGED, {"denoise_strength": self.denoise_strength.get()}))

self.denoise_threshold = tk.DoubleVar()
self.denoise_threshold.set(graxpert.prefs.denoise_threshold)
self.denoise_threshold.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_THRESHOLD_CHANGED, {"denoise_threshold": self.denoise_threshold.get()}))

self.create_children()
self.setup_layout()
Expand All @@ -247,6 +251,11 @@ def create_children(self):
self.sub_frame, width=default_label_width, variable_name=_("Denoise Strength"), variable=self.denoise_strength, min_value=0.0, max_value=1.0, precision=2
)
tooltip.Tooltip(self.denoise_strength_slider, text=tooltip.bg_tol_text)

self.denoise_threshold_slider = ValueSlider(
self.sub_frame, width=default_label_width, variable_name=_("Denoise Threshold"), variable=self.denoise_threshold, min_value=0.1, max_value=10.0, precision=1
)
tooltip.Tooltip(self.denoise_threshold_slider, text=tooltip.bg_tol_text)

def setup_layout(self):
super().setup_layout()
Expand All @@ -255,7 +264,8 @@ def place_children(self):
super().place_children()

self.denoise_strength_slider.grid(column=1, row=0, pady=pady, sticky=tk.EW)
self.denoise_button.grid(column=1, row=1, pady=pady, sticky=tk.EW)
self.denoise_threshold_slider.grid(column=1, row=1, pady=pady, sticky=tk.EW)
self.denoise_button.grid(column=1, row=2, pady=pady, sticky=tk.EW)

def toggle(self):
super().toggle()
Expand Down

0 comments on commit 636873b

Please sign in to comment.