Skip to content

Commit

Permalink
Merge branch 'feature/denoising' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
schmelly committed Mar 29, 2024
2 parents f3600a9 + 26e8247 commit bbee341
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 54 deletions.
4 changes: 2 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"configurations": [
{
"name": "Run GraXpert",
"type": "python",
"type": "debugpy",
"request": "launch",
"module": "graxpert.main",
"justMyCode": true
"justMyCode": true,
}
]
}
36 changes: 20 additions & 16 deletions graxpert/application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ def initialize(self):
eventbus.add_listener(AppEvents.BGE_AI_VERSION_CHANGED, self.on_bge_ai_version_changed)
eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, self.on_denoise_ai_version_changed)
eventbus.add_listener(AppEvents.SCALING_CHANGED, self.on_scaling_changed)
eventbus.add_listener(AppEvents.AI_BATCH_SIZE_CHANGED, self.on_ai_batch_size_changed)

# event handling
def on_ai_batch_size_changed(self, event):
self.prefs.ai_batch_size = event["ai_batch_size"]

def on_bge_ai_version_changed(self, event):
self.prefs.bge_ai_version = event["bge_ai_version"]

Expand Down Expand Up @@ -133,9 +137,9 @@ def on_calculate_request(self, event=None):

try:
self.prefs.images_linked_option = False

img_array_to_be_processed = np.copy(self.images.get("Original").img_array)

background = AstroImage()
background.set_from_array(
extract_background(
Expand Down Expand Up @@ -166,7 +170,7 @@ def on_calculate_request(self, event=None):

self.images.set("Gradient-Corrected", gradient_corrected)
self.images.set("Background", background)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.CALCULATE_SUCCESS)
Expand Down Expand Up @@ -312,10 +316,10 @@ def on_save_as_changed(self, event):

def on_smoothing_changed(self, event):
self.prefs.smoothing_option = event["smoothing_option"]

def on_denoise_strength_changed(self, event):
self.prefs.denoise_strength = event["denoise_strength"]

def on_denoise_request(self, event):
if self.images.get("Original") is None:
messagebox.showerror("Error", _("Please load your picture first."))
Expand All @@ -330,12 +334,12 @@ def on_denoise_request(self, event):

try:
img_array_to_be_processed = np.copy(self.images.get("Original").img_array)
if (self.images.get("Gradient-Corrected") is not None):
if self.images.get("Gradient-Corrected") is not None:
img_array_to_be_processed = np.copy(self.images.get("Gradient-Corrected").img_array)

self.prefs.images_linked_option = True
ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, progress=progress)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, progress=progress)

denoised = AstroImage()
denoised.set_from_array(imarray)
Expand All @@ -345,9 +349,9 @@ def on_denoise_request(self, event):
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)

denoised.copy_metadata(self.images.get("Original"))

self.images.set("Denoised", denoised)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.DENOISE_SUCCESS)
Expand Down Expand Up @@ -375,13 +379,13 @@ def on_save_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
if (self.images.get("Denoised") is not None):
if self.images.get("Denoised") is not None:
self.images.get("Denoised").save(dir, self.prefs.saveas_option)
elif (self.images.get("Gradient-Corrected") is not None):
elif self.images.get("Gradient-Corrected") is not None:
self.images.get("Gradient-Corrected").save(dir, self.prefs.saveas_option)
else:
self.images.get("Original").save(dir, self.prefs.saveas_option)

except Exception as e:
logging.exception(e)
eventbus.emit(AppEvents.SAVE_ERROR)
Expand Down Expand Up @@ -425,13 +429,13 @@ def on_save_stretched_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
if (self.images.get("Denoised") is not None):
if self.images.get("Denoised") is not None:
self.images.get("Denoised").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
elif (self.images.get("Gradient-Corrected") is not None):
elif self.images.get("Gradient-Corrected") is not None:
self.images.get("Gradient-Corrected").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
else:
self.images.get("Original").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))

except Exception as e:
eventbus.emit(AppEvents.SAVE_ERROR)
logging.exception(e)
Expand Down
1 change: 1 addition & 0 deletions graxpert/application/app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ class AppEvents(Enum):
CORRECTION_TYPE_CHANGED = auto()
LANGUAGE_CHANGED = auto()
SCALING_CHANGED = auto()
AI_BATCH_SIZE_CHANGED = auto()
52 changes: 33 additions & 19 deletions graxpert/background_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,39 @@
from concurrent.futures import wait
from multiprocessing import shared_memory

