Skip to content

Commit

Permalink
Umbriel
Browse files Browse the repository at this point in the history
  • Loading branch information
Steffenhir committed May 3, 2024
2 parents 9d3a417 + e767742 commit 593cddd
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 46 deletions.
10 changes: 7 additions & 3 deletions graxpert/ai_model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import shutil
import zipfile

import onnxruntime as ort
from appdirs import user_data_dir
from minio import Minio
from packaging import version
import onnxruntime as ort

try:
from graxpert.s3_secrets import endpoint, ro_access_key, ro_secret_key
Expand Down Expand Up @@ -166,7 +166,11 @@ def validate_local_version(ai_models_dir, local_version):
return os.path.isfile(os.path.join(ai_models_dir, local_version, "model.onnx"))


def get_execution_providers_ordered():
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
def get_execution_providers_ordered(gpu_acceleration=True):

if gpu_acceleration:
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
else:
supported_providers = ["CPUExecutionProvider"]

return [provider for provider in supported_providers if provider in ort.get_available_providers()]
Empty file.
38 changes: 26 additions & 12 deletions graxpert/application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,15 @@ def initialize(self):
eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, self.on_denoise_ai_version_changed)
eventbus.add_listener(AppEvents.SCALING_CHANGED, self.on_scaling_changed)
eventbus.add_listener(AppEvents.AI_BATCH_SIZE_CHANGED, self.on_ai_batch_size_changed)
eventbus.add_listener(AppEvents.AI_GPU_ACCELERATION_CHANGED, self.on_ai_gpu_acceleration_changed)

# event handling
def on_ai_batch_size_changed(self, event):
self.prefs.ai_batch_size = event["ai_batch_size"]

def on_ai_gpu_acceleration_changed(self, event):
self.prefs.ai_gpu_acceleration = event["ai_gpu_acceleration"]

def on_bge_ai_version_changed(self, event):
self.prefs.bge_ai_version = event["bge_ai_version"]

Expand Down Expand Up @@ -154,6 +158,7 @@ def on_calculate_request(self, event=None):
self.prefs.corr_type,
ai_model_path_from_version(bge_ai_models_dir, self.prefs.bge_ai_version),
progress,
self.prefs.ai_gpu_acceleration,
)
)

Expand Down Expand Up @@ -187,7 +192,7 @@ def on_calculate_request(self, event=None):
def on_change_saturation_request(self, event):
if self.images.get("Original") is None:
return

self.prefs.saturation = event["saturation"]

eventbus.emit(AppEvents.CHANGE_SATURATION_BEGIN)
Expand Down Expand Up @@ -342,23 +347,32 @@ def on_denoise_request(self, event):

self.prefs.images_linked_option = True
ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, progress=progress)
imarray = denoise(
img_array_to_be_processed,
ai_model_path,
self.prefs.denoise_strength,
batch_size=self.prefs.ai_batch_size,
progress=progress,
ai_gpu_acceleration=self.prefs.ai_gpu_acceleration,
)

denoised = AstroImage()
denoised.set_from_array(imarray)
if imarray is not None:

# Update fits header and metadata
background_mean = np.mean(self.images.get("Original").img_array)
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)
denoised = AstroImage()
denoised.set_from_array(imarray)

denoised.copy_metadata(self.images.get("Original"))
# Update fits header and metadata
background_mean = np.mean(self.images.get("Original").img_array)
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)

self.images.set("Denoised", denoised)
denoised.copy_metadata(self.images.get("Original"))

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)
self.images.set("Denoised", denoised)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.DENOISE_SUCCESS)
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"})
eventbus.emit(AppEvents.DENOISE_SUCCESS)
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"})

