Skip to content

Commit

Permalink
Merge branch 'feature/denoising' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
schmelly committed May 2, 2024
2 parents 6c1af7f + ce16269 commit 3b9ebe6
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 41 deletions.
10 changes: 7 additions & 3 deletions graxpert/ai_model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import shutil
import zipfile

import onnxruntime as ort
from appdirs import user_data_dir
from minio import Minio
from packaging import version
import onnxruntime as ort

try:
from graxpert.s3_secrets import endpoint, ro_access_key, ro_secret_key
Expand Down Expand Up @@ -166,7 +166,11 @@ def validate_local_version(ai_models_dir, local_version):
return os.path.isfile(os.path.join(ai_models_dir, local_version, "model.onnx"))


def get_execution_providers_ordered():
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
def get_execution_providers_ordered(gpu_acceleration=True):

if gpu_acceleration:
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
else:
supported_providers = ["CPUExecutionProvider"]

return [provider for provider in supported_providers if provider in ort.get_available_providers()]
Empty file.
41 changes: 28 additions & 13 deletions graxpert/application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ def initialize(self):
eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, self.on_denoise_ai_version_changed)
eventbus.add_listener(AppEvents.SCALING_CHANGED, self.on_scaling_changed)
eventbus.add_listener(AppEvents.AI_BATCH_SIZE_CHANGED, self.on_ai_batch_size_changed)
eventbus.add_listener(AppEvents.AI_GPU_ACCELERATION_CHANGED, self.on_ai_gpu_acceleration_changed)

# event handling
def on_ai_batch_size_changed(self, event):
self.prefs.ai_batch_size = event["ai_batch_size"]

def on_ai_gpu_acceleration_changed(self, event):
self.prefs.ai_gpu_acceleration = event["ai_gpu_acceleration"]

def on_bge_ai_version_changed(self, event):
self.prefs.bge_ai_version = event["bge_ai_version"]

Expand Down Expand Up @@ -155,6 +159,7 @@ def on_calculate_request(self, event=None):
self.prefs.corr_type,
ai_model_path_from_version(bge_ai_models_dir, self.prefs.bge_ai_version),
progress,
self.prefs.ai_gpu_acceleration,
)
)

Expand Down Expand Up @@ -188,7 +193,7 @@ def on_calculate_request(self, event=None):
def on_change_saturation_request(self, event):
if self.images.get("Original") is None:
return

self.prefs.saturation = event["saturation"]

eventbus.emit(AppEvents.CHANGE_SATURATION_BEGIN)
Expand Down Expand Up @@ -323,7 +328,7 @@ def on_smoothing_changed(self, event):

def on_denoise_strength_changed(self, event):
self.prefs.denoise_strength = event["denoise_strength"]

def on_denoise_threshold_changed(self, event):
self.prefs.denoise_threshold = event["denoise_threshold"]

Expand All @@ -346,23 +351,33 @@ def on_denoise_request(self, event):

self.prefs.images_linked_option = True
ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, threshold=self.prefs.denoise_threshold, progress=progress)
imarray = denoise(
img_array_to_be_processed,
ai_model_path,
self.prefs.denoise_strength,
batch_size=self.prefs.ai_batch_size,
threshold=self.prefs.denoise_threshold,
progress=progress,
ai_gpu_acceleration=self.prefs.ai_gpu_acceleration,
)

denoised = AstroImage()
denoised.set_from_array(imarray)
if imarray is not None:

# Update fits header and metadata
background_mean = np.mean(self.images.get("Original").img_array)
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)
denoised = AstroImage()
denoised.set_from_array(imarray)

denoised.copy_metadata(self.images.get("Original"))
# Update fits header and metadata
background_mean = np.mean(self.images.get("Original").img_array)
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)

self.images.set("Denoised", denoised)
denoised.copy_metadata(self.images.get("Original"))

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)
self.images.set("Denoised", denoised)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.DENOISE_SUCCESS)
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"})
eventbus.emit(AppEvents.DENOISE_SUCCESS)
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"})

except Exception as e:
logging.exception(e)
Expand Down
3 changes: 3 additions & 0 deletions graxpert/application/app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@ class AppEvents(Enum):
LANGUAGE_CHANGED = auto()
SCALING_CHANGED = auto()
AI_BATCH_SIZE_CHANGED = auto()
AI_GPU_ACCELERATION_CHANGED = auto()
# process control
CANCEL_PROCESSING = auto()
4 changes: 2 additions & 2 deletions graxpert/background_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def gaussian_kernel(sigma=1.0, truncate=4.0): # follow simulate skimage.filters
return (ksize, ksize)


def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None):
def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None, ai_gpu_acceleration=True):

num_colors = in_imarray.shape[-1]

Expand Down Expand Up @@ -71,7 +71,7 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth
if progress is not None:
progress.update(8)

