In [1]:
import argparse
import os
import sys
from typing import Optional
import glob
import git
import cv2
import matplotlib.pyplot as plt
import numpy as np
import rawpy
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from PySide6.QtCore import Slot, QObject, QSize
from PySide6.QtQuick import QQuickImageProvider
from PySide6.QtGui import QGuiApplication, QImage, QPixmap, QColor
from PySide6.QtWidgets import QFileDialog
from scipy.signal import find_peaks
from tqdm import tqdm
import imageio
import ipywidgets as widgets 

In [2]:
from PySide6.QtCore import Qt, QAbstractItemModel, QModelIndex

class TreeModel(QAbstractItemModel):
    def __init__(self, headers, data, parent=None):
        super(TreeModel, self).__init__(parent)
        rootData = [header for header in headers]
        self.rootItem = TreeNode(rootData)
        indent = -1
        self.parents = [self.rootItem]
        self.indentations = [0]
        self.createData(data, indent)

    def createData(self, data, indent):
        if type(data) == dict:
            indent += 1
            position = 4 * indent
            for dict_keys, dict_values in data.items():
                if position > self.indentations[-1]:
                    if self.parents[-1].childCount() > 0:
                        self.parents.append(self.parents[-1].child(self.parents[-1].childCount() - 1))
                        self.indentations.append(position)
                else:
                    while position < self.indentations[-1] and len(self.parents) > 0:
                        self.parents.pop()
                        self.indentations.pop()
                parent = self.parents[-1]
                parent.insertChildren(parent.childCount(), 1, parent.columnCount())
                parent.child(parent.childCount() - 1).setData(0, dict_keys)
                if type(dict_values) != dict:
                    parent.child(parent.childCount() - 1).setData(1, str(dict_values))
                self.createData(dict_values, indent)

    def index(self, row, column, index=QModelIndex()):
        if not self.hasIndex(row, column, index):
            return QModelIndex()
        if not index.isValid():
            item = self.rootItem
        else:
            item = index.internalPointer()
        child = item.child(row)
        if child:
            return self.createIndex(row, column, child)
        return QModelIndex()

    def parent(self, index):
        if not index.isValid():
            return QModelIndex()
        item = index.internalPointer()
        if not item:
            return QModelIndex()
        parent = item.parentItem
        if parent == self.rootItem:
            return QModelIndex()
        else:
            return self.createIndex(parent.childNumber(), 0, parent)

    def rowCount(self, index=QModelIndex()):
        if index.isValid():
            parent = index.internalPointer()
        else:
            parent = self.rootItem
        return parent.childCount()

    def columnCount(self, index=QModelIndex()):
        return self.rootItem.columnCount()

    def data(self, index, role=Qt.DisplayRole):
        if index.isValid() and role == Qt.DisplayRole:
            return index.internalPointer().data(index.column())
        elif not index.isValid():
            return self.rootItem.data(index.column())

    def headerData(self, section, orientation, role=Qt.DisplayRole):
        if orientation == Qt.Horizontal and role == Qt.DisplayRole:
            return self.rootItem.data(section)

class TreeNode(object):
    def __init__(self, data, parent=None):
        self.parentItem = parent
        self.itemData = data
        self.children = []

    def child(self, row):
        return self.children[row]

    def childCount(self):
        return len(self.children)

    def childNumber(self):
        if self.parentItem is not None:
            return self.parentItem.children.index(self)

    def columnCount(self):
        return len(self.itemData)

    def data(self, column):
        return self.itemData[column]

    def insertChildren(self, position, count, columns):
        if position < 0 or position > len(self.children):
            return False
        for row in range(count):
            data = [None for v in range(columns)]
            item = TreeNode(data, self)
            self.children.insert(position, item)

    def parent(self):
        return self.parentItem

    def setData(self, column, value):
        if column < 0 or column >= len(self.itemData):
            return False
        self.itemData[column] = value

In [3]:
def resample(
    image: np.ndarray, 
    dtype: type
) -> np.ndarray:
    if np.iinfo(image.dtype).max < np.iinfo(dtype).max:
        return image.astype(dtype)
    return ((image - np.iinfo(image.dtype).min/np.iinfo(image.dtype).max) * np.iinfo(dtype).max).astype(dtype)