except Exception as e:
logging.exception(e)
Expand Down
4 changes: 4 additions & 0 deletions graxpert/application/app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class AppEvents(Enum):
CALCULATE_ERROR = auto()
# denoising
DENOISE_STRENGTH_CHANGED = auto()
DENOISE_THRESHOLD_CHANGED = auto()
DENOISE_REQUEST = auto()
DENOISE_BEGIN = auto()
DENOISE_PROGRESS = auto()
Expand Down Expand Up @@ -77,3 +78,6 @@ class AppEvents(Enum):
LANGUAGE_CHANGED = auto()
SCALING_CHANGED = auto()
AI_BATCH_SIZE_CHANGED = auto()
AI_GPU_ACCELERATION_CHANGED = auto()
# process control
CANCEL_PROCESSING = auto()
4 changes: 2 additions & 2 deletions graxpert/background_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def gaussian_kernel(sigma=1.0, truncate=4.0): # follow simulate skimage.filters
return (ksize, ksize)


def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None):
def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None, ai_gpu_acceleration=True):

num_colors = in_imarray.shape[-1]

Expand Down Expand Up @@ -71,7 +71,7 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth
if progress is not None:
progress.update(8)

providers = get_execution_providers_ordered()
providers = get_execution_providers_ordered(ai_gpu_acceleration)
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Providers : {providers}")
Expand Down
31 changes: 22 additions & 9 deletions graxpert/cmdline_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from graxpert.ai_model_handling import ai_model_path_from_version, bge_ai_models_dir, denoise_ai_models_dir, download_version, latest_version, list_local_versions
from graxpert.astroimage import AstroImage
from graxpert.background_extraction import extract_background
from graxpert.denoising import denoise
from graxpert.preferences import Prefs, load_preferences, save_preferences
from graxpert.s3_secrets import bge_bucket_name, denoise_bucket_name
from graxpert.denoising import denoise

user_preferences_filename = os.path.join(user_config_dir(appname="GraXpert"), "preferences.json")

Expand Down Expand Up @@ -85,6 +85,8 @@ def execute(self):
preferences.corr_type = json_prefs["corr_type"]
if "ai_version" in json_prefs:
preferences.ai_version = json_prefs["ai_version"]
if "ai_gpu_acceleration" in json_prefs:
preferences.ai_gpu_acceleration = json_prefs["ai_gpu_acceleration"]

if preferences.interpol_type_option == "Kriging" or preferences.interpol_type_option == "RBF":
downscale_factor = 4
Expand All @@ -109,6 +111,12 @@ def execute(self):
else:
logging.info(f"Using stored correction type {preferences.corr_type}.")

if self.args.gpu_acceleration is not None:
preferences.ai_gpu_acceleration = True if self.args.gpu_acceleration == "true" else False
logging.info(f"Using user-supplied gpu acceleration setting {preferences.ai_gpu_acceleration}.")
else:
logging.info(f"Using stored gpu acceleration setting {preferences.ai_gpu_acceleration}.")

if preferences.interpol_type_option == "AI":
ai_model_path = ai_model_path_from_version(bge_ai_models_dir, self.get_ai_version(preferences))
else:
Expand Down Expand Up @@ -153,6 +161,7 @@ def execute(self):
preferences.spline_order,
preferences.corr_type,
ai_model_path,
ai_gpu_acceleration=preferences.ai_gpu_acceleration,
)
)

Expand Down Expand Up @@ -222,26 +231,34 @@ def execute(self):
preferences.denoise_strength = json_prefs["denoise_strength"]
if "ai_batch_size" in json_prefs:
preferences.ai_batch_size = json_prefs["ai_batch_size"]
if "ai_gpu_acceleration" in json_prefs:
preferences.ai_gpu_acceleration = json_prefs["ai_gpu_acceleration"]

except Exception as e:
logging.exception(e)
logging.shutdown()
sys.exit(1)
else:
preferences = Prefs()

if self.args.denoise_strength is not None:
preferences.denoise_strength = self.args.denoise_strength
logging.info(f"Using user-supplied denoise strength value {preferences.denoise_strength}.")
else:
logging.info(f"Using stored denoise strength value {preferences.denoise_strength}.")