providers = get_execution_providers_ordered()
providers = get_execution_providers_ordered(ai_gpu_acceleration)
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Providers : {providers}")
Expand Down
31 changes: 22 additions & 9 deletions graxpert/cmdline_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from graxpert.ai_model_handling import ai_model_path_from_version, bge_ai_models_dir, denoise_ai_models_dir, download_version, latest_version, list_local_versions
from graxpert.astroimage import AstroImage
from graxpert.background_extraction import extract_background
from graxpert.denoising import denoise
from graxpert.preferences import Prefs, load_preferences, save_preferences
from graxpert.s3_secrets import bge_bucket_name, denoise_bucket_name
from graxpert.denoising import denoise

user_preferences_filename = os.path.join(user_config_dir(appname="GraXpert"), "preferences.json")

Expand Down Expand Up @@ -85,6 +85,8 @@ def execute(self):
preferences.corr_type = json_prefs["corr_type"]
if "ai_version" in json_prefs:
preferences.ai_version = json_prefs["ai_version"]
if "ai_gpu_acceleration" in json_prefs:
preferences.ai_gpu_acceleration = json_prefs["ai_gpu_acceleration"]

if preferences.interpol_type_option == "Kriging" or preferences.interpol_type_option == "RBF":
downscale_factor = 4
Expand All @@ -109,6 +111,12 @@ def execute(self):
else:
logging.info(f"Using stored correction type {preferences.corr_type}.")

if self.args.gpu_acceleration is not None:
preferences.ai_gpu_acceleration = True if self.args.gpu_acceleration == "true" else False
logging.info(f"Using user-supplied gpu acceleration setting {preferences.ai_gpu_acceleration}.")
else:
logging.info(f"Using stored gpu acceleration setting {preferences.ai_gpu_acceleration}.")

if preferences.interpol_type_option == "AI":
ai_model_path = ai_model_path_from_version(bge_ai_models_dir, self.get_ai_version(preferences))
else:
Expand Down Expand Up @@ -153,6 +161,7 @@ def execute(self):
preferences.spline_order,
preferences.corr_type,
ai_model_path,
ai_gpu_acceleration=preferences.ai_gpu_acceleration,
)
)

Expand Down Expand Up @@ -222,26 +231,34 @@ def execute(self):
preferences.denoise_strength = json_prefs["denoise_strength"]
if "ai_batch_size" in json_prefs:
preferences.ai_batch_size = json_prefs["ai_batch_size"]
if "ai_gpu_acceleration" in json_prefs:
preferences.ai_gpu_acceleration = json_prefs["ai_gpu_acceleration"]

except Exception as e:
logging.exception(e)
logging.shutdown()
sys.exit(1)
else:
preferences = Prefs()

if self.args.denoise_strength is not None:
preferences.denoise_strength = self.args.denoise_strength
logging.info(f"Using user-supplied denoise strength value {preferences.denoise_strength}.")
else:
logging.info(f"Using stored denoise strength value {preferences.denoise_strength}.")

if self.args.ai_batch_size is not None:
preferences.ai_batch_size = self.args.ai_batch_size
logging.info(f"Using user-supplied batch size value {preferences.ai_batch_size}.")
else:
logging.info(f"Using stored batch size value {preferences.ai_batch_size}.")

if self.args.gpu_acceleration is not None:
preferences.ai_gpu_acceleration = True if self.args.gpu_acceleration == "true" else False
logging.info(f"Using user-supplied gpu acceleration setting {preferences.ai_gpu_acceleration}.")
else:
logging.info(f"Using stored gpu acceleration setting {preferences.ai_gpu_acceleration}.")

ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.get_ai_version(preferences))

logging.info(
Expand All @@ -254,12 +271,8 @@ def execute(self):
)

processed_Astro_Image.set_from_array(
denoise(
astro_Image.img_array,
ai_model_path,
preferences.denoise_strength,
batch_size=preferences.ai_batch_size
))
denoise(astro_Image.img_array, ai_model_path, preferences.denoise_strength, batch_size=preferences.ai_batch_size, ai_gpu_acceleration=preferences.ai_gpu_acceleration)
)
processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format())

def get_ai_version(self, prefs):
Expand Down
23 changes: 18 additions & 5 deletions graxpert/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from graxpert.ui.ui_events import UiEvents


def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, threshold=1.0, progress=None):
def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, threshold=1.0, progress=None, ai_gpu_acceleration=True):

logging.info("Starting denoising")

Expand All @@ -26,7 +26,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
batch_size = 2 ** (batch_size).bit_length() // 2 # map batch_size to power of two

input = copy.deepcopy(image)

median = np.median(image[::4, ::4, :], axis=[0, 1])
mad = np.median(np.abs(image[::4, ::4, :] - median), axis=[0, 1])

Expand Down Expand Up @@ -61,20 +61,33 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,

output = copy.deepcopy(image)

providers = get_execution_providers_ordered()
providers = get_execution_providers_ordered(ai_gpu_acceleration)
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Available inference providers : {providers}")
logging.info(f"Used inference providers : {session.get_providers()}")

if "1.0.0" in ai_path or "1.1.0" in ai_path:
model_threshold = 1.0
else:
model_threshold = 10.0

cancel_flag = False

