In [1]:
import sys 
import os
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
%matplotlib widget
import matplotlib as mpl
import matplotlib.pyplot as plt 
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D
import astropy.constants as c
import astropy.units as u
from astropy.io import fits
from IPython.display import display
import logging
from pathlib import Path
from datetime import datetime
import pyregion
from astropy.wcs import WCS
from astropy.visualization.wcsaxes import WCSAxes
from astropy.coordinates import SkyCoord
import importlib
import pickle
import pandas as pd
try:
    logging.getLogger('matplotlib').disabled = True
except:
    pass
import pandas as pd
pd.pandas.set_option('display.max_columns', None)
pd.pandas.set_option('display.max_rows', None)
import shutil
from astropy.table import QTable, Table
from matplotlib.patches import (Ellipse, Rectangle)
import matplotlib.patches as mpatches
import itertools
import shapely
import shapely.plotting
from shapely.geometry.point import Point
from shapely import affinity
from scipy.integrate import dblquad
from uncertainties import ufloat
from scipy.optimize import curve_fit
from astropy.visualization import (
    MinMaxInterval, 
    SqrtStretch,
    ImageNormalize,
    simple_norm
)
from astropy import visualization
import matplotlib.cm as cm
from pprint import pprint
import pylustrator
from astropy.stats import sigma_clipped_stats
from astropy.modeling import models
from astropy.convolution.kernels import CustomKernel
from astropy.stats import gaussian_fwhm_to_sigma
from photutils.utils._parameters import as_pair
from astropy.convolution import discretize_model

In [2]:
ROOT_DIR = os.path.dirname(os.path.realpath('__file__'))
PROG_DIR = os.path.abspath('./mirar')
os.chdir(PROG_DIR)