if self.args.ai_batch_size is not None:
preferences.ai_batch_size = self.args.ai_batch_size
logging.info(f"Using user-supplied batch size value {preferences.ai_batch_size}.")
else:
logging.info(f"Using stored batch size value {preferences.ai_batch_size}.")

if self.args.gpu_acceleration is not None:
preferences.ai_gpu_acceleration = True if self.args.gpu_acceleration == "true" else False
logging.info(f"Using user-supplied gpu acceleration setting {preferences.ai_gpu_acceleration}.")
else:
logging.info(f"Using stored gpu acceleration setting {preferences.ai_gpu_acceleration}.")

ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.get_ai_version(preferences))

logging.info(
Expand All @@ -254,12 +271,8 @@ def execute(self):
)

processed_Astro_Image.set_from_array(
denoise(
astro_Image.img_array,
ai_model_path,
preferences.denoise_strength,
batch_size=preferences.ai_batch_size
))
denoise(astro_Image.img_array, ai_model_path, preferences.denoise_strength, batch_size=preferences.ai_batch_size, ai_gpu_acceleration=preferences.ai_gpu_acceleration)
)
processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format())

def get_ai_version(self, prefs):
Expand Down
44 changes: 33 additions & 11 deletions graxpert/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from graxpert.ui.ui_events import UiEvents


def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, progress=None):
def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, progress=None, ai_gpu_acceleration=True):

logging.info("Starting denoising")

Expand All @@ -27,9 +27,17 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,

input = copy.deepcopy(image)

median = np.median(image[::4, ::4, :], axis=[0, 1])
mad = np.median(np.abs(image[::4, ::4, :] - median), axis=[0, 1])

if "1.0.0" in ai_path or "1.1.0" in ai_path:
model_threshold = 1.0
else:
model_threshold = 10.0

global cached_denoised_image
if cached_denoised_image is not None:
return blend_images(input, cached_denoised_image, strength)
return blend_images(input, cached_denoised_image, strength, model_threshold, median, mad)

num_colors = image.shape[-1]
if num_colors == 1:
Expand All @@ -56,20 +64,30 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
image = np.concatenate((image, image[:, (w - offset) :, :]), axis=1)
image = np.concatenate((image[:, :offset, :], image), axis=1)

median = np.median(image[::4, ::4, :], axis=[0, 1])
mad = np.median(np.abs(image[::4, ::4, :] - median), axis=[0, 1])

output = copy.deepcopy(image)

providers = get_execution_providers_ordered()
providers = get_execution_providers_ordered(ai_gpu_acceleration)
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Available inference providers : {providers}")
logging.info(f"Used inference providers : {session.get_providers()}")

cancel_flag = False

def cancel_listener(event):
nonlocal cancel_flag
cancel_flag = True

eventbus.add_listener(AppEvents.CANCEL_PROCESSING, cancel_listener)

last_progress = 0
for b in range(0, ith * itw + batch_size, batch_size):

if cancel_flag:
logging.info("Denoising cancelled")
eventbus.remove_listener(AppEvents.CANCEL_PROCESSING, cancel_listener)
return None

input_tiles = []
input_tile_copies = []
for t_idx in range(0, batch_size):
Expand All @@ -87,7 +105,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
tile = image[x : x + window_size, y : y + window_size, :]
tile = (tile - median) / mad * 0.04
input_tile_copies.append(np.copy(tile))
tile = np.clip(tile, -1.0, 1.0)
tile = np.clip(tile, -model_threshold, model_threshold)

input_tiles.append(tile)

Expand All @@ -114,7 +132,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,