In [4]:
def requestImage(self, id: str, size: QSize, requestedSize: QSize) -> QImage:
    files = glob.glob("M31Andromeda/*.FITS")
    assert len(files) > int(id)
    file = os.path.abspath(files[int(id)])
    print(file)
    assert os.path.isfile(file), "RAW file doesn't exist."
    if file[-5:] in [".fits", ".FITS"] or \
        file[-4:] in [".fit", ".FIT"]:
        raw_image = fits.getdata(file, ext=0)
    elif file[-5:] in [".tiff", ".jpeg"] or \
        file[-4:] in [".tif", ".tif", ".png"]:
        raw_image = imageio.imread(file)
    else:
        try:
            with rawpy.imread(file) as raw:
                raw_image = raw.raw_image
        except rawpy.LibRawError:
            print("Invalid file format.")
            return None
    #raw_image = self.resample(raw_image, dtype=np.uint16)
    print(raw_image.shape)
    print(raw_image.dtype)
    h, w = raw_image.shape[:2]
    bytes_per_line = w * \
        (2 if raw_image.dtype == np.uint16 else 1) * \
        (raw_image.shape[2] if len(raw_image.shape) > 2 else 1)
    fmt = QImage.Format_Grayscale16 if raw_image.dtype == np.uint16 else QImage.Format_Grayscale8
    return QImage(raw_image.data, w, h, bytes_per_line, fmt)

In [5]:
from enum import Enum
class Detector(Enum):
    ORB = cv2.ORB_create()
    SIFT = cv2.SIFT_create()
    AKAZE = cv2.AKAZE_create()

In [6]:
def set_params(
    demosaic_algorithm=None,
    half_size=False,
    four_color_rgb=False,
    dcb_iterations=0,
    dcb_enhance=False,
    fbdd_noise_reduction=rawpy.FBDDNoiseReductionMode.Off,
    noise_thr=None,
    median_filter_passes=0,
    use_camera_wb=False,
    use_auto_wb=False,
    user_wb=None,
    output_color=rawpy.ColorSpace.sRGB,
    output_bps=8,
    user_flip=None,
    user_black=None,
    user_sat=None,
    no_auto_bright=False,
    auto_bright_thr=None,
    adjust_maximum_thr=0.75,
    bright=1.0,
    highlight_mode=rawpy.HighlightMode.Clip,
    exp_shift=None,
    exp_preserve_highlights=0.0,
    no_auto_scale=False,
    gamma=None,
    chromatic_aberration=None,
    bad_pixels_path=None
) -> None:
    postprocess_params = rawpy.Params(
        demosaic_algorithm=None,
        half_size=False,
        four_color_rgb=False,
        dcb_iterations=0,
        dcb_enhance=False,
        fbdd_noise_reduction=rawpy.FBDDNoiseReductionMode.Off,
        noise_thr=None,
        median_filter_passes=0,
        use_camera_wb=False,
        use_auto_wb=False,
        user_wb=None,
        output_color=rawpy.ColorSpace.sRGB,
        output_bps=8,
        user_flip=None,
        user_black=None,
        user_sat=None,
        no_auto_bright=False,
        auto_bright_thr=None,
        adjust_maximum_thr=0.75,
        bright=1.0,
        highlight_mode=rawpy.HighlightMode.Clip,
        exp_shift=None,
        exp_preserve_highlights=0.0,
        no_auto_scale=False,
        gamma=None,
        chromatic_aberration=None,
        bad_pixels_path=None
    )

In [7]:
def calibrate(
    image: np.ndarray,
    master_dark: Optional[np.ndarray] = None,
    master_flat: Optional[np.ndarray] = None,
) -> np.ndarray:
    """Calibrates an image.

    Parameters
    ----------
    image : numpy.ndarray
        Uncalibrated image.
    master_dark : numpy.ndarray
        Master dark calibration frame.
    master_flat : numpy.ndarray
        Master flat calibration frame.

    Returns
    -------
    image : numpy.ndarray
        Calibrated image
    """
    image = 1.0 * image  # floating point conversion for data retention in computation
    if master_dark is not None:
        image = image - master_dark
    if master_flat is not None:
        image = image / master_flat
    image[image < 0] = 0
    image = image.astype(np.uint16)
    return image