import cv2
import numpy as np
import onnxruntime as ort
from astropy.stats import sigma_clipped_stats
from pykrige.ok import OrdinaryKriging
from scipy import interpolate, linalg
from skimage.filters import gaussian
from skimage.transform import resize

from graxpert.ai_model_handling import get_execution_providers_ordered
from graxpert.mp_logging import get_logging_queue, worker_configurer
from graxpert.parallel_processing import executor
from graxpert.radialbasisinterpolation import RadialBasisInterpolation
from graxpert.ai_model_handling import get_execution_providers_ordered


def gaussian_kernel(sigma=1.0, truncate=4.0): # follow simulate skimage.filters.gaussian defaults
ksize = round(sigma * truncate) - 1 if round(sigma * truncate) % 2 == 0 else round(sigma * truncate)
return (ksize, ksize)


def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None):

shm_imarray = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes)
shm_background = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes)
imarray = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_imarray.buf)
background = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_background.buf)
np.copyto(imarray, in_imarray)
num_colors = in_imarray.shape[-1]

num_colors = imarray.shape[-1]
shm_imarray = None
shm_background = None

if interpolation_type == "AI":
imarray = np.ndarray(in_imarray.shape, dtype=np.float32)
background = np.ndarray(in_imarray.shape, dtype=np.float32)
np.copyto(imarray, in_imarray)

# Shrink and pad to avoid artifacts on borders
padding = 8
imarray_shrink = resize(imarray, output_shape=(256 - 2 * padding, 256 - 2 * padding))
imarray_shrink = cv2.resize(imarray, dsize=(256 - 2 * padding, 256 - 2 * padding), interpolation=cv2.INTER_LINEAR)
imarray_shrink = np.pad(imarray_shrink, ((padding, padding), (padding, padding), (0, 0)), mode="edge")

median = []
Expand Down Expand Up @@ -77,7 +82,7 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth

if smoothing != 0:
sigma = smoothing * 20
background = gaussian(image=background, sigma=sigma, channel_axis=-1)
background = cv2.GaussianBlur(background, ksize=gaussian_kernel(sigma), sigmaX=sigma, sigmaY=sigma)

if progress is not None:
progress.update(8)
Expand All @@ -96,13 +101,20 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth
if progress is not None:
progress.update(8)

background = gaussian(background, sigma=3.0) # To simulate tensorflow method='gaussian'
background = resize(background, output_shape=(in_imarray.shape[0], in_imarray.shape[1]))
sigma = 3.0
background = cv2.GaussianBlur(background, ksize=gaussian_kernel(sigma), sigmaX=sigma, sigmaY=sigma)
background = cv2.resize(background, dsize=(in_imarray.shape[1], in_imarray.shape[0]), interpolation=cv2.INTER_LINEAR)

if progress is not None:
progress.update(8)

else:
shm_imarray = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes)
shm_background = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes)
imarray = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_imarray.buf)
background = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_background.buf)
np.copyto(imarray, in_imarray)

x_sub = np.array(background_points[:, 0], dtype=int)
y_sub = np.array(background_points[:, 1], dtype=int)

Expand Down Expand Up @@ -154,15 +166,17 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth
imarray[:, :, :] = imarray.clip(min=0.0, max=1.0)

in_imarray[:] = imarray[:]
background = np.copy(background)

if progress is not None:
progress.update(8)

shm_imarray.close()
shm_background.close()
shm_imarray.unlink()
shm_background.unlink()
if shm_imarray is not None:
shm_imarray.close()
shm_imarray.unlink()
if shm_background is not None:
background = np.copy(background)
shm_background.close()
shm_background.unlink()

return background

Expand Down Expand Up @@ -258,7 +272,7 @@ def interpol(shm_imarray_name, shm_background_name, c, x_sub, y_sub, shape, kind
return

if downscale_factor != 1:
result = resize(result, shape, preserve_range=True)
result = cv2.resize(src=result, dsize=(shape[1], shape[0]), interpolation=cv2.INTER_LINEAR)

background[:, :, c] = result
except Exception as e:
Expand Down
11 changes: 10 additions & 1 deletion graxpert/cmdline_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def execute(self):
preferences.ai_version = json_prefs["ai_version"]
if "denoise_strength" in json_prefs:
preferences.denoise_strength = json_prefs["denoise_strength"]
if "ai_batch_size" in json_prefs:
preferences.ai_batch_size = json_prefs["ai_batch_size"]

except Exception as e:
logging.exception(e)
Expand All @@ -233,6 +235,12 @@ def execute(self):
logging.info(f"Using user-supplied denoise strength value {preferences.denoise_strength}.")
else:
logging.info(f"Using stored denoise strength value {preferences.denoise_strength}.")

if self.args.ai_batch_size is not None:
preferences.ai_batch_size = self.args.ai_batch_size
logging.info(f"Using user-supplied batch size value {preferences.ai_batch_size}.")
else:
logging.info(f"Using stored batch size value {preferences.ai_batch_size}.")

ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.get_ai_version(preferences))

Expand All @@ -249,7 +257,8 @@ def execute(self):
denoise(
astro_Image.img_array,
ai_model_path,
preferences.denoise_strength
preferences.denoise_strength,
batch_size=preferences.ai_batch_size
))
processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format())

Expand Down
74 changes: 59 additions & 15 deletions graxpert/denoising.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import copy
import logging
import time

import numpy as np
import onnxruntime as ort

from graxpert.ai_model_handling import get_execution_providers_ordered


def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None):
def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, progress=None):

logging.info("Starting denoising")

input = copy.deepcopy(image)
num_colors = image.shape[-1]

if num_colors == 1:
image = np.array([image[:, :, 0], image[:, :, 0], image[:, :, 0]])
image = np.moveaxis(image, 0, -1)
Expand Down Expand Up @@ -42,33 +45,72 @@ def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None
output = copy.deepcopy(image)

providers = get_execution_providers_ordered()
session = ort.InferenceSession(ai_path, providers=providers)
ort_options = ort.SessionOptions()
ort_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
session = ort.InferenceSession(ai_path, providers=providers, sess_options=ort_options)

logging.info(f"Available inference providers : {providers}")
logging.info(f"Used inference providers : {session.get_providers()}")

logging.info(f"Providers : {providers}")
logging.info(f"Used providers : {session.get_providers()}")
last_progress = 0
for b in range(0, ith * itw + batch_size, batch_size):

input_tiles = []
input_tile_copies = []
for t_idx in range(0, batch_size):

index = b + t_idx
i = index % ith
j = index // ith

if i >= ith or j >= itw:
break

for i in range(ith):
for j in range(itw):
x = stride * i
y = stride * j

tile = image[x : x + window_size, y : y + window_size, :]
tile = (tile - median) / mad * 0.04
tile_copy = tile.copy()
input_tile_copies.append(np.copy(tile))
tile = np.clip(tile, -1.0, 1.0)

tile = np.expand_dims(tile, axis=0)
tile = np.array(session.run(None, {"gen_input_image": tile})[0][0])
input_tiles.append(tile)

if not input_tiles:
continue

input_tiles = np.array(input_tiles)

tile = np.where(tile_copy < 0.95, tile, tile_copy)
output_tiles = []
session_result = session.run(None, {"gen_input_image": input_tiles})[0]
for e in session_result:
output_tiles.append(e)

output_tiles = np.array(output_tiles)

for t_idx, tile in enumerate(output_tiles):

index = b + t_idx
i = index % ith
j = index // ith

if i >= ith or j >= itw:
break

x = stride * i
y = stride * j
tile = np.where(input_tile_copies[t_idx] < 0.95, tile, input_tile_copies[t_idx])
tile = tile / 0.04 * mad + median
tile = tile[offset : offset + stride, offset : offset + stride, :]
output[x + offset : stride * (i + 1) + offset, y + offset : stride * (j + 1) + offset, :] = tile

if progress is not None:
progress.update(int(100 / ith))
else:
logging.info(f"Progress: {int(i/ith*100)}%")
p = int(b / (ith * itw + batch_size) * 100)
if p > last_progress:
if progress is not None:
progress.update(p - last_progress)
else:
logging.info(f"Progress: {p}%")
last_progress = p

output = np.clip(output, 0, 1)
output = output[offset : H + offset, offset : W + offset, :]
Expand All @@ -78,4 +120,6 @@ def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None
output = np.array([output[:, :, 0]])
output = np.moveaxis(output, 0, -1)

logging.info("Finished denoising")

return output
Loading

0 comments on commit bbee341

Please sign in to comment.