x = stride * i
y = stride * j
tile = np.where(input_tile_copies[t_idx] < 0.95, tile, input_tile_copies[t_idx])
tile = np.where(input_tile_copies[t_idx] < model_threshold, tile, input_tile_copies[t_idx])
tile = tile / 0.04 * mad + median
tile = tile[offset : offset + stride, offset : offset + stride, :]
output[x + offset : stride * (i + 1) + offset, y + offset : stride * (j + 1) + offset, :] = tile
Expand All @@ -134,15 +152,18 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128,
output = np.moveaxis(output, 0, -1)

cached_denoised_image = output
output = blend_images(input, output, strength)
output = blend_images(input, output, strength, model_threshold, median, mad)

eventbus.remove_listener(AppEvents.CANCEL_PROCESSING, cancel_listener)
logging.info("Finished denoising")

return output


def blend_images(original_image, denoised_image, strength):
blend = denoised_image * strength + original_image * (1 - strength)
def blend_images(original_image, denoised_image, strength, threshold, median, mad):
threshold = threshold / 0.04 * mad + median
blend = np.where(original_image < threshold, denoised_image, original_image)
blend = blend * strength + original_image * (1 - strength)
return np.clip(blend, 0, 1)


Expand All @@ -155,3 +176,4 @@ def reset_cached_denoised_image(event):
eventbus.add_listener(AppEvents.LOAD_IMAGE_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(AppEvents.CALCULATE_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(UiEvents.APPLY_CROP_REQUEST, reset_cached_denoised_image)
eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, reset_cached_denoised_image)
2 changes: 2 additions & 0 deletions graxpert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def collect_available_versions(ai_models_dir, bucket_name):
def bge_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):
return version_type(bge_ai_models_dir, bge_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$"))


def denoise_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")):
return version_type(denoise_ai_models_dir, denoise_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$"))

Expand Down Expand Up @@ -205,6 +206,7 @@ def main():
type=str,
help="Allows GraXpert commandline to run all extraction methods based on a preferences file that contains background grid points",
)
parser.add_argument("-gpu", "--gpu_acceleration", type=str, choices=["true", "false"], default=None, help="Set to 'false' in order to disable gpu acceleration during AI inference.")
parser.add_argument("-v", "--version", action="version", version=f"GraXpert version: {graxpert_version} release: {graxpert_release}")

bge_parser = argparse.ArgumentParser("GraXpert Background Extraction", parents=[parser], description="GraXpert, the astronomical background extraction tool")
Expand Down
1 change: 1 addition & 0 deletions graxpert/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Prefs:
graxpert_version: AnyStr = graxpert_version
denoise_strength: float = 0.5
ai_batch_size: int = 4
ai_gpu_acceleration: bool = True


def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs:
Expand Down
4 changes: 4 additions & 0 deletions graxpert/ui/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,22 @@ def on_calculate_end(self, event=None):

def on_denoise_begin(self, event=None):
self.dynamic_progress_frame.text.set(_("Denoising"))
self.dynamic_progress_frame.cancellable = True
self.show_progress_frame(True)

def on_denoise_progress(self, event=None):
self.dynamic_progress_frame.update_progress(event["progress"])

def on_denoise_success(self, event=None):
self.dynamic_progress_frame.cancellable = False
if not "Denoised" in self.display_options:
self.display_options.append("Denoised")
self.display_menu.grid_forget()
self.display_menu = CTkOptionMenu(self, variable=self.display_type, values=self.display_options)
self.display_menu.grid(column=0, row=0, sticky=tk.N)

def on_denoise_end(self, event=None):
self.dynamic_progress_frame.cancellable = False
self.dynamic_progress_frame.text.set("")
self.dynamic_progress_frame.variable.set(0.0)
self.show_progress_frame(False)
Expand Down Expand Up @@ -497,6 +500,7 @@ def show_loading_frame(self, show):

def show_progress_frame(self, show):
if show:
self.dynamic_progress_frame.place_children()
self.dynamic_progress_frame.grid(column=0, row=0, rowspan=2)
else:
self.dynamic_progress_frame.grid_forget()
Expand Down
Loading

0 comments on commit 593cddd

Please sign in to comment.