Skip to content

Commit

Permalink
Merge branch 'michael-r-feature/cli-allow-custom-preferences' into de…
Browse files Browse the repository at this point in the history
…velop
  • Loading branch information
schmelly committed Dec 30, 2023
2 parents 5db3e26 + e9a69f3 commit 07e9bab
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 50 deletions.
120 changes: 76 additions & 44 deletions graxpert/CommandLineTool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import logging
import os
import sys

import numpy as np
from appdirs import user_config_dir

from graxpert.ai_model_handling import (ai_model_path_from_version,
download_version, latest_version,
list_local_versions)
from graxpert.ai_model_handling import ai_model_path_from_version, download_version, latest_version, list_local_versions
from graxpert.astroimage import AstroImage
from graxpert.background_extraction import extract_background
from graxpert.preferences import load_preferences, save_preferences
from graxpert.preferences import Prefs, load_preferences, save_preferences

user_preferences_filename = os.path.join(user_config_dir(appname="GraXpert"), "preferences.json")


class CommandLineTool:
Expand All @@ -19,98 +21,128 @@ def __init__(self, args):
def execute(self):
astro_Image = AstroImage(do_update_display=False)
astro_Image.set_from_file(self.args.filename, None, None)

processed_Astro_Image = AstroImage(do_update_display=False)
background_Astro_Image = AstroImage(do_update_display=False)

processed_Astro_Image.fits_header = astro_Image.fits_header
background_Astro_Image.fits_header = astro_Image.fits_header

ai_version = self.get_ai_version()

downscale_factor = 1

if self.args.preferences_file:
preferences = Prefs()
preferences.interpol_type_option = "AI"
try:
preferences_file = os.path.abspath(self.args.preferences_file)
if os.path.isfile(preferences_file):
with open(preferences_file, "r") as f:
json_prefs = json.load(f)
preferences.background_points = json_prefs["background_points"]
preferences.sample_size = json_prefs["sample_size"]
preferences.spline_order = json_prefs["spline_order"]
preferences.RBF_kernel = json_prefs["RBF_kernel"]
preferences.interpol_type_option = json_prefs["interpol_type_option"]
preferences.ai_version = json_prefs["ai_version"]

if preferences.interpol_type_option == "Kriging" or preferences.interpol_type_option == "RBF":
downscale_factor = 4

except Exception as e:
logging.exception(e)
logging.shutdown()
sys.exit(1)
else:
preferences = load_preferences(user_preferences_filename)
preferences.interpol_type_option = "AI"

if self.args.smoothing:
preferences.smoothing_option = self.args.smoothing
logging.info(f"Using user-supplied smoothing value {preferences.smoothing_option}.")

if self.args.correction:
preferences.corr_type = self.args.correction
logging.info(f"Using user-supplied correction type {preferences.corr_type}.")

if preferences.interpol_type_option == "AI":
ai_model_path = ai_model_path_from_version(self.get_ai_version(preferences))
else:
ai_model_path = None

background_Astro_Image.set_from_array(
extract_background(
astro_Image.img_array,
[],
"AI",
self.args.smoothing,
1,
50,
"RBF",
0,
self.args.correction,
ai_model_path_from_version(ai_version),
np.array(preferences.background_points),
preferences.interpol_type_option,
preferences.smoothing_option,
downscale_factor,
preferences.sample_size,
preferences.RBF_kernel,
preferences.spline_order,
preferences.corr_type,
ai_model_path,
)
)

processed_Astro_Image.set_from_array(astro_Image.img_array)

processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format())
if (self.args.bg):
if self.args.bg:
background_Astro_Image.save(self.get_background_save_path(), self.get_output_file_format())


def get_ai_version(self):
prefs_filename = os.path.join(
user_config_dir(appname="GraXpert"), "preferences.json"
)
prefs = load_preferences(prefs_filename)
def get_ai_version(self, prefs):
user_preferences = load_preferences(user_preferences_filename)

ai_version = None
if self.args.ai_version:
ai_version = self.args.ai_version
logging.info(f"Using user-supplied AI version {ai_version}.")
else:
ai_version = prefs.ai_version

if ai_version is None:
ai_version = latest_version()

logging.info(
"using AI version {}. you can change this by providing the argument '-ai_version'".format(
ai_version
)
)
logging.info(f"Using AI version {ai_version}. You can overwrite this by providing the argument '-ai_version'")

if not ai_version in [v["version"] for v in list_local_versions()]:
try:
logging.info(
"AI version {} not found locally, downloading...".format(ai_version)
)
logging.info(f"AI version {ai_version} not found locally, downloading...")
download_version(ai_version)
logging.info("download successful".format(ai_version))
logging.info("download successful")
except Exception as e:
logging.exception(e)
logging.shutdown()
sys.exit(1)

prefs.ai_version = ai_version
save_preferences(prefs_filename, prefs)
user_preferences.ai_version = ai_version
save_preferences(user_preferences_filename, user_preferences)

return ai_version

def get_output_file_ending(self):
file_ending = os.path.splitext(self.args.filename)[-1]

if file_ending.lower() == ".xisf":
return ".xisf"
else:
return ".fits"

def get_output_file_format(self):
output_file_ending = self.get_output_file_ending()
if (output_file_ending) == ".xisf":
return "32 bit XISF"
else:
return "32 bit Fits"

def get_save_path(self):
if (self.args.output is not None):
if self.args.output is not None:
base_path = os.path.dirname(self.args.filename)
output_file_name = self.args.output + self.get_output_file_ending()
return os.path.join(base_path, output_file_name)

else:
return os.path.splitext(self.args.filename)[0] + "_GraXpert" + self.get_output_file_ending()

def get_background_save_path(self):
save_path = self.get_save_path()
return os.path.splitext(save_path)[0] + "_background" + self.get_output_file_ending()
22 changes: 16 additions & 6 deletions graxpert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from packaging import version

from graxpert.version import release as graxpert_release, version as graxpert_version
from graxpert.ai_model_handling import list_local_versions, list_remote_versions
from graxpert.mp_logging import configure_logging
from graxpert.version import release as graxpert_release
from graxpert.version import version as graxpert_version

available_local_versions = []
available_remote_versions = []
Expand Down Expand Up @@ -177,22 +178,31 @@ def main():
)
parser.add_argument("-correction", "--correction", nargs="?", required=False, default="Subtraction", choices=["Subtraction", "Division"], type=str, help="Subtraction or Division")
parser.add_argument("-smoothing", "--smoothing", nargs="?", required=False, default=0.0, type=float, help="Strength of smoothing between 0 and 1")
parser.add_argument(
"-preferences_file",
"--preferences_file",
nargs="?",
required=False,
default="",
type=str,
help="Allows GraXpert commandline to run all extraction methods based on a preferences file that contains background grid points",
)
parser.add_argument("-output", "--output", nargs="?", required=False, type=str, help="Filename of the processed image")
parser.add_argument("-bg", "--bg", required=False, action="store_true", help="Also save the background model")
parser.add_argument("-cli", "--cli", required=False, action="store_true", help="Has to be added when using the command line integration of GraXpert")
parser.add_argument('-v', '--version', action='version', version="GraXpert version: " + graxpert_version + " release: " + graxpert_release)

parser.add_argument("-v", "--version", action="version", version=f"GraXpert version: {graxpert_version} release: {graxpert_release}")

args = parser.parse_args()
if (args.cli):

if args.cli:
from graxpert.CommandLineTool import CommandLineTool

clt = CommandLineTool(args)
clt.execute()
logging.shutdown()
else:
ui_main()

else:
ui_main()

Expand Down

0 comments on commit 07e9bab

Please sign in to comment.