In [3]:
class Fits:
    def __init__(self, file, print_=True):
        self.file = file
        self.data = {}
        self.header = {}
        open_file = fits.open(self.file)
        if print_:
            open_file.info()
        open_file.close()
        self.read()
        try:
            logging.getLogger("matplotlib").disabled = True
        except:
            pass

    def read(self, hdu=0):
        open_file = fits.open(self.file)
        try:
            self.data[hdu] = open_file[hdu].data.astype(float)
        except (AttributeError, TypeError, ValueError):
            print(rf"HDU {hdu} is not float")
            self.data[hdu] = open_file[hdu].data

        self.header[hdu] = open_file[hdu].header
        open_file.close()

    def wcs_plot(self, hdu=0):
        wcs = WCS(self.header[hdu])
        fig = plt.figure(clear=True)
        ax = plt.subplot(projection=wcs)
        fig.add_axes(ax)
        return fig, ax

    def create_copy(self):
        shutil.copy2(self.file, f"{self.file}.copy")

    def image(
        self,
        hdu=0,
        column=False,
        title=None,
        scale=[5, 95],
        save=False,
        tag=None,
        wcs=False,
        median=True,
        meanstd=True,
    ):
        data = self.data[hdu]
        if column:
            data = data[column]
        if not wcs:
            fig, ax = plt.subplots()
            ax.set_xlabel("x pixel")
            ax.set_ylabel("y pixel")
        else:
            wcs_ = WCS(self.header[hdu])
            fig = plt.figure(clear=True)
            ax = plt.subplot(projection=wcs_)
            fig.add_axes(ax)
        self.ax = ax
        if not meanstd and scale:
            try:
                if median:
                    img = np.nanmedian(data, axis=1)
                else:
                    img = data
            except np.AxisError:
                print("invalid shape", data.shape)
                return
            self.scale_low, self.scale_high = np.percentile(img, scale)
            im = ax.imshow(
                data, cmap="magma", vmin=self.scale_low, vmax=self.scale_high
            )
        elif meanstd:
            mean, std = np.nanmean(data), np.nanstd(data)
            vmin = mean - std
            vmax = mean + 10 * std
            im = ax.imshow(
                data,
                interpolation="nearest",
                cmap="grey",
                vmin=vmin,
                vmax=vmax,
                origin="lower",
                # norm=mcolors.Normalize(vmin=vmin,vmax=vmax)
            )
        else:
            im = ax.imshow(data, cmap="magma")
        fig.colorbar(im, ax=ax , pad=0.005)
        if not title:
            title = rf"{os.path.basename(os.path.dirname(self.file))}/{os.path.basename(self.file)}"
        ax.set_title(title)
        if save:
            if tag:
                fig.savefig(
                    os.path.join(
                        os.path.dirname(self.file),
                        rf"{os.path.basename(self.file)}_{tag}.png",
                    ),
                    dpi=600,
                )
            else:
                fig.savefig(
                    os.path.join(
                        os.path.dirname(self.file),
                        rf"{os.path.basename(self.file)}.png",
                    ),
                    dpi=600,
                )
        fig.tight_layout()
        return ax

    def image_with_reg(
        self, reg, hdu=0, wcsaxis=[0.1, 0.1, 0.8, 0.8], v=[0, 100], save=False, tag=None
    ):
        r = pyregion.open(reg).as_imagecoord(self.header[hdu])
        patch_list, artist_list = r.get_mpl_patches_texts()
        wcs = WCS(self.header[hdu])
        fig = plt.figure(clear=True)
        ax = plt.subplot(projection=wcs)
        # ax = WCSAxes(fig,wcsaxis,wcs=wcs)
        fig.add_axes(ax)
        for p in patch_list:
            p.set_color("red")
            p.set_facecolor("none")
            ax.add_patch(p)
        for t in artist_list:
            ax.add_artist(t)

        if v:
            im = ax.imshow(
                self.data[0], origin="lower", vmin=v[0], vmax=v[1], cmap="magma"
            )
        else:
            im = ax.imshow(self.data[0], origin="lower", cmap="magma")
        fig.colorbar(im, cmap="magma")
        if save:
            if tag:
                fig.savefig(
                    os.path.join(
                        os.path.dirname(self.file),
                        rf"{os.path.basename(self.file)}_{tag}.reg.png",
                    ),
                    dpi=600,
                )
            else:
                fig.savefig(
                    os.path.join(
                        os.path.dirname(self.file),
                        rf"{os.path.basename(self.file)}.reg.png",
                    ),
                    dpi=600,
                )
        # fig.show()
        return ax

    def bin_table(self, hdu=0, return_=True, save=False):
        df = pd.DataFrame(self.data[hdu])
        if save:
            df.to_csv(
                os.path.join(
                    os.path.dirname(self.file), rf"{os.path.basename(self.file)}.csv"
                ),
                encoding="utf-8",
            )
        if return_:
            return df
        else:
            return None

    def bin_table2(self, hdu=0):
        table = QTable(self.data[hdu])
        return table

    def mark_from_cat(self, keys, cat_file=None, cat_table=None, hdu=2, save=False, color='red'):
        x_offset = 2
        y_offset = 2
        if cat_file:
            catfits = Fits(cat_file)
            catfits.read(hdu)
            cat = catfits.bin_table2(hdu)
        elif isinstance(cat_table,QTable) or isinstance(cat_table,pd.DataFrame):
            cat = cat_table
        else:
            return
        shape = (self.ax.get_xlim()[-1], self.ax.get_ylim()[-1])
        for i, (x, y, a, b, theta) in enumerate(
            zip(
                cat[keys["x"]],
                cat[keys["y"]],
                cat[keys["a"]],
                cat[keys["b"]],
                cat[keys["angle"]],
            )
        ):
            marker = Ellipse(xy=(x,y), height=a, width=b, angle=theta-90, color=color, fill=None) # theta-90 to rotate wrt. x
            self.ax.add_patch(marker)
            annotation = str(i)
            try:
                if 'NUMBER' in list(cat.columns):
                    annotation = cat['NUMBER'][i]
            except:
                print('could not access column names')
            if x + x_offset >= shape[0] - x_offset * 2:
                self.ax.annotate(
                    annotation, (x - x_offset*2, y + y_offset), color=color
                )
            else:
                self.ax.annotate(annotation, (x + x_offset, y + y_offset), color=color)
        if save:
            plt.savefig(cat_file+'.png',dpi=300)

