In [None]:
# imports
import os
import rawpy
import cv2
from tqdm.notebook import tqdm, trange
import numpy as np

import matplotlib.pyplot as plt

# from astropy.nddata import CCDData
# from astropy.stats import mad_std
# import ccdproc as ccdp

# test data sourced from https://www.reddit.com/r/astrophotography/comments/9q7tum/andromeda_galaxy_raw_photos_to_experiment/

In [None]:
# global variables
CURRENT_DIR = os.path.dirname(os.path.realpath('__file__'))

In [None]:
def register_image(calibrated_image: np.ndarray[np.uint16], features: tuple, base_features: tuple, detector_type: str, match_percent_threshhold: float=0.8) -> np.ndarray[int]:
    """Registers and aligned an image
    
    Parameters
    ----------
    calibrated_image: numpy.ndarray[numpy.uint16]
        Calibrated image
    
    features: tuple
        Keypoints and descriptors from features of interest for current image
    
    base_features: tuple
        Keypoints and descriptors for features of interest of base image
    
    match_percent_threshhold: float
        Percentage threshold of matches to include in registration process

    Returns
    -------
    registered_image: numpy.ndarray[numpy.uint16]
        Registered image
    """    
    keypoints, descriptors = features
    base_keypoints, base_descriptors = base_features

    matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) if detector_type != 'SIFT' else cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
    matches = matcher.match(descriptors, base_descriptors, None)
    matches = sorted(matches, key = lambda x:x.distance)
    
    if matcher == 'SIFT':
        good_matches = []
        for m1, m2 in matches:
            if m1.distance < 0.6*m2.distance:
                good_matches.append(m1)
        matches = good_matches
    else:
        matches = matches[:int(len(matches)*match_percent_threshhold)]

    features = [keypoints, base_keypoints]
    points = np.empty((2, len(matches), 2))
    for (i, feature) in enumerate(features):
        for (j, match) in enumerate(matches):
            points[i,j,:] = feature[match.trainIdx if i else match.queryIdx].pt

    homography, mask = cv2.findHomography(points[0], points[1], cv2.RANSAC)
    registered_image = cv2.warpPerspective(calibrated_image, homography, (calibrated_image.shape[1], calibrated_image.shape[0]))
    return registered_image

In [None]:
def get_calibration_frames(base_dir: str, use_median: bool=False) -> tuple[np.ndarray[int]]:
    """Generates calibration frame from a specified image set
    
    Paramters
    ---------
    base_dir: str
        Base directory containing image set
    
    use_median: bool
        Determines if median function is used for calibration frame averaging
    
    Returns
    -------
    master_bias_frame: numpy.ndarray[numpy.uint16]
        Master bias calibration frame

    master_dark_frame: numpy.ndarray[numpy.uint16]
        Master dark calibration frame
    """
    averaging_func = np.median if use_median else np.mean
    
    bias_dir = f"{base_dir}/BIAS"
    bias_frames = []
    dark_dir = f"{base_dir}/DARKS"
    dark_frames = []

    try:
        for (i, file) in enumerate(os.listdir(bias_dir)):
            with rawpy.imread(f"{bias_dir}/{file}") as raw:
                bias_frames.append(raw.raw_image)
        master_bias_frame = averaging_func(np.asarray(bias_frames), axis=0)
        for (i, file) in enumerate(os.listdir(dark_dir)):
            with rawpy.imread(f"{dark_dir}/{file}") as raw:
                dark_frames.append(raw.raw_image)
        master_dark_frame = averaging_func(np.asarray(dark_frames), axis=0)
        return master_bias_frame, master_dark_frame
    except rawpy.LibRawError:
        return None, None

In [None]:
def scale_to_8bit(image: np.ndarray[np.uint16]) -> np.ndarray[np.uint8]:
    scaled_image = (image/257).astype(np.uint8)
    return scaled_image