def cancel_listener(event):
nonlocal cancel_flag
cancel_flag = True

eventbus.add_listener(AppEvents.CANCEL_PROCESSING, cancel_listener)

last_progress = 0
for b in range(0, ith * itw + batch_size, batch_size):

if cancel_flag:
logging.info("Denoising cancelled")
eventbus.remove_listener(AppEvents.CANCEL_PROCESSING, cancel_listener)
return None

input_tiles = []
input_tile_copies = []
for t_idx in range(0, batch_size):
Expand Down Expand Up @@ -141,6 +154,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
cached_denoised_image = output
output = blend_images(input, output, strength, threshold, median, mad)

eventbus.remove_listener(AppEvents.CANCEL_PROCESSING, cancel_listener)
logging.info("Finished denoising")

return output
Expand All @@ -153,7 +167,6 @@ def blend_images(original_image, denoised_image, strength, threshold, median, ma
return np.clip(blend, 0, 1)



def reset_cached_denoised_image(event):
global cached_denoised_image
cached_denoised_image = None
Expand Down
2 changes: 2 additions & 0 deletions graxpert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def collect_available_versions(ai_models_dir, bucket_name):
def bge_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):
return version_type(bge_ai_models_dir, bge_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$"))


def denoise_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):
return version_type(denoise_ai_models_dir, denoise_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$"))

Expand Down Expand Up @@ -205,6 +206,7 @@ def main():
type=str,
help="Allows GraXpert commandline to run all extraction methods based on a preferences file that contains background grid points",
)
parser.add_argument("-gpu", "--gpu_acceleration", type=str, choices=["true", "false"], default=None, help="Set to 'false' in order to disable gpu acceleration during AI inference.")
parser.add_argument("-v", "--version", action="version", version=f"GraXpert version: {graxpert_version} release: {graxpert_release}")

bge_parser = argparse.ArgumentParser("GraXpert Background Extraction", parents=[parser], description="GraXpert, the astronomical background extraction tool")
Expand Down
1 change: 1 addition & 0 deletions graxpert/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Prefs:
denoise_strength: float = 0.5
denoise_threshold: float = 10.0
ai_batch_size: int = 4
ai_gpu_acceleration: bool = True


def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs:
Expand Down
4 changes: 4 additions & 0 deletions graxpert/ui/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,22 @@ def on_calculate_end(self, event=None):

def on_denoise_begin(self, event=None):
self.dynamic_progress_frame.text.set(_("Denoising"))
self.dynamic_progress_frame.cancellable = True
self.show_progress_frame(True)

def on_denoise_progress(self, event=None):
self.dynamic_progress_frame.update_progress(event["progress"])

def on_denoise_success(self, event=None):
self.dynamic_progress_frame.cancellable = False
if not "Denoised" in self.display_options:
self.display_options.append("Denoised")
self.display_menu.grid_forget()
self.display_menu = CTkOptionMenu(self, variable=self.display_type, values=self.display_options)
self.display_menu.grid(column=0, row=0, sticky=tk.N)

def on_denoise_end(self, event=None):
self.dynamic_progress_frame.cancellable = False
self.dynamic_progress_frame.text.set("")
self.dynamic_progress_frame.variable.set(0.0)
self.show_progress_frame(False)
Expand Down Expand Up @@ -497,6 +500,7 @@ def show_loading_frame(self, show):

def show_progress_frame(self, show):
if show:
self.dynamic_progress_frame.place_children()
self.dynamic_progress_frame.grid(column=0, row=0, rowspan=2)
else:
self.dynamic_progress_frame.grid_forget()
Expand Down
6 changes: 3 additions & 3 deletions graxpert/ui/left_menu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tkinter as tk

from customtkinter import StringVar, ThemeManager
from customtkinter import ThemeManager

import graxpert.ui.tooltip as tooltip
from graxpert.application.app import graxpert
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(self, parent, **kwargs):
self.denoise_strength = tk.DoubleVar()
self.denoise_strength.set(graxpert.prefs.denoise_strength)
self.denoise_strength.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_STRENGTH_CHANGED, {"denoise_strength": self.denoise_strength.get()}))

self.denoise_threshold = tk.DoubleVar()
self.denoise_threshold.set(graxpert.prefs.denoise_threshold)
self.denoise_threshold.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_THRESHOLD_CHANGED, {"denoise_threshold": self.denoise_threshold.get()}))
Expand All @@ -251,7 +251,7 @@ def create_children(self):
self.sub_frame, width=default_label_width, variable_name=_("Denoise Strength"), variable=self.denoise_strength, min_value=0.0, max_value=1.0, precision=2
)
tooltip.Tooltip(self.denoise_strength_slider, text=tooltip.denoise_strength_text)

self.denoise_threshold_slider = ValueSlider(
self.sub_frame, width=default_label_width, variable_name=_("Denoise Threshold"), variable=self.denoise_threshold, min_value=0.1, max_value=10.0, precision=1
)
Expand Down
Loading

0 comments on commit 3b9ebe6

Please sign in to comment.