SEX_SRC_KEYS = {
    'x': 'X_IMAGE',
    'y': 'Y_IMAGE',
    'a': 'A_IMAGE',
    'b': 'B_IMAGE',
    'angle': 'THETA_IMAGE'
}

MIRAR_SRC_KEYS = {
    'x': 'xpos',
    'y': 'ypos',
    'a': 'aimage',
    'b': 'bimage',
    'angle': 'THETA_IMAGE'
}

PSF_PHOT_SRC_KEYS = {
    'x': 'x_fit',
    'y': 'y_fit',
    'a': 'a',
    'b': 'b',
    'angle': 'angle'
}

class MakeFits:
    def __init__(self, data, filename, obsclass="science", **kwargs):
        self.data = data
        self.filename = filename
        self.obsclass = obsclass

        hdu0 = fits.PrimaryHDU(data)
        self.hdul = fits.HDUList([hdu0])
        self.hdul.verify("fix")
        self.fix_header()
        self.save(**kwargs)

    def fix_header(self):
        header = self.hdul[0].header
        header[OBSCLASS_KEY] = self.obsclass
        header[TARGET_KEY] = self.obsclass
        header[TIME_KEY] = str(datetime.now())
        header[COADD_KEY] = 1
        header[GAIN_KEY] = 1
        header[PROC_HISTORY_KEY] = ""
        header[PROC_FAIL_KEY] = ""
        header[BASE_NAME_KEY] = Path(self.filename).name

    def save(self, **kwargs):
        self.hdul.writeto(self.filename, **kwargs)


def src_table(file, save=False, return_=True):
    df = pd.read_pickle(file).get_data()
    if save:
        df.to_csv(
            os.path.join(os.path.dirname(file), rf"{os.path.basename(file)}.csv"),
            encoding="utf-8",
        )
    if return_:
        return df
    else:
        return None


class Convert:
    def __init__(self, ra, dec):
        self.coords = SkyCoord(ra=ra * u.deg, dec=dec * u.deg)

    def get_coords(self):
        return self.coords

    def to_hms(self):
        return (self.coords.ra.hms, self.coords.dec.hms)

    def to_wcs(self):
        return (self.coords.ra.hms, self.coords.dec.dms)


def command(config, py=""):
    return rf"python{py} -m mirar -p {PIPELINE} -n {NIGHT} -c {config} -m"


def run(config, *args, **kwargs):
    os.system(command(config, *args, **kwargs))

def get_prams():
    return _ih[-1]

def save_prams(path):
    with open(path,'w') as file:
        file.write(get_prams())

def test(py=""):
    os.system(rf"python{py} -m unittest discover tests/")

def plot_image(ax,data,cmap='grey',scale=True,colorbar=True):
    mean, std = np.nanmean(data), np.nanstd(data)
    if scale:
        vmin = mean - std
        vmax = mean + 10 * std
    else:
        vmin = vmax = None
    im = ax.imshow(
        data,
        interpolation="nearest",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        origin="lower",
    )
    if colorbar:
        ax.get_figure().colorbar(im, ax=ax , pad=0.005)
    return im