In [None]:
def calibrate_image(image_file_path: str, master_bias_frame: np.ndarray[np.uint16], master_dark_frame: np.ndarray[np.uint16], detector_type) -> np.ndarray[np.uint16]:
    """Loads and calibrates an image at a specified file path

    Paramters
    ---------
    image_file_path: str
        Full path to specified image

    master_bias_frame: numpy.ndarray[numpy.uint16]
        Master bias calibration frame

    master_dark_frame: numpy.ndarray[numpy.uint16]
        Master dark calibration frame

    Returns
    -------
    calibrated_image: numpy.ndarray[numpy.uint16]
        Calibrated image
    
    keypoints: tuple
        Keypoints from features of interest for calibrated image
    
    descriptors: tuple
        Descriptors from features of interest for calibrated image
    """
    
    try:
        with rawpy.imread(image_file_path) as raw:
            calibrated_image = 1.0*raw.raw_image - master_dark_frame
            calibrated_image[calibrated_image < 0] = 0
            calibrated_image = calibrated_image.astype(np.uint16)
            np.copyto(raw.raw_image, calibrated_image)

            params = rawpy.Params(gamma=(1,1),
                                  no_auto_scale=False,
                                  no_auto_bright=True,
                                  output_bps=16,
                                  use_camera_wb=True,
                                  use_auto_wb=False,
                                  user_wb=None,
                                  output_color=rawpy.ColorSpace.sRGB,
                                  demosaic_algorithm=rawpy.DemosaicAlgorithm.AHD,
                                  fbdd_noise_reduction=rawpy.FBDDNoiseReductionMode.Full,
                                  dcb_enhance=False,
                                  dcb_iterations=0,
                                  half_size=False,
                                  median_filter_passes=0,
                                  user_black=0)
                                  
            calibrated_image = raw.postprocess(params)
            if detector_type == 'ORB':
                detector = cv2.ORB_create()
            elif detector_type == 'SIFT':
                detector = cv2.SIFT_create()
            else:
                detector = cv2.AKAZE_create()
            keypoints, descriptors = detector.detectAndCompute(scale_to_8bit(calibrated_image), None)
            return calibrated_image, keypoints, descriptors
    except rawpy.LibRawError:
        return None, None, None

In [None]:
def show(image_stack: np.ndarray[np.uint16], title: str="") -> None:
    """Shows image(s)

    Paramters
    ---------
    image_stack: numpy.ndarray[numpy.uint16]
        Image(s) to be displayed
    
    title: str
        Title of image plot(s)
    
    Returns
    -------
    None
    """
    if len(image_stack.shape) == 4:
        num_of_plots = image_stack.shape[0]
        fig, ax = plt.subplots(ncols=num_of_plots, figsize=(5*num_of_plots, 5))
        plots = []
        for (i, image) in enumerate(image_stack):
            plots.append(ax[i].imshow(scale_to_8bit(image)))
            # fig.colorbar(plots[i], ax=ax[i])
    else:
        fig, ax = plt.subplots()
        plot = ax.imshow(scale_to_8bit(image_stack))
        # fig.colorbar(plot, ax=ax)
    plt.style.use('dark_background')
    plt.title(title)
    plt.axis("off")
    plt.show()
    plt.close(fig)

In [None]:
def stack_images(image_set: str) -> np.ndarray[np.uint16]:
    """
    
    Parameters
    ----------
    image_set: str
        Name of image set to be processed

    Returns
    -------
    stacked_image: numpy.ndarray[numpy.uint16]
        Stacked image
    """
    global CURRENT_DIR
    
    base_dir = f"{CURRENT_DIR}/../data/{image_set}"
    if os.path.isdir(base_dir):

        master_bias_frame, master_dark_frame = get_calibration_frames(base_dir)
        light_dir = f"{base_dir}/LIGHTS"

        detector_types = ['SIFT', 'AKAZE', 'ORB']
        images = []
        for detector_type in detector_types:
            for file in os.listdir(light_dir)[0:1]:
                stacked_image, base_keypoints, base_descriptors = calibrate_image(image_file_path=f"{light_dir}/{file}",
                                                                                  master_bias_frame=master_bias_frame,
                                                                                  master_dark_frame=master_dark_frame,
                                                                                  detector_type=detector_type)
            for (i, file) in enumerate(tqdm(os.listdir(light_dir)[1:], desc=f"{detector_type} stacked image")):
                calibrated_image, keypoints, descriptors = calibrate_image(image_file_path=f"{light_dir}/{file}",
                                                                           master_bias_frame=master_bias_frame,
                                                                           master_dark_frame=master_dark_frame,
                                                                           detector_type=detector_type)
                registered_image = register_image(calibrated_image,
                                                features=(keypoints, descriptors),
                                                base_features=(base_keypoints, base_descriptors),
                                                detector_type=detector_type)
                # alpha = 1/(i+1)
                # beta = 1 - alpha
                # stacked_image = cv2.addWeighted(registered_image, alpha, stacked_image, beta, 0)
                stacked_image += registered_image
            stacked_image = (len(os.listdir(light_dir))*stacked_image)
            show(stacked_image, title=f"{detector_type} stacked image")
            images.append(stacked_image)
        return images
    else:
        print("Not a valid image set")



In [None]:
stacked_image = stack_images('ANDROMEDA_TEST')