In [8]:
def load(
    raw_file: str,
    params: rawpy.Params
) -> Optional[np.ndarray]:
    """Loads and processes an astrophotography image at a specified file path.

    Paramters
    ---------
    raw_file : str
        Path to specified RAW file.
    
    Returns
    -------
    image : Optional[numpy.ndarray]
        Unprocessed raw image
    """

    try:
        if raw_file[-5::] == ".fits" or raw_file[-4::] == ".fts":
            with fits.open(raw_file) as hdul:
                image = hdul[0].data
        else:
            with rawpy.imread(raw_file) as raw:
                image = raw.postprocess(params)
    except rawpy.LibRawError:
        return None
    else:
        return image

In [9]:
def preview(
    raw_file: str
) -> np.ndarray:
    """Loads preview of RAW image.

    Parameters
    ----------
    raw_file : str
        Path to RAW file

    Returns
    -------
    raw_image : numpy.ndarray
        Data for RAW image
    """
    raw_image = load(raw_file)
    if raw_image:
        w, h, c = raw_image.shape
        raw_image = QImage(raw_image, w, h, 3*w, QImage.Format.Format_RGB888)
        return raw_image

In [10]:
def master_calibration(
    file_tree: dict[list[str]],
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    """Generates calibration frame from a specified image set.

    Paramters
    ---------
    file_tree : dict[list[str]]
        File tree containing light, dark, bias, and flat images.

    Returns
    -------
    master_dark : Optional[numpy.ndarray]
        Master dark calibration frame
    master_flat : Optional[numpy.ndarray]
        Master flat calibration frame
    """
    
    bias_frames = np.array(list(map(load, file_tree["Bias"])))
    master_bias = np.mean(bias_frames, axis=0) if len(bias_frames) else None

    dark_frames = np.array(list(map(load, file_tree["Dark"])))
    master_dark = np.mean(dark_frames, axis=0) if len(dark_frames) else None

    dark_flat_frames = np.array(list(map(load, file_tree["Dark Flat"])))
    master_dark_flat = np.mean(dark_flat_frames, axis=0) if len(dark_flat_frames) else None
    if master_dark_flat:
        if master_bias is not None:
            master_dark_flat -= master_bias

    flat_frames = np.array(list(map(load, file_tree["Flat"])))
    master_flat = np.mean(flat_frames, axis=0) if len(flat_frames) else None
    if master_flat:
        if master_bias is not None:
            master_flat -= master_bias
        if master_dark_flat is not None:
            master_flat -= master_dark_flat

    return master_dark, master_flat

In [11]:
def detect_features(
    image: np.ndarray, 
    detector: Detector
) -> tuple[Optional[tuple[cv2.KeyPoint]], Optional[np.ndarray]]: 
    scale = False
    if scale:
        if np.max(image) < 255:
            image = 257 * image
        peaks, _ = find_peaks(image.flatten(), height=7 * 257)
        row, col = divmod(peaks, image.shape[0])
        save = image[row, col]
        image = image.astype(np.uint16)
        image[row, col] = np.iinfo(np.uint16).max
    keypoints, descriptors = detector.value.detectAndCompute(resample(image, dtype=np.uint8), None)
    return keypoints, descriptors

In [12]:
def register(
    image: np.ndarray,
    base_keypoints: tuple[cv2.KeyPoint],
    base_descriptors: np.ndarray,
    detector: Detector,
    match_threshhold: float = 0.8,
) -> np.ndarray:
    """Registers and aligns an image

    Parameters
    ----------
    image : numpy.ndarray
        Calibrated image
    features : tuple[tuple[cv2.KeyPoint], numpy.ndarray]
        Keypoints and descriptors from features of interest for current image.
    base_features : tuple[tuple[cv2.KeyPoint], numpy.ndarray]
        Keypoints and descriptors for features of interest of reference image.
    feature_detector : str
        Feature detector (ORB, SIFT, or AKAZE) to indentify prominent stars.
    match_threshhold : float
        Percentage of matches to include in registration process.

    Returns
    -------
    image : numpy.ndarray
        Calibrated and registered image.
    """
    keypoints, descriptors = detect_features(image, detector)

    matcher = (
        cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        if detector == Detector.SIFT
        else cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
    )
    matches = matcher.match(descriptors, base_descriptors, None)
    matches = sorted(matches, key=lambda x: x.distance)

    if detector == Detector.SIFT:
        matches = [m1 for m1, m2 in matches if m1.distance < 0.6 * m2.distance]
    else:
        matches = matches[: int(len(matches) * match_threshhold)]

    features = [keypoints, base_keypoints]
    points = np.empty((2, len(matches), 2))
    for i, feature in enumerate(features):
        for j, match in enumerate(matches):
            points[i, j, :] = feature[match.trainIdx if i else match.queryIdx].pt
    try:
        homography, mask = cv2.findHomography(points[0], points[1], cv2.RANSAC)
        image = cv2.warpPerspective(image, homography, (image.shape[1], image.shape[0]))
    except cv2.error:
        dx = int(base_keypoints[0].pt[0] - keypoints[0].pt[0])
        dy = int(base_keypoints[0].pt[1] - keypoints[0].pt[1])
        image = np.roll(image, dx, axis=1)
        image = np.roll(image, dy, axis=0)
    return image

In [13]:
def display(
    image: np.ndarray,
    filename: Optional[str] = None,
    save: bool = False
) -> None:
    """Displays an image.

    Paramters
    ---------
    image : numpy.ndarray
        Image to be displayed.
    filename : Optional[str]
        Filename in which to save resulting image.
    save : bool
        Determines if resulting image is automatically saved.

    Returns
    -------
    None
    """
    filename = filename.replace("_", " ") if filename else ""
    fig, ax = plt.subplots()
    ax.imshow(resample(image, dtype=np.uint8), cmap=(None if len(image.shape) == 3 else "gray"))
    plt.style.use("dark_background")
    plt.title(filename)
    plt.axis("off")
    if save:
        plt.savefig(filename, bbox_inches="tight", pad_inches=0)
    plt.show()
    plt.close(fig)

In [14]:
def stack(
    file_tree: dict[list[str]],
    detector: Detector,
    params: rawpy.Params,
    reference_index: int = 0,
    filename: Optional[str] = None,
    save: bool = False,
    verbose: bool = True
) -> None:
    """Stacks astrophotography images.

    Parameters
    ----------
    file_tree : dict[list[str]]
        Absolute paths to light, bias, dark, dark flat, and flat images.
    detector : str
        Feature detector (ORB, SIFT, or AKAZE) to indentify prominent stars.
    filename : Optional[str]
        Filename in which to save resulting image.
    save : bool
        Determines if resulting image is automatically saved.

    Returns
    -------
    stacked_image: numpy.ndarray
        Stacked image
    """
    load_with_params = lambda file: load(file, params) 
    
    bias_frames = np.array(list(map(load_with_params, file_tree["Bias"])))
    master_bias = np.mean(bias_frames, axis=0) if len(bias_frames) else None
    if verbose and master_bias is not None:
        print("loaded bias frames")

    dark_frames = np.array(list(map(load_with_params, file_tree["Dark"])))
    master_dark = np.mean(dark_frames, axis=0) if len(dark_frames) else None
    if verbose and master_dark is not None:
        print("loaded dark frames")

    dark_flat_frames = np.array(list(map(load_with_params, file_tree["Dark Flat"])))
    master_dark_flat = np.mean(dark_flat_frames, axis=0) if len(dark_flat_frames) else None
    if master_dark_flat:
        if master_bias is not None:
            master_dark_flat -= master_bias
    if verbose and master_dark_flat is not None:
        print("loaded dark flat frames")

    flat_frames = np.array(list(map(load_with_params, file_tree["Flat"])))
    master_flat = np.mean(flat_frames, axis=0) if len(flat_frames) else None
    if master_flat:
        if master_bias is not None:
            master_flat -= master_bias
        if master_dark_flat is not None:
            master_flat -= master_dark_flat,
    if verbose and master_flat is not None:
        print("loaded flat frames")

    light_frames = np.array(list(map(load_with_params, file_tree["Light"])))
    light_frames = np.array(list(map(lambda frame: calibrate(frame, master_dark, master_flat), light_frames)))
    if verbose:
        print("loaded light frames")

    base_keypoints, base_descriptors = detect_features(light_frames[reference_index], detector) 
    light_frames = np.array(list(map(
        lambda index: light_frames[index] if index == reference_index else register(light_frames[index], base_keypoints, base_descriptors, detector),
        range(len(light_frames))
    )))
    if verbose:
        print("registered light frames")

    stacked_image = np.sum(light_frames, axis=0)
    display(stacked_image)


In [15]:

@widgets.interact(
    demosaic_algorithm=(0,12),
    half_size=[True,False],
    four_color_rgb=[True,False],
    dcb_iterations=(0,10),
    dcb_enhance=[True,False],
    fbdd_noise_reduction=(0,2),
    noise_thr=(0.0,1.0),
    median_filter_passes=(0,10),
    use_camera_wb=[True,False],
    use_auto_wb=[True,False],
    r_wb=(0.0,1.0),
    g1_wb=(0.0,1.0),
    g2_wb=(0.0,1.0),
    b_wb=(0.0,1.0),
    output_color=(0,8),
    output_bps=[8,16],
    user_flip=[0,3,5,6],
    user_black=(0,255),
    user_sat=(0,255),
    no_auto_scale=[True,False],
    no_auto_bright=[True,False],
    auto_bright_thr=(0.0,1.0),
    adjust_maximum_thr=(0.0,1.0),
    bright=(0.0,1.0),
    highlight_mode=(0,2),
    exp_shift=(0.25,8.0),
    exp_preserve_highlights=(0.0,1.0),
    power=(0.0,5.0),
    slope=(0.0,5.0),
    red_scale=(0,1),
    blue_scale=(0,1),
    bad_pixels_path=[None]   
)
def choose_params(
    demosaic_algorithm: int = 3,
    half_size: bool = False,
    four_color_rgb: bool = False,
    dcb_iterations: int = 0,
    dcb_enhance: bool = False,
    fbdd_noise_reduction: int = 0,
    noise_thr: float = None,
    median_filter_passes: int = 0,
    use_camera_wb: bool = True,
    use_auto_wb: bool = False,
    r_wb: float = 1.0,
    g1_wb: float = 1.0,
    g2_wb: float = 1.0,
    b_wb: float = 1.0,
    output_color: int = 1,
    output_bps: int = 16,
    user_flip: int = 0,
    user_black: int = None,
    user_sat: int = None,
    no_auto_scale: bool = False,
    no_auto_bright: bool = False,
    auto_bright_thr: float = None,
    adjust_maximum_thr: float = 0.75,
    bright: float = 1.0,
    highlight_mode: int = 1,
    exp_shift: float = 1.0,
    exp_preserve_highlights: float = 0.0,
    power: float = 2.222,
    slope: float = 4.5,
    red_scale: int = 1,
    blue_scale: int = 1,
    bad_pixels_path: str = None
):
    params = rawpy.Params(
        demosaic_algorithm=rawpy.DemosaicAlgorithm(demosaic_algorithm),
        half_size=half_size,
        four_color_rgb=four_color_rgb,
        dcb_iterations=dcb_iterations,
        dcb_enhance=dcb_enhance,
        fbdd_noise_reduction=rawpy.FBDDNoiseReductionMode(fbdd_noise_reduction),
        noise_thr=noise_thr,
        median_filter_passes=median_filter_passes,
        use_camera_wb=use_camera_wb,
        use_auto_wb=use_auto_wb,
        user_wb=[r_wb,g1_wb,g2_wb,b_wb],
        output_color=rawpy.ColorSpace(output_color),
        output_bps=output_bps,
        user_flip=user_flip,
        user_black=user_black,
        user_sat=user_sat,
        no_auto_scale=no_auto_scale,
        no_auto_bright=no_auto_bright,
        auto_bright_thr=auto_bright_thr,
        adjust_maximum_thr=adjust_maximum_thr,
        bright=bright,
        highlight_mode=rawpy.HighlightMode(highlight_mode),
        exp_shift=exp_shift,
        exp_preserve_highlights=exp_preserve_highlights,
        gamma=(power,slope),
        chromatic_aberration=(red_scale,blue_scale),
        bad_pixels_path=bad_pixels_path
    )

    repo = git.Repo('.', search_parent_directories=True)
    image_catagories = ["Light", "Bias", "Dark", "Dark Flat", "Flat"]
    file_tree = {}
    for catagory in image_catagories:
        file_tree[catagory] = glob.glob(f"{repo.working_tree_dir}/src/test_data/{catagory}/*.ARW")
    #print(type(Detector.SIFT))
    stack(file_tree, Detector.ORB, params)

interactive(children=(IntSlider(value=3, description='demosaic_algorithm', max=12), Dropdown(description='half…