def peek(ax,coords,size):
    ax.set_xlim([coords[0]-size//2,coords[0]+size//2])
    ax.set_ylim([coords[1]-size//2,coords[1]+size//2])
  
def get_pos(table,obj,keys):
    row = table[obj-1]
    x = row[keys['x']]
    y = row[keys['y']]
    return np.array([float(x),float(y)])

def mark_from_cat(ax, keys, cat_file=None, cat_table=None, hdu=2, save=False, color='red', condition=True,bound=False,bounds=None,annotate=True):
    x_offset = 2
    y_offset = 2
    if cat_file:
        catfits = Fits(cat_file)
        catfits.read(hdu)
        cat = catfits.bin_table2(hdu)
    elif isinstance(cat_table,QTable) or isinstance(cat_table,pd.DataFrame):
        cat = cat_table
    else:
        return
    shape = (ax.get_xlim()[-1], ax.get_ylim()[-1])
    
    pos_all = []
    for i, (x, y, a, b, theta) in enumerate(
        zip(
            cat[keys["x"]],
            cat[keys["y"]],
            cat[keys["a"]],
            cat[keys["b"]],
            cat[keys["angle"]],
        )
    ):
        if condition:
            if not bound or bound and (x>=bounds[0][0] and x<=bounds[0][1] and y>=bounds[1][0] and y<=bounds[1][1]):
                marker = Ellipse(xy=(x,y), height=a, width=b, angle=theta-90, color=color, fill=None) # theta-90 to rotate wrt. x
                ax.add_patch(marker)
                annotation = str(i)
                if annotate:
                    try:
                        if 'NUMBER' in list(cat.columns):
                            annotation = cat['NUMBER'][i]
                    except:
                        print('could not access column names')
                    if x + x_offset >= shape[0] - x_offset * 2:
                        ax.annotate(
                            annotation, (x - x_offset*2, y + y_offset), color=color
                        )
                    else:
                        ax.annotate(annotation, (x + x_offset, y + y_offset), color=color)
                pos_all.append([x,y])
    return pos_all
       
 
def none_mask(lst):
    for i in range(len(lst)):
        if lst[i] == None:
            lst[i] = np.nan
    return lst

In [4]:
def test_path(name):
    if isinstance(name,tuple):
        return Path(os.path.abspath(os.path.join(TEST_DIR,*name)))
    else:
        return Path(os.path.abspath(os.path.join(TEST_DIR,name)))
    
def make_test_dirs():
    for dir in OUTPUT_DIRS.values():
        os.makedirs(dir, exist_ok=True)
        
def move_to_raw(path):
    path = Path(path)
    dst = os.path.join(RAW_DIR,path.name)
    shutil.copyfile(path,dst)
    
def prepare(src_dir,src_basenames):
    for file in os.listdir(src_dir):
        for basename in src_basenames:
            if basename in file:
                move_to_raw(os.path.join(src_dir,file)) 
                
def save_params(path):
    with open(path,'w') as file:
        file.write(_ih[-1])
        
DATA_DIR = os.path.join(ROOT_DIR,'SampleData')

TEST_ID = '0'
TEST_DIR = os.path.abspath(os.path.join(ROOT_DIR,'test_'+TEST_ID))
OUTPUT_DIRS = {
    'BKG': test_path('background'),
    'DET': test_path('detection'),
    'PSF_MODEL': test_path('psf_model'),
    'PSF_PHOT': test_path('psf_phot'),
    'APER_PHOT': test_path('aper_phot'),
    'LOG': test_path('log'),
    'CONF': test_path('config'),
    'RES': test_path('results'),
    'CAT': test_path('catalogs'),
}
make_test_dirs()

RAW_DIR = os.path.join(TEST_DIR,'raw')
os.makedirs(RAW_DIR, exist_ok=True)

LOGGING = True
LOG_FILE = os.path.abspath(os.path.join(OUTPUT_DIRS['LOG'],f"test{TEST_ID}_{datetime.now()}"))
if LOGGING:
    logger = logging.getLogger(__name__)
    logging.basicConfig(filename=LOG_FILE,level=logging.DEBUG)
    logging.getLogger('matplotlib').disabled = True

ENV = {
    'RAW_DATA_DIR': RAW_DIR,
    'OUTPUT_DATA_DIR': TEST_DIR,
    'REF_IMG_DIR': '',
    'USE_WINTER_CACHE': 'false',
    'FRITZ_TOKEN': 'test',
    'KOWALSKI_TOKEN': 'test',
    'DB_USER': 'postgres',
    'DB_PWD': '',
}
for var in ENV.keys():
    os.environ[var] = ENV[var]

In [20]:
# mirar
from mirar.pipelines.wifes_autoguider.wifes_autoguider_pipeline import WifesAutoguiderPipeline
from mirar.processors.astromatic.sextractor.background_subtractor import (
    SextractorBkgSubtractor,
)
from mirar.processors.astromatic.sextractor.sextractor import Sextractor
from mirar.processors.utils import (
    CustomImageBatchModifier,
    HeaderAnnotator,
    ImageBatcher,
    ImageDebatcher,
    ImageLoader,
    ImageSaver,
    ImageSelector,
    MEFLoader,
)
from mirar.processors.utils.header_annotate import (
    HeaderEditor,
    # SextractorHeaderCorrector,
)
from mirar.data import (
    Image,
    Dataset, 
    ImageBatch,
    SourceBatch,
    SourceTable
)
from mirar.io import open_raw_image
from mirar.paths import (
    BASE_NAME_KEY,
    COADD_KEY,
    GAIN_KEY,
    LATEST_SAVE_KEY,
    LATEST_WEIGHT_SAVE_KEY,
    OBSCLASS_KEY,
    PROC_FAIL_KEY,
    PROC_HISTORY_KEY,
    RAW_IMG_KEY,
    SATURATE_KEY,
    TARGET_KEY,
    TIME_KEY,
    DIFF_IMG_KEY,
    REF_IMG_KEY,
    SCI_IMG_KEY,
    XPOS_KEY,
    YPOS_KEY,
    NORM_PSFEX_KEY,
    core_fields,
    get_output_dir
)
from mirar.processors.astromatic import PSFex, Scamp
# from mirar.processors.photometry.psf_photometry import SourcePSFPhotometry
# from mirar.processors.photometry.aperture_photometry import SourceAperturePhotometry
from mirar.processors.sources import (
    SourceWriter
)
# from mirar.processors.sources.source_detector import (
#     SourceGenerator
# )
from mirar.processors.base_processor import (
    BaseImageProcessor,
    BaseSourceProcessor,
    BaseSourceGenerator,
)
# from mirar.processors.photometry.base_photometry import (
#     BaseSourcePhotometry,
# )
from mirar.utils.pipeline_visualisation import flowify
from mirar.io import (
    open_fits,
    save_to_path
)
from mirar.processors.base_processor import PrerequisiteError
from mirar.processors.utils.image_selector import select_from_images
from mirar.processors.photcal import PhotCalibrator

# photutils
from photutils.background import Background2D
from photutils.detection import DAOStarFinder
from photutils.psf import extract_stars
from photutils.psf import EPSFBuilder

import astropy


In [35]:
ACQ_KEY = 'acq'

def default_select_acq(
    images: ImageBatch,
) -> ImageBatch:
    """
    Returns images in a batch with are tagged as error

    :param images: set of images
    :return: subset of bias images
    """
    return select_from_images(images, key=OBSCLASS_KEY, target_values=ACQ_KEY)

class PhotutilsBkgSubtractor(BaseImageProcessor):
    
    base_key = "photutilsbkgsubtractor"
    
    def __init__(
        self,
        box_size = 500,
        output_sub_dir = 'background',
        select_science_images: Callable[[ImageBatch], ImageBatch] = default_select_science,
        save_bkg: bool = False,
        cache: bool = False,
    ):
        super().__init__()
        self.box_size = box_size
        self.cache = cache
        self.output_sub_dir = output_sub_dir
        self.save_bkg = save_bkg
        self.select_science_images = select_science_images
    
    def _apply_to_images(
        self,
        batch: ImageBatch,
    ) -> ImageBatch:
        
        science_images = self.select_science_images(batch)
        
        for image in science_images:
            data = image.get_data()
            header = image.get_header()
            background = Background2D(data=data,box_size=self.box_size)
            background_map = background.background
            
            header[BGMED_KEY] = background.background_median
            header[BGRMSMED_KEY] = background.background_rms_median
            
            bkgsub = data - background_map
            image.set_data(bkgsub)
            
            save_images = []
            output_dir = get_output_dir(self.output_sub_dir, self.night_sub_dir)
            
            if self.save_bkg:
                save_images = ['background']
                bkg_image_name = image[BASE_NAME_KEY].replace('fits','background.fits')
                header[BGPATH_KEY] = str(output_dir.joinpath(bkg_image_name))
                
            if self.cache:
                save_images = ['background','background_rms','background_mesh']
                save_name = image[BASE_NAME_KEY].replace('fits','background.pkl')
                dump_object(
                    data=background,
                    path=output_dir.joinpath(save_name)
                )
                
            for im in save_images:
                save_name = image[BASE_NAME_KEY].replace('fits',im+'.fits')
                save_to_path(
                    data=eval('background.'+im),
                    header=image.header,
                    path=output_dir.joinpath(save_name),
                    overwrite=True
                )
            
            image.set_header(header)
        
        return batch

def load_wifes_guider_fits(
    path: str | Path
) -> tuple[np.array, astropy.io.fits.Header]:
    data, header = open_fits(path)
    header[OBSCLASS_KEY] = ACQ_KEY
    header[TARGET_KEY] = header['OBJECT']
    header[COADD_KEY] = 1
    header[GAIN_KEY] = 1
    header['CALSTEPS'] = ''
    header[PROC_FAIL_KEY] = ''
    return data, header

def load_wifes_guider_image(path: str | Path) -> Image:
    return open_raw_image(path, load_wifes_guider_fits)

load = [
    ImageLoader(input_sub_dir=RAW_DIR, input_img_dir=TEST_DIR, load_image=load_wifes_guider_image)
]

test_config = list(itertools.chain(
    load
))

pipeline = WifesAutoguiderPipeline(night=f"test_{TEST_ID}")
pipeline.night_sub_dir = TEST_DIR
pipeline.add_configuration(configuration_name="test_config", configuration=test_config)

save_params(Path(TEST_DIR).joinpath('all.param'))

In [36]:
logger.debug(f"\n\n{datetime.now()}\n\n")

for file in os.listdir(DATA_DIR):
    if file.endswith('.fits'):
        move_to_raw(os.path.join(DATA_DIR,file))
configuration = "test_config"
flowify(processor_list=eval(configuration), output_path=Path(TEST_DIR).joinpath(configuration))
pipeline.reduce_images(Dataset([ImageBatch()]), catch_all_errors=False, selected_configurations=configuration)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:00<00:00, 269.27it/s]


(<mirar.data.base_data.Dataset at 0x2a3e216d0>,
 <mirar.errors.error_stack.ErrorStack at 0x29d639f90>)

In [26]:
file = Fits('/Users/astqx/Desktop/WiFeS/WiFeS_seeing/SampleData/OBK-530784-WiFeS-Acq--UT20231008T093800-5.fits')
file.header[0]

Filename: /Users/astqx/Desktop/WiFeS/WiFeS_seeing/SampleData/OBK-530784-WiFeS-Acq--UT20231008T093800-5.fits
No.    Name      Ver    Type      Cards   Dimensions   Format
  0  PRIMARY       1 PrimaryHDU      60   (1072, 1027)   int16 (rescales to uint16)   


SIMPLE  =                    T / file does conform to FITS standard             
BITPIX  =                   16 / number of bits per data pixel                  
NAXIS   =                    2 / number of data axes                            
NAXIS1  =                 1072 / length of data axis 1                          
NAXIS2  =                 1027 / length of data axis 2                          
EXTEND  =                    T / FITS dataset may contain extensions            
COMMENT   FITS (Flexible Image Transport System) format is defined in 'Astronomy
COMMENT   and Astrophysics', volume 376, page 359; bibcode: 2001A&A...376..359H 
BZERO   =                32768 / offset data range to that of unsigned short    
BSCALE  =                    1 / default scaling factor                         
OBSBLKID=               530784 / Observation Block ID                           
PROPID  =              2370179 / Proposal ID                                    
DATE-OBS= '2023-10-08T09:38: