In [1]:
import warnings
from copy import copy
from timeit import default_timer
from typing import Dict, List, Optional, Callable, Tuple, Union, TypeVar, Any
import numpy as np
from numpy.typing import ArrayLike, NDArray
from astropy.utils.exceptions import AstropyDeprecationWarning, AstropyUserWarning
import astropy.units as u
from astropy.table import Table
from astropy.nddata import NDData
from astropy.modeling import fitting
from astropy.wcs.utils import proj_plane_pixel_area
import multiprocessing

FILE = "FILENAME"
LOCAL_ARCHIEVE = "/Users/kimphan/Desktop/flows_test/photometry/2018rw"
FOLDER_OUTPUT = "/Users/kimphan/Desktop/flows_test/photometry_test_output"
    

# ··###Import magnitudes.py ###

In [2]:
# First download utilities, then filters.py, then magnitudes.py

# utilities.py 

In [3]:
"""
Utility functions
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
import hashlib
import multiprocessing
import logging
import sys
from argparse import Namespace
from typing import Optional

from scipy.interpolate import interp2d


def get_filehash(fname):
    """Calculate SHA1-hash of file."""
    buf = 65536
    s = hashlib.sha1()
    with open(fname, 'rb') as fid:
        while True:
            data = fid.read(buf)
            if not data:
                break
            s.update(data)

    sha1sum = s.hexdigest().lower()
    if len(sha1sum) != 40:
        raise Exception("Invalid file hash")
    return sha1sum


def has_file_handler(logger):
    """Check if logger has one file handler."""
    return sum([type(l) is logging.FileHandler for l in logger.handlers]) > 0


def has_stream_handler(logger):
    """Check if logger has one stream handler."""
    return sum([type(l) is logging.StreamHandler for l in logger.handlers]) > 0


def remove_file_handlers(logger):
    """Remove file handler from logger."""
    for handler in logger.handlers:
        if type(handler) is logging.FileHandler:
            logger.removeHandler(handler)


def create_logger(worker_name=None, log_level: Optional[int] = None, log_file: str = None):
    formatter = logging.Formatter('[%(asctime)s| %(levelname)s| %(processName)s | %(module)s] %(message)s')
    if worker_name is not None:
        process = multiprocessing.current_process()
        process.name = worker_name

    logger = multiprocessing.get_logger()
    if log_level is not None:
        logger.setLevel(log_level)

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)

    # remove duplicated messages in the output
    if not has_stream_handler(logger):
        logger.addHandler(stream_handler)

    # check for log file and remove duplicated messages in the output
    if log_file is not None and not has_file_handler(logger):
        file_formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s -%(module)s - %(message)s', "%Y-%m-%d %H:%M:%S")
        file_handler = logging.FileHandler(log_file, mode='w')
        file_handler.setFormatter(file_formatter)
        file_handler.setLevel(logging.INFO)
        logger.addHandler(file_handler)

    return logger

def create_warning_logger(log_file: str):
    """Create a logger for warnings."""
    logging.captureWarnings(True)
    logger_warn = logging.getLogger('py.warnings')

    file_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s -%(module)s - %(message)s', "%Y-%m-%d %H:%M:%S")
    file_handler = logging.FileHandler(log_file, mode='w')
    file_handler.setFormatter(file_formatter)
    file_handler.setLevel(logging.INFO)

    if not has_file_handler(logger_warn):
        logger_warn.addHandler(file_handler)

    return logger_warn


def parse_log_level(args: Namespace):
    logging_level = logging.INFO
    if args.quiet:
        logging_level = logging.WARNING
    elif args.debug:
        logging_level = logging.DEBUG
    return logging_level

# filters.py 

In [4]:
### filters.py ###
from typing import Optional
#from .utilities import create_logger
logger = create_logger()
FILTERS = {
    'up': 'u_mag',
    'gp': 'g_mag',
    'rp': 'r_mag',
    'ip': 'i_mag',
    'zp': 'z_mag',
    'B': 'B_mag',
    'V': 'V_mag',
    'J': 'J_mag',
    'H': 'H_mag',
    'K': 'K_mag',
}

FALLBACK_FILTER = 'gp'


def get_reference_filter(photfilter: str) -> str:
    """
    Translate photometric filter into table column.

    Parameters:
        photfilter (str): photometric filter corresponding to key of FILTERS
    """

    _ref_filter = FILTERS.get(photfilter, None)
    if _ref_filter is None:
        logger.warning(f"Could not find filter {photfilter} in catalogs. "
                       f"Using default {FALLBACK_FILTER} filter.")
        _ref_filter = FILTERS[FALLBACK_FILTER]
    return _ref_filter

def clean_value(value: str) -> str:
    """
    Clean value.
    """
    return value.replace(' ', '').replace('-', '').replace('.', '').replace('_', '').lower()

COMMON_FILTERS = {
    'B': 'B', 'V': 'V', 'R': 'R', 'g': 'gp', 'r': 'rp', 
    'i': 'ip', 'u': 'up', 'z': 'zp',
    'Ks': 'K', 'Hs': 'H', 'Js': 'J',
    'Bessel-B': 'B', 'Bessel-V': 'V', 'Bessell-V': 'V', 'SDSS-U': 'up',
    'SDSS-G': 'gp', 'SDSS-R': 'rp', 'SDSS-I': 'ip', 'SDSS-Z': 'zp',
    'PS1-u': 'up', 'PS1-g': 'gp', 'PS1-r': 'rp', 'PS1-i': 'ip', 'PS1-z': 'zp',
    'PS2-u': 'up', 'PS2-g': 'gp', 'PS2-r': 'rp', 'PS2-i': 'ip', 'PS2-z': 'zp',
    'PS-u': 'up', 'PS-g': 'gp', 'PS-r': 'rp', 'PS-i': 'ip', 'PS-z': 'zp',
    'Yc': 'Y', 'Jc': 'J', 'Hc': 'H', 'Kc': 'K',
    'Yo': 'Y', 'Jo': 'J', 'Ho': 'H', 'Ko': 'K',
    "J_Open": "J", "H_Open": "H", "K_Open": "K",
    "B_Open": "B", "V_Open": "V", "R_Open": "r", "I_Open": "i", "Y_Open": "Y",
    "g_Open": "gp", "r_Open": "rp", "i_Open": "ip", "z_Open": "zp", 'u_Open': 'up',
}


COMMON_FILTERS_LOWER = {clean_value(key): value for key, value in COMMON_FILTERS.items()}
              
                  
def match_header_to_filter(header_dict: dict[str,str]) -> str:
    """
    Extract flows filter from header.
    """
    bad_keys = ["", "NONE", "Clear"]
    filt = header_dict.get("FILTER")
    if filt is not None and filt not in bad_keys:
        filt = match_filter_to_flows(filt)
        if filt is not None:
            return filt    
    
    for key, value in header_dict.items():
        if "FILT" in key.upper():
            if value not in bad_keys:
                filt = match_filter_to_flows(value)
                if filt is not None:
                    return filt
        
   
    raise ValueError("Could not determine filter from header. Add FILTER keyword with a flows filter to header.")
                  

def match_filter_to_flows(header_filter: str) -> Optional[str]:
    """
    Match filter header value to flows filter.
    """
    if header_filter in FILTERS:
        return header_filter

    values = max(header_filter.lower().split(' '), header_filter.lower().split('.'), key=len)
    
    
    filters_keys_lower = [str(key).lower() for key in FILTERS.keys()]
    for value in values:
        if value in filters_keys_lower:
            if FILTERS.get(value) is not None:
                return value
            return value.upper()
         
    
    filters_keys_lower = [clean_value(str(key)) for key in COMMON_FILTERS.keys()]
    for value in values:
        clean = clean_value(value)
        if clean in filters_keys_lower:
            return COMMON_FILTERS_LOWER.get(clean) 
    return None

# import target.py

In [5]:
from dataclasses import dataclass
from typing import Dict, Optional

import numpy as np
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from numpy.typing import NDArray
from tendrils import api


@dataclass
class Target:
    ra: float
    dec: float
    name: Optional[str] = None
    id: Optional[int] = None  # Target id from Flows database
    photfilter: Optional[str] = None  # Defined if target is associated with an image.
    coords: Optional[SkyCoord] = None
    pixel_column: Optional[int] = None
    pixel_row: Optional[int] = None

    def __post_init__(self):
        if self.coords is None:
            self.coords = SkyCoord(ra=self.ra, dec=self.dec, unit='deg', frame='icrs')

    def calc_pixels(self, wcs: WCS) -> None:
        pixels = np.array(wcs.all_world2pix(self.ra, self.dec, 1)).T
        self._add_pixel_coordinates(pixel_pos=pixels)

    def _add_pixel_coordinates(self, pixel_column: Optional[int] = None, pixel_row: Optional[int] = None,
                               pixel_pos: Optional[NDArray] = None) -> None:
        """
        Add pixel coordinates to target.
        """
        if pixel_column is None or pixel_row is None:
            if pixel_pos is None:
                raise ValueError('Either pixel_column, pixel_row or pixel_pos must be provided.')
            pixel_column, pixel_row = pixel_pos

        self.pixel_column = pixel_column
        self.pixel_row = pixel_row

    def output_dict(self, starid: Optional[int] = 0) -> Dict:
        """
        Return target as output dictionary. starid = -1 means difference image.
        """
        return {'starid': starid, 'ra': self.ra, 'decl': self.dec, 'pixel_column': self.pixel_column,
                'pixel_row': self.pixel_row, 'used_for_epsf': False}

    @classmethod
    def from_dict(cls, d: Dict) -> 'Target':
        """
        Create target from dictionary.
        """
        return cls(ra=d['ra'], dec=d['decl'], name=d['target_name'], id=d['targetid'], photfilter=d['photfilter'],)

    @classmethod
    def from_fid(cls, fid: int, datafile: Optional[Dict] = None) -> 'Target':
        """
        Create target from fileid.
        """

        datafile = datafile or api.get_datafile(fid)
        if datafile is None:
            raise ValueError(f'No datafile found for fid={fid}')
        d = api.get_target(datafile['target_name']) | datafile
        return cls.from_dict(d)

    @classmethod
    def from_tid(cls, target_id: int) -> 'Target':
        """
        Create target from target id.
        """
        target_pars = api.get_target(target_id)
        return cls(
            ra=target_pars['ra'], dec=target_pars['decl'],
            name=target_pars['target_name'], id=target_pars['targetid'])

# import zeropoint.py

In [6]:
"""
Provides functions for computing the zeropoint and its error.
Uses bootstrapping with sigma clipping as outlier rejection,
where the sigma is determined by the Chauvenet criteria. Also
allows for arbitrary outlier and fitting functions.

.. codeauthor:: Emir Karamehmetoglu <emir.k@phys.au.dk>
"""
from typing import List, Optional, Dict, Union, Callable

import numpy as np
from astropy.stats import bootstrap
from astropy.modeling import fitting
from numpy.typing import ArrayLike
from scipy.special import erfcinv


# Calculate sigma for sigma clipping using Chauvenet
def sigma_from_Chauvenet(Nsamples):
    '''Calculate sigma according to the Cheuvenet criterion'''
    return erfcinv(1. / (2 * Nsamples)) * (2.) ** (1 / 2)


def bootstrap_outlier(x: ArrayLike, y: ArrayLike, yerr: ArrayLike, n: int = 500, model='None',
                      fitter: Union[Callable, str] = None, outlier='None', outlier_kwargs: Optional[Dict] = None,
                      summary: Union[Callable, str] = 'median', error: Union[Callable, str] = 'bootstrap',
                      parnames: Optional[List] = None, return_vals: bool = True):
    """x = catalog mag, y = instrumental mag, yerr = instrumental error
    summary = function for summary statistic, np.nanmedian by default.
    model = Linear1D
    fitter = LinearLSQFitter
    outlier = 'sigma_clip'
    outlier_kwargs, default sigma = 3
    return_vals = False will return dictionary
    Performs bootstrap with replacement and returns model.
    """
    summary = np.nanmedian if summary == 'median' else summary
    error = np.nanstd if error == 'bootstrap' else error
    parnames = ['intercept'] if parnames is None else parnames
    outlier_kwargs = {'sigma': 3} if outlier_kwargs is None else outlier_kwargs

    # Create index for bootstrapping
    ind = np.arange(len(x))

    # Bootstrap indexes with replacement using astropy
    bootstraps = bootstrap(ind, bootnum=n)
    bootstraps.sort()  # sort increasing.
    bootinds = bootstraps.astype(int)

    # Prepare fitter
    fitter = fitting.LinearLSQFitter if fitter is None else fitter
    fitter_instance = fitting.FittingWithOutlierRemoval(fitter(), outlier, **outlier_kwargs)
    # Fit each bootstrap with model and fitter using outlier rejection at each step.
    # Then obtain summary statistic for each parameter in parnames
    pars = {}
    out = {}
    for parname in parnames:
        pars[parname] = np.ones(len(bootinds), dtype=np.float64)
    for i, bs in enumerate(bootinds):
        # w = np.ones(len(x[bs]), dtype=np.float64) if yerr=='None' else (1.0/yerr[bs])**2
        w = (1.0 / yerr[bs]) ** 2
        best_fit, sigma_clipped = fitter_instance(model, x[bs], y[bs], weights=w)
        # obtain parameters of interest
        for parname in parnames:
            pars[parname][i] = best_fit.parameters[np.array(best_fit.param_names) == parname][0]
    if return_vals:
        return [summary(pars[par]) for par in pars]

    for parname in parnames:
        out[parname] = summary(pars[parname])
        out[parname + '_error'] = error(pars[parname])
    return out

# magnitudes.py

In [7]:
from typing import Tuple, Optional

import matplotlib.pyplot as plt
import numpy as np
from astropy.modeling import models, fitting
from astropy.stats import sigma_clip
from astropy.table import Table
from bottleneck import nansum

#from filters import get_reference_filter
#from target import Target
#from utilities import create_logger
#from zeropoint import sigma_from_Chauvenet, bootstrap_outlier

logger = create_logger()
#logger = logging.getLogger(__name__)


def instrumental_mag(tab: Table, target: Target, make_fig: bool = False) -> Tuple[Table, Optional[plt.Figure],
                                                                                  Optional[plt.Axes]]:
    target_rows = tab['starid'] <= 0

    # Check that we got valid flux photometry:
    if np.any(~np.isfinite(tab[target_rows]['flux_psf'])) or np.any(~np.isfinite(tab[target_rows]['flux_psf_error'])):
        raise RuntimeError(f"Target:{target.name} flux is undefined.")

    # Convert PSF fluxes to magnitudes:
    mag_inst = -2.5 * np.log10(tab['flux_psf'])
    mag_inst_err = (2.5 / np.log(10)) * (tab['flux_psf_error'] / tab['flux_psf'])

    # Corresponding magnitudes in catalog:
    mag_catalog = tab[get_reference_filter(target.photfilter)]

    # Mask out things that should not be used in calibration:
    use_for_calibration = np.ones_like(mag_catalog, dtype='bool')
    use_for_calibration[target_rows] = False  # Do not use target for calibration
    use_for_calibration[~np.isfinite(mag_inst) | ~np.isfinite(mag_catalog)] = False


    # Just creating some short-hands:
    x = mag_catalog[use_for_calibration]
    y = mag_inst[use_for_calibration]
    yerr = mag_inst_err[use_for_calibration]
    weights = 1.0 / yerr ** 2

    if not any(use_for_calibration):
        raise RuntimeError("No calibration stars")

    # Fit linear function with fixed slope, using sigma-clipping:
    model = models.Linear1D(slope=1, fixed={'slope': True})
    fitter = fitting.FittingWithOutlierRemoval(fitting.LinearLSQFitter(), sigma_clip, sigma=3.0)
    best_fit, sigma_clipped = fitter(model, x, y, weights=weights)

    # Extract zero-point and estimate its error using a single weighted fit:
    # I don't know why there is not an error-estimate attached directly to the Parameter?
    zp = -1 * best_fit.intercept.value  # Negative, because that is the way zeropoints are usually defined

    weights[sigma_clipped] = 0  # Trick to make following expression simpler
    n_weights = len(weights.nonzero()[0])
    if n_weights > 1:
        zp_error = np.sqrt(n_weights * nansum(weights * (y - best_fit(x)) ** 2) / nansum(weights) / (n_weights - 1))
    else:
        zp_error = np.NaN
    logger.info('Leastsquare ZP = %.3f, ZP_error = %.3f', zp, zp_error)

    # Determine sigma clipping sigma according to Chauvenet method
    # But don't allow less than sigma = sigmamin, setting to 1.5 for now.
    # Should maybe be 2?
    sigmamin = 1.5
    sig_chauv = sigma_from_Chauvenet(len(x))
    sig_chauv = sig_chauv if sig_chauv >= sigmamin else sigmamin

    # Extract zero point and error using bootstrap method
    nboot = 1000
    logger.info('Running bootstrap with sigma = %.2f and n = %d', sig_chauv, nboot)
    pars = bootstrap_outlier(x, y, yerr, n=nboot, model=model, fitter=fitting.LinearLSQFitter, outlier=sigma_clip,
                             outlier_kwargs={'sigma': sig_chauv}, summary='median', error='bootstrap',
                             return_vals=False)

    zp_bs = pars['intercept'] * -1.0
    zp_error_bs = pars['intercept_error']

    logger.info('Bootstrapped ZP = %.3f, ZP_error = %.3f', zp_bs, zp_error_bs)

    # Check that difference is not large
    zp_diff = 0.4
    if np.abs(zp_bs - zp) >= zp_diff:
        logger.warning("Bootstrap and weighted LSQ ZPs differ by %.2f, "
                       "which is more than the allowed %.2f mag.", np.abs(zp_bs - zp), zp_diff)

    # Add calibrated magnitudes to the photometry table:
    tab['mag'] = mag_inst + zp_bs
    tab['mag_error'] = np.sqrt(mag_inst_err ** 2 + zp_error_bs ** 2)

    # Check that we got valid magnitude photometry:
    if not np.isfinite(tab[0]['mag']) or not np.isfinite(tab[0]['mag_error']):
        raise RuntimeError(f"Target:{target.name} magnitude is undefined.")

    # Update Meta-data:
    tab.meta['zp'] = zp_bs
    tab.meta['zp_error'] = zp_error_bs
    tab.meta['zp_diff'] = np.abs(zp_bs - zp)
    tab.meta['zp_error_weights'] = zp_error

    # Plot:
    if make_fig:
        mag_fig, mag_ax = plt.subplots(1, 1)
        mag_ax.errorbar(x, y, yerr=yerr, fmt='k.')
        mag_ax.scatter(x[sigma_clipped], y[sigma_clipped], marker='x', c='r')
        mag_ax.plot(x, best_fit(x), color='g', linewidth=3)
        mag_ax.set_xlabel('Catalog magnitude')
        mag_ax.set_ylabel('Instrumental magnitude')

        return tab, mag_fig, mag_ax
    return tab, None, None


### import result_model###

# import image.py

In [8]:
from __future__ import annotations
from enum import Enum
import numpy as np
from numpy.typing import NDArray
from dataclasses import dataclass
import warnings
from typing import Union
from astropy.time import Time
from astropy.wcs import WCS, FITSFixedWarning
from typing import Tuple,  Dict, Any, Optional, TypeGuard
#from .utilities import create_logger
logger = create_logger()

@dataclass
class InstrumentDefaults:
    """
    Default radius and FWHM for an instrument in arcseconds.
    """
    radius: float = 10
    fwhm: float = 6.0   # Best initial guess
    fwhm_min: float = 3.5
    fwhm_max: float = 18.0


@dataclass
class FlowsImage:
    image: np.ndarray
    header: Dict
    mask: Optional[np.ndarray] = None
    peakmax: Optional[float] = None
    exptime: Optional[float] = None
    instrument_defaults: Optional[InstrumentDefaults] = None
    site: Optional[Dict[str, Any]] = None
    obstime: Optional[Time] = None
    photfilter: Optional[str] = None
    wcs: Optional[WCS] = None
    fwhm: Optional[float] = None
    fid: Optional[int] = None  # FileID of this image
    template_fid: Optional[int] = None  # Template file ID if exists in same band.

    clean: Optional[np.ma.MaskedArray] = None
    subclean: Optional[np.ma.MaskedArray] = None
    error: Optional[np.ma.MaskedArray] = None

    def __post_init__(self) -> None:
        self.shape = self.image.shape
        self.wcs = self.create_wcs()
        # Create mask
        self.initialize_mask()

    def initialize_mask(self) -> None:
        self.update_mask(self.mask)

    def check_finite(self) -> None:
        if self.ensure_mask(self.mask):
            self.mask |= ~np.isfinite(self.image)

    def mask_non_linear(self) -> None:
        if self.peakmax is None:
            return
        if self.ensure_mask(self.mask):
            self.mask |= self.image >= self.peakmax

    def ensure_mask(self, mask: Optional[np.ndarray]) -> TypeGuard[NDArray[np.bool_]]:
        if mask is None:
            self.mask = np.zeros_like(self.image, dtype='bool')
        return True

    def update_mask(self, mask) -> None:
        self.mask = mask
        self.check_finite()
        self.mask_non_linear()

    def create_wcs(self) -> WCS:
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=FITSFixedWarning)
            return WCS(header=self.header, relax=True)

    def create_masked_image(self) -> None:
        """Warning: this is destructive and will overwrite image data setting masked values to NaN"""
        self.image[self.mask] = np.NaN
        self.clean = np.ma.masked_array(data=self.image, mask=self.mask, copy=False)

    def set_edge_rows_to_value(self, y: Tuple[float] = None, value: Union[int, float, np.float64] = 0) -> None:
        if y is None:
            pass
        for row in y:
            self.image[row] = value

    def set_edge_columns_to_value(self, x: Tuple[float] = None, value: Union[int, float, np.float64] = 0) -> None:
        if x is None:
            pass
        for col in x:
            self.image[:, col] = value

    @staticmethod
    def get_edge_mask(img: np.ndarray, value: Union[int, float, np.float64] = 0):
        """
        Create boolean mask of given value near edge of image.

        Parameters:
            img (ndarray): image with values for masking.
            value (float): Value to detect near edge. Default=0.

        Returns:
            ndarray: Pixel mask with given values on the edge of image.

        .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
        """

        mask1 = (img == value)
        mask = np.zeros_like(img, dtype='bool')

        # Mask entire rows and columns which are only the value:
        mask[np.all(mask1, axis=1), :] = True
        mask[:, np.all(mask1, axis=0)] = True

        # Detect "uneven" edges column-wise in image:
        a = np.argmin(mask1, axis=0)
        b = np.argmin(np.flipud(mask1), axis=0)
        for col in range(img.shape[1]):
            if mask1[0, col]:
                mask[:a[col], col] = True
            if mask1[-1, col]:
                mask[-b[col]:, col] = True

        # Detect "uneven" edges row-wise in image:
        a = np.argmin(mask1, axis=1)
        b = np.argmin(np.fliplr(mask1), axis=1)
        for row in range(img.shape[0]):
            if mask1[row, 0]:
                mask[row, :a[row]] = True
            if mask1[row, -1]:
                mask[row, -b[row]:] = True

        return mask

    def apply_edge_mask(self, y: Tuple[int] = None, x: Tuple[int] = None, apply_existing_mask_first: bool = False):
        """
        Masks given rows and columns of image but will replace the current mask! Set apply_existing_mask_first to True
        if the current mask should be kept.
        :param y: Tuple[int] of rows to mask
        :param x: Tuple[int] of columns to mask
        :param apply_existing_mask_first: Whether to apply the existing mask to image first, before overwriting mask.
        :return: None
        """
        if y is None and x is None:
            logger.debug("(y,x) was None when applying edge mask. Edge was not actually masked.")

        if apply_existing_mask_first:
            self.create_masked_image()

        if y is not None:
            self.set_edge_rows_to_value(y=y)

        if x is not None:
            self.set_edge_columns_to_value(x=x)

        self.mask = self.get_edge_mask(self.image)
        self.create_masked_image()

class ImageType(Enum):
    raw = 'raw'
    diff = 'diff'

# result_model.py 

In [9]:
from astropy.table import Table
import astropy.units as u
#from .image import FlowsImage
#from .utilities import create_logger
logger = create_logger()
class ResultsTable(Table):

    def add_column_descriptions(self):
        # Descriptions of columns:
        self['used_for_epsf'].description = 'Was object used for building ePSF?'
        self['mag'].description = 'Measured magnitude'
        self['mag'].unit = u.mag
        self['mag_error'].description = 'Error on measured magnitude'
        self['mag_error'].unit = u.mag
        self['flux_aperture'].description = 'Measured flux using aperture photometry'
        self['flux_aperture'].unit = u.count / u.second
        self['flux_aperture_error'].description = 'Error on measured flux using aperture photometry'
        self['flux_aperture_error'].unit = u.count / u.second
        self['flux_psf'].description = 'Measured flux using PSF photometry'
        self['flux_psf'].unit = u.count / u.second
        self['flux_psf_error'].description = 'Error on measured flux using PSF photometry'
        self['flux_psf_error'].unit = u.count / u.second
        self['pixel_column'].description = 'Location on image pixel columns'
        self['pixel_column'].unit = u.pixel
        self['pixel_row'].description = 'Location on image pixel rows'
        self['pixel_row'].unit = u.pixel
        self['pixel_column_psf_fit'].description = 'Measured location on image pixel columns from PSF photometry'
        self['pixel_column_psf_fit'].unit = u.pixel
        self['pixel_column_psf_fit_error'].description = 'Error on measured location on image pixel columns from PSF ' \
                                                         'photometry'
        self['pixel_column_psf_fit_error'].unit = u.pixel
        self['pixel_row_psf_fit'].description = 'Measured location on image pixel rows from PSF photometry'
        self['pixel_row_psf_fit'].unit = u.pixel
        self['pixel_row_psf_fit_error'].description = 'Error on measured location on image pixel rows from PSF ' \
                                                      'photometry'
        self['pixel_row_psf_fit_error'].unit = u.pixel

    def add_metadata(self, tab):
        raise NotImplementedError()
        # # Meta-data:
        # tab.meta['fileid'] = fileid
        # tab.meta['target_name'] = target_name
        # tab.meta['version'] = __version__
        # tab.meta['template'] = None if datafile.get('template') is None else datafile['template']['fileid']
        # tab.meta['diffimg'] = None if datafile.get('diffimg') is None else datafile['diffimg']['fileid']
        # tab.meta['photfilter'] = photfilter
        # tab.meta['fwhm'] = fwhm * u.pixel
        # tab.meta['pixel_scale'] = pixel_scale * u.arcsec / u.pixel
        # tab.meta['seeing'] = (fwhm * pixel_scale) * u.arcsec
        # tab.meta['obstime-bmjd'] = float(image.obstime.mjd)
        # tab.meta['zp'] = zp_bs
        # tab.meta['zp_error'] = zp_error_bs
        # tab.meta['zp_diff'] = np.abs(zp_bs - zp)
        # tab.meta['zp_error_weights'] = zp_error
        # tab.meta['head_wcs'] = head_wcs  # TODO: Are these really useful?
        # tab.meta['used_wcs'] = used_wcs  # TODO: Are these really useful?

    @classmethod
    def make_results_table(cls, ref_table: Table, apphot_tbl: Table, psfphot_tbl: Table, image: FlowsImage):
        results_table = cls(ref_table)
        if len(ref_table) - len(apphot_tbl) == 1:
            results_table.add_row(0)

        psfphot_tbl = ResultsTable.verify_uncertainty_column(psfphot_tbl)

        results_table['flux_aperture'] = apphot_tbl['flux_aperture'] / image.exptime
        results_table['flux_aperture_error'] = apphot_tbl['flux_aperture_error'] / image.exptime
        results_table['flux_psf'] = psfphot_tbl['flux_fit'] / image.exptime
        results_table['flux_psf_error'] = psfphot_tbl['flux_unc'] / image.exptime
        results_table['pixel_column_psf_fit'] = psfphot_tbl['x_fit']
        results_table['pixel_row_psf_fit'] = psfphot_tbl['y_fit']
        results_table['pixel_column_psf_fit_error'] = psfphot_tbl['x_0_unc']
        results_table['pixel_row_psf_fit_error'] = psfphot_tbl['y_0_unc']

        return results_table

    @staticmethod
    def verify_uncertainty_column(tab):
        if "flux_unc" in tab.colnames:
            return tab
        tab['flux_unc'] = tab['flux_fit'] * 0.04  # Assume 4% errors
        logger.warning("Flux uncertainty not found from PSF fit, assuming 4% error.")

In [10]:
#from .magnitudes import instrumental_mag (DONE ABOVE)
#from .result_model import ResultsTable (DONE ABOVE)

warnings.simplefilter('ignore', category=AstropyDeprecationWarning)
from photutils import CircularAperture, CircularAnnulus, aperture_photometry  # noqa: E402
from photutils.psf import EPSFFitter, BasicPSFPhotometry, DAOGroup, extract_stars  # noqa: E402
from photutils.background import MedianBackground  # noqa: E402
import photutils  # noqa: E402

#from .reference_cleaning import References, ReferenceCleaner, InitGuess  # noqa: E402
#from .plots import plt, plot_image  # noqa: E402
#from .version import get_version  # noqa: E402
#from .image import FlowsImage  # noqa: E402
#from .coordinatematch import correct_wcs  # noqa: E402
#from .epsfbuilder import FlowsEPSFBuilder, verify_epsf  # noqa: E402
#from .fileio import DirectoryProtocol, IOManager  # noqa: E402
#from .target import Target  # noqa: E402
#from .background import FlowsBackground  # noqa: E402
#from .utilities import create_logger  # noqa: E402

### ALL THIS IS DONE BELOW###

# reference_cleaning.py

In [11]:
"""
Clean bad source extraction, find and correct WCS.

.. codeauthor:: Emir Karamehmetoglu <emir.k@phys.au.dk>
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
from typing import Dict, Optional, TypeVar, Tuple, Union
from dataclasses import dataclass
import warnings

import numpy as np
from numpy.typing import ArrayLike, NDArray
import astroalign as aa
from astropy.coordinates import SkyCoord
from astropy import wcs
from astropy.stats import sigma_clip, gaussian_fwhm_to_sigma
from astropy.modeling import models, fitting
from astropy.time import Time
from astropy.utils.exceptions import ErfaWarning
import astropy.units as u
from astropy.table import Table
from copy import deepcopy
from bottleneck import nanmedian, nansum, nanmean, replace
from scipy.spatial import KDTree
import pandas as pd  # TODO: Convert to pure numpy implementation
import sep

#from .image import FlowsImage
#from .target import Target
#from .utilities import create_logger
logger = create_logger()

RefTable = TypeVar('RefTable', Dict, ArrayLike, Table)


class MinStarError(RuntimeError):
    pass


@dataclass
class References:
    table: RefTable
    coords: Optional[SkyCoord] = None
    mask: Optional[np.ndarray] = None  # positive mask ie True where we want it.
    xy: Optional[RefTable] = None

    def replace_nans_pm(self) -> None:
        replace(self.table['pm_ra'], np.NaN, 0.)
        replace(self.table['pm_dec'], np.NaN, 0.)

    def make_sky_coords(self, reference_time: float = 2015.5) -> None:
        self.replace_nans_pm()
        self.coords = SkyCoord(ra=self.table['ra'], dec=self.table['decl'], pm_ra_cosdec=self.table['pm_ra'],
                        pm_dec=self.table['pm_dec'], unit='deg', frame='icrs',
                        obstime=Time(reference_time, format='decimalyear'))

    def propagate(self, obstime: Time) -> None:
        if self.coords is None:
            raise AttributeError("References.coords is not defined.")
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", ErfaWarning)
            self.coords = self.coords.apply_space_motion(new_obstime=obstime)

    def copy(self) -> 'References':
        return References(self.__dataclass_fields__)

    @property
    def masked(self, mask: Optional[np.ndarray] = None) -> "References":
        if self.mask is None:
            if mask is None:
                raise AttributeError("No mask defined.")
            self.mask = mask

        copy = self.copy()
        for name in self.__dataclass_fields__:
            if getattr(self, name) is not None:
                setattr(copy, name, getattr(self, name)[self.mask])
        return copy

    def get_xy(self, img_wcs: wcs.WCS) -> None:
        """get pixel coordinates of reference stars"""
        self.xy = img_wcs.all_world2pix(list(zip(self.coords.ra.deg, self.coords.dec.deg)), 0)

    def make_pixel_columns(self, column_name: str ='pixel_column', row_name: str = 'pixel_row') -> None:
        self.table[column_name], self.table[row_name] = list(map(np.array, zip(*self.xy)))

    def _prepend_row(self, row: dict) -> None:
        self.table.insert_row(0, row)

    def add_target(self, target: Target, starid: int = 0) -> None:
        self._prepend_row(target.output_dict(starid=starid))
        if target.pixel_row and target.pixel_column:
            self.xy = np.vstack(((target.pixel_column, target.pixel_row), self.xy))


def use_sep(image: FlowsImage, tries: int = 5, thresh: float = 5.):

    # Use sep to for soure extraction
    sep_background = sep.Background(image.image, mask=image.mask)
    try:
        objects = sep.extract(image.image - sep_background, thresh=thresh, err=sep_background.globalrms,
                              mask=image.mask, deblend_cont=0.1, minarea=9, clean_param=2.0)
    except KeyboardInterrupt as e:
        raise e
    except Exception as e:
        logger.warning("SEP failed, trying again...")
        if tries > 0:
            thresh += 3
            return use_sep(image, tries - 1, thresh * 2)
        else:
            raise e
    sep_references = References(table=Table(objects))
    sep_references.xy = sep_references.table[['x', 'y']]
    sep_references.make_pixel_columns()
    return sep_references

def force_reject_g2d(xarray: ArrayLike, yarray: ArrayLike, image: Union[NDArray, np.ma.MaskedArray],
                     rsq_min: float = 0.5, radius: float = 10, fwhm_guess: float = 6.0, fwhm_min: float = 3.5,
                     fwhm_max: float = 18.0) -> Tuple[np.ma.MaskedArray, ...]:
    """
    It takes a list of x and y coordinates, and a 2D image, and returns a list of x and y coordinates,
    a list of r-squared values, and a boolean mask

    :param xarray: x-coordinates of the stars
    :type xarray: ArrayLike
    :param yarray: y-coordinates of the stars
    :type yarray: ArrayLike
    :param image: the image to be processed
    :type image: Union[NDArray, np.ma.MaskedArray]
    :param rsq_min: The minimum r-squared value for a star to be considered good
    :type rsq_min: float
    :param radius: The radius of the box around the star to fit, defaults to 10
    :type radius: float (optional)
    :param fwhm_guess: The initial guess for the FWHM of the star
    :type fwhm_guess: float
    :param fwhm_min: The minimum FWHM allowed for a star
    :type fwhm_min: float
    :param fwhm_max: The maximum FWHM allowed for a star to be considered good
    :type fwhm_max: float
    :return: masked_fwhms, masked_xys, mask, masked_rsqs
    """
    # Set up 2D Gaussian model for fitting to reference stars:
    g2d = models.Gaussian2D(amplitude=1.0, x_mean=radius, y_mean=radius, x_stddev=fwhm_guess * gaussian_fwhm_to_sigma)
    g2d.amplitude.bounds = (0.1, 2.0)
    g2d.x_mean.bounds = (0.5 * radius, 1.5 * radius)
    g2d.y_mean.bounds = (0.5 * radius, 1.5 * radius)
    g2d.x_stddev.bounds = (fwhm_min * gaussian_fwhm_to_sigma, fwhm_max * gaussian_fwhm_to_sigma)
    g2d.y_stddev.tied = lambda model: model.x_stddev
    g2d.theta.fixed = True
    gfitter = fitting.LevMarLSQFitter()

    # Stars reject
    N = len(xarray)
    fwhms = np.full((N, 2), np.NaN)
    xys = np.full((N, 2), np.NaN)
    rsqs = np.full(N, np.NaN)
    for i, (x, y) in enumerate(zip(xarray, yarray)):
        x = int(np.round(x))
        y = int(np.round(y))
        xmin = max(x - radius, 0)
        xmax = min(x + radius + 1, image.shape[1])
        ymin = max(y - radius, 0)
        ymax = min(y + radius + 1, image.shape[0])

        curr_star = deepcopy(image[ymin:ymax, xmin:xmax])

        edge = np.zeros_like(curr_star, dtype='bool')
        edge[(0, -1), :] = True
        edge[:, (0, -1)] = True
        curr_star -= nanmedian(curr_star[edge])
        curr_star /= np.nanmax(curr_star)

        ypos, xpos = np.indices(curr_star.shape)
        nan_filter = np.ones_like(curr_star, dtype='bool')
        nan_filter = nan_filter & np.isfinite(curr_star)
        if len(curr_star[nan_filter]) < 3:  # Not enough pixels to fit
            logger.debug(f"Not enough pixels to fit star, curr_star[nan_filter]:{curr_star[nan_filter]}")
            rsqs[i] = np.NaN
            fwhms[i] = np.NaN
            continue

        gfit = gfitter(g2d, x=xpos[nan_filter], y=ypos[nan_filter], z=curr_star[nan_filter])

        # Center
        xys[i] = np.array([gfit.x_mean + x - radius, gfit.y_mean + y - radius], dtype='float64')

        # Calculate rsq
        sstot = nansum((curr_star - nanmean(curr_star)) ** 2)
        sserr = nansum(gfitter.fit_info['fvec'] ** 2)
        rsqs[i] = 0 if sstot == 0 else 1.0 - (sserr / sstot)

        # FWHM
        fwhms[i] = gfit.x_fwhm

    masked_xys = np.ma.masked_array(xys, ~np.isfinite(xys))
    masked_rsqs = np.ma.masked_array(rsqs, ~np.isfinite(rsqs))
    mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0)  # Reject Rsq < rsq_min
    # changed
    # masked_xys = masked_xys[mask] # Clean extracted array.
    # to
    masked_xys.mask[~mask] = True
    # don't know if it breaks anything, but it doesn't make sense if
    # len(masked_xys) != len(masked_rsqs) FIXME
    masked_fwhms = np.ma.masked_array(fwhms, ~np.isfinite(fwhms))

    return masked_fwhms, masked_xys, mask.data, masked_rsqs


# --------------------------------------------------------------------------------------------------
def clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, references, min_fwhm_references=2, min_references=6,
                                rsq_min=0.15):
    """
    Clean references and obtain fwhm using RSQ values.

    Parameters:
        masked_fwhms (np.ma.maskedarray): array of fwhms
        masked_rsqs (np.ma.maskedarray): array of rsq values
        references (astropy.table.Table): table of reference stars
        min_fwhm_references: (Default 2) min stars to get a fwhm
        min_references: (Default 6) min stars to aim for when cutting by R2
        rsq_min: (Default 0.15) min rsq value

    .. codeauthor:: Emir Karamehmetoglu <emir.k@phys.au.dk>
    """
    min_references_now = min_references
    rsqvals = np.arange(rsq_min, 0.95, 0.15)[::-1]
    fwhm_found = False
    min_references_achieved = False
    fwhm = np.nan
    # Clean based on R^2 Value
    while not min_references_achieved:
        for rsqval in rsqvals:
            mask = (masked_rsqs >= rsqval) & (masked_rsqs < 1.0)
            nreferences = np.sum(np.isfinite(masked_fwhms[mask]))
            if nreferences >= min_fwhm_references:
                _fwhms_cut_ = np.nanmean(sigma_clip(masked_fwhms[mask], maxiters=100, sigma=2.0))
                if not fwhm_found:
                    fwhm = _fwhms_cut_
                    fwhm_found = True
            if nreferences >= min_references_now:
                references = references[mask]
                min_references_achieved = True
                break
        if min_references_achieved:
            break
        min_references_now = min_references_now - 2
        if (min_references_now < 2) and fwhm_found:
            break
        elif not fwhm_found:
            raise RuntimeError("Could not estimate FWHM")

    if np.isnan(fwhm):
        raise RuntimeError("Could not estimate FWHM")

    # if minimum references not found, then take what we can get with even a weaker cut.
    # TODO: Is this right, or should we grab rsq_min (or even weaker?)
    min_references_now = min_references - 2
    while not min_references_achieved:
        mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0)
        nreferences = np.sum(np.isfinite(masked_fwhms[mask]))
        if nreferences >= min_references_now:
            references = references[mask]
            min_references_achieved = True
        rsq_min = rsq_min - 0.07
        min_references_now = min_references_now - 1

    # Check len of references as this is a destructive cleaning.
    # if len(references) == 2: logger.info('2 reference stars remaining, check WCS and image quality')
    if len(references) < 2:
        raise MinStarError(f"{len(references)} References remaining; could not estimate fwhm.")
    return fwhm, references


# --------------------------------------------------------------------------------------------------
def mkposxy(posx, posy):
    '''Make 2D np array for astroalign'''
    img_posxy = np.array([[x, y] for x, y in zip(posx, posy)], dtype="float64")
    return img_posxy


# --------------------------------------------------------------------------------------------------
def try_transform(source, target, pixeltol=2, nnearest=5, max_stars=50):
    aa.NUM_NEAREST_NEIGHBORS = nnearest
    aa.PIXEL_TOL = pixeltol
    transform, (sourcestars, targetstars) = aa.find_transform(source, target, max_control_points=max_stars)
    return sourcestars, targetstars


# --------------------------------------------------------------------------------------------------
def try_astroalign(source, target, pixeltol=2, nnearest=5, max_stars_n=50):
    # Get indexes of matched stars
    success = False
    try:
        source_stars, target_stars = try_transform(source, target, pixeltol=pixeltol, nnearest=nnearest,
                                                   max_stars=max_stars_n)
        source_ind = np.argwhere(np.in1d(source, source_stars)[::2]).flatten()
        target_ind = np.argwhere(np.in1d(target, target_stars)[::2]).flatten()
        success = True
    except aa.MaxIterError:
        source_ind, target_ind = 'None', 'None'
    return source_ind, target_ind, success


# --------------------------------------------------------------------------------------------------
def min_to_max_astroalign(source, target, fwhm=5, fwhm_min=1, fwhm_max=4, knn_min=5, knn_max=20, max_stars=100,
                          min_matches=3):
    """Try to find matches using astroalign asterisms by stepping through some parameters."""
    # Set max_control_points par based on number of stars and max_stars.
    nstars = max(len(source), len(source))
    if max_stars >= nstars:
        max_stars_list = 'None'
    else:
        if max_stars > 60:
            max_stars_list = (max_stars, 50, 4, 3)
        else:
            max_stars_list = (max_stars, 6, 4, 3)

    # Create max_stars step-through list if not given
    if max_stars_list == 'None':
        if nstars > 6:
            max_stars_list = (nstars, 5, 3)
        elif nstars > 3:
            max_stars_list = (nstars, 3)

    pixeltols = np.linspace(int(fwhm * fwhm_min), int(fwhm * fwhm_max), 4, dtype=int)
    nearest_neighbors = np.linspace(knn_min, min(knn_max, nstars), 4, dtype=int)

    success = False
    for max_stars_n in max_stars_list:
        for pixeltol in pixeltols:
            for nnearest in nearest_neighbors:
                source_ind, target_ind, success = try_astroalign(source, target, pixeltol=pixeltol, nnearest=nnearest,
                                                                 max_stars_n=max_stars_n)
                if success:
                    if len(source_ind) >= min_matches:
                        return source_ind, target_ind, success
                    else:
                        success = False
    return 'None', 'None', success


# --------------------------------------------------------------------------------------------------
def kdtree(source, target, fwhm=5, fwhm_max=4, min_matches=3):
    '''Use KDTree to get nearest neighbor matches within fwhm_max*fwhm distance'''

    # Use KDTree to rapidly efficiently query nearest neighbors

    tt = KDTree(target)
    st = KDTree(source)
    matches_list = st.query_ball_tree(tt, r=fwhm * fwhm_max)

    # indx = []
    targets = []
    sources = []
    for j, (sstar, match) in enumerate(zip(source, matches_list)):
        if np.array(target[match]).size != 0:
            targets.append(match[0])
            sources.append(j)
    sources = np.array(sources, dtype=int)
    targets = np.array(targets, dtype=int)

    # Return indexes of matches
    return sources, targets, len(sources) >= min_matches


# --------------------------------------------------------------------------------------------------
def get_new_wcs(extracted_ind, extracted_stars, clean_references, ref_ind, obstime, rakey='ra_obs', deckey='decl_obs'):
    targets = (extracted_stars[extracted_ind][:, 0], extracted_stars[extracted_ind][:, 1])

    c = SkyCoord(ra=clean_references[rakey][ref_ind], dec=clean_references[deckey][ref_ind], frame='icrs',
                 obstime=obstime)
    return wcs.utils.fit_wcs_from_points(targets, c)


def make_rsq_mask(masked_rsqs: np.ma.MaskedArray) -> np.ndarray:
    # Switching to pandas for easier selection
    df = pd.DataFrame(masked_rsqs, columns=['rsq'])
    return df.sort_values('rsq', ascending=False).dropna().index.values


def get_clean_references(reference_table: RefTable, masked_rsqs: np.ma.MaskedArray, min_references_ideal: int = 6,
                         min_references_abs: int = 3, rsq_min: float = 0.15, rsq_ideal: float = 0.5,
                         keep_max: int = 100, rescue_bad: bool = True) -> Tuple[RefTable, np.ndarray]:
    # Greedy first try
    mask = (masked_rsqs >= rsq_ideal) & (masked_rsqs < 1.0)
    mask = ~mask.data | mask.mask  # masked out of range values OR non-finite values
    masked_rsqs.mask = mask
    rsq_mask_index = make_rsq_mask(masked_rsqs)[:keep_max]
    if len(rsq_mask_index) >= min_references_ideal:
        return reference_table[rsq_mask_index], rsq_mask_index

    # Desperate second try
    mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0)
    mask = ~mask.data | mask.mask
    masked_rsqs.mask = mask
    rsq_mask_index = make_rsq_mask(masked_rsqs)[:min_references_ideal]
    if len(rsq_mask_index) >= min_references_abs:
        return reference_table[rsq_mask_index], rsq_mask_index
    if not rescue_bad:
        raise MinStarError(f'Less than {min_references_abs} clean stars and rescue_bad = False')

    # Extremely desperate last ditch attempt i.e. "rescue bad"
    mask = (masked_rsqs >= 0.02) & (masked_rsqs < 1.0)
    mask = ~mask.data | mask.mask
    masked_rsqs.mask = mask
    rsq_mask_index = make_rsq_mask(masked_rsqs)[:min_references_ideal]
    if len(rsq_mask_index) < 2:
        raise MinStarError('Less than 2 clean stars.')
    return reference_table[rsq_mask_index], rsq_mask_index  # Return if len >= 2


class ReferenceCleaner:

    def __init__(self, image: FlowsImage, references: References, rsq_min: float = 0.3,
                 min_references_ideal: int = 6, min_references_abs: int = 3):
        self.image = image
        self.references = references
        self.rsq_min = rsq_min
        self.min_references_ideal = min_references_ideal
        self.min_references_abs = min_references_abs
        self.gaussian_xys: Optional[np.ndarray] = None  # gaussian pixel positions

    def _clean_extracted_stars(self, x: Optional[ArrayLike] = None,
                               y: Optional[ArrayLike] = None) -> Tuple[np.ma.MaskedArray, ...]:
        """
        Clean extracted stars.
        :return: Tuple of masked_fwhms, masked_ref_xys, rsq_mask, masked_rsqs
        """
        # use instrument_defaults for initial guess of FWHM
        radius = self.image.instrument_defaults.radius
        fwhm_guess = self.image.instrument_defaults.fwhm
        fwhm_min = self.image.instrument_defaults.fwhm_min
        fwhm_max = self.image.instrument_defaults.fwhm_max
        useimage = self.image.subclean if self.image.subclean is not None else self.image.clean

        # Clean the references
        x = x if x is not None else self.references.table['pixel_column']
        y = y if y is not None else self.references.table['pixel_row']
        return force_reject_g2d(x, y, useimage, radius=radius, fwhm_guess=fwhm_guess, rsq_min=self.rsq_min,
                                fwhm_max=fwhm_max, fwhm_min=fwhm_min)

    def set_gaussian_xys(self, masked_ref_xys: np.ma.MaskedArray, old_references: RefTable,
                         new_references: RefTable) -> None:
        xy = [tuple(masked_ref_xys[old_references['starid'] == ref['starid']].data[0]) for ref in new_references]
        self.gaussian_xys = np.array(xy)

    def clean_references(self, references: References = None) -> Tuple[References, float]:
        if references is None:
            references = self.references

        # Clean the references
        masked_fwhms, masked_ref_xys, rsq_mask, masked_rsqs = self._clean_extracted_stars(
            references.table['pixel_column'],
            references.table['pixel_row'])


        # Use R^2 to more robustly determine initial FWHM guess.
        # This cleaning is good when we have FEW references.
        fwhm, fwhm_clean_references = clean_with_rsq_and_get_fwhm(
            masked_fwhms, masked_rsqs, references.table, min_fwhm_references=2,
            min_references=self.min_references_abs, rsq_min=self.rsq_min)
        logger.info('Initial FWHM guess is %f pixels', fwhm)

        # Final clean of wcs corrected references
        logger.info("Number of references before final cleaning: %d", len(references.table))
        logger.debug('Masked R^2 values: %s', masked_rsqs[rsq_mask].data)
        # Get references cleaned and ordered by R^2:
        ordered_cleaned_references, order_index = get_clean_references(references.table, masked_rsqs, rsq_ideal=0.8)
        ordered_coords = None if references.coords is None else references.coords[order_index]
        ordered_xy = None if references.xy is None else references.xy[order_index]
        ordered_cleaned_references = References(table=ordered_cleaned_references, coords=ordered_coords, xy=ordered_xy)
        logger.info("Number of references after final cleaning: %d", len(ordered_cleaned_references.table))

        # Save Gaussian XY positions before returning
        self.set_gaussian_xys(masked_ref_xys, references.table, ordered_cleaned_references.table)
        return ordered_cleaned_references, fwhm

    def make_sep_clean_references(self) -> References:
        """
        Make a clean reference catalog using SExtractor.
        """
        image = self.image
        # Get the SExtractor references from the image
        sep_references = use_sep(image)

        # Clean extracted stars
        _, masked_sep_xy, sep_mask, masked_sep_rsqs = self._clean_extracted_stars(
            sep_references.table['x'], sep_references.table['y'])

        sep_references.mask = sep_mask
        return sep_references.masked

    def mask_edge_and_target(self, target_coords: SkyCoord, hsize: int = 10,
                         target_distance_lim: u.quantity.Quantity = 10 * u.arcsec) -> References:
        """
        Clean the references by removing references that are too close to the target.
        """
        image_shape = self.image.shape

        # Make mask
        mask = (target_coords.separation(self.references.coords) > target_distance_lim) & (
            self.references.table['pixel_column'] > hsize) & (self.references.table['pixel_column'] < (image_shape[1] - 1 - hsize)) & (
            self.references.table['pixel_row'] > hsize) & (self.references.table['pixel_row'] < (image_shape[0] - 1 - hsize))
        self.references.mask = mask

        # Make new clean references
        return self.references.masked
        # return References(table=self.references.masked, mask=mask,
        #                  coords=self.references.coords[mask],
        #                  xy=self.references.xy[mask])


@dataclass(frozen=True)
class InitGuess:
    clean_references: References
    target_row: int = 0
    diff_row: Optional[int] = None

    @property
    def init_guess_full(self) -> Table:
        return Table(self.clean_references.xy, names=['x_0', 'y_0'])

    @property
    def init_guess_target(self) -> Table:
        return Table(self.init_guess_full[self.target_row])

    @property
    def init_guess_diff(self) -> Table:
        if self.diff_row is None:
            raise ValueError('`diff_row` is None, I cannot calculate the initial guesses for the difference image row.')
        return Table(self.init_guess_full[self.diff_row])

    @property
    def init_guess_references(self) -> Table:
        ref_begin = max(self.diff_row, self.target_row) + 1 if self.diff_row is not None else self.target_row + 1
        return self.init_guess_full[ref_begin:]


# plots.py

In [12]:
"""
Plotting utilities.

.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
import logging
import copy
import numpy as np
from bottleneck import allnan
import matplotlib
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import astropy.visualization as viz
#from .utilities import create_logger

logger = create_logger()
# Change to a non-GUI backend since this
# should be able to run on a cluster:
plt.switch_backend('Agg')

# Change the fonts used in plots:
# TODO: Use stylesheets instead of overwriting defaults here
matplotlib.rcParams['font.family'] = 'serif'
matplotlib.rcParams['text.usetex'] = False
matplotlib.rcParams['mathtext.fontset'] = 'dejavuserif'


# --------------------------------------------------------------------------------------------------
def plots_interactive(backend=('Qt5Agg', 'MacOSX', 'Qt4Agg', 'Qt5Cairo', 'TkAgg')):
    """
    Change plotting to using an interactive backend.

    Parameters:
        backend (str or list): Backend to change to. If not provided, will try different
            interactive backends and use the first one that works.

    .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
    """

    logger = logging.getLogger(__name__)
    logger.debug("Valid interactive backends: %s", matplotlib.rcsetup.interactive_bk)

    if isinstance(backend, str):
        backend = [backend]

    for bckend in backend:
        if bckend not in matplotlib.rcsetup.interactive_bk:
            logger.warning("Interactive backend '%s' is not found", bckend)
            continue

        # Try to change the backend, and catch errors
        # it it didn't work:
        try:
            plt.switch_backend(bckend)
        except (ModuleNotFoundError, ImportError):
            pass
        else:
            break


# --------------------------------------------------------------------------------------------------
def plots_noninteractive():
    """
    Change plotting to using a non-interactive backend, which can e.g. be used on a cluster.
    Will set backend to 'Agg'.

    .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
    """
    plt.switch_backend('Agg')


# --------------------------------------------------------------------------------------------------
def plot_image(image, ax=None, scale='log', cmap=None, origin='lower', xlabel=None, ylabel=None, cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)', cbar_ticks=None, cbar_ticklabels=None, cbar_pad=None, cbar_size='4%',
               title=None, percentile=95.0, vmin=None, vmax=None, offset_axes=None, color_bad='k', **kwargs):
    """
    Utility function to plot a 2D image.

    Parameters:
        image (2d array): Image data.
        ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
            Default (None) is to use current active axes.
        scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
            Normalization used to stretch the colormap.
            Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
            and ``'squared'``.
            Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
            Default is ``'log'``.
        origin (str, optional): The origin of the coordinate system.
        xlabel (str, optional): Label for the x-axis.
        ylabel (str, optional): Label for the y-axis.
        cbar (string, optional): Location of color bar.
            Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
            Default is not to create colorbar.
        clabel (str, optional): Label for the color bar.
        cbar_size (float, optional): Fractional size of colorbar compared to axes. Default='4%'.
        cbar_pad (float, optional): Padding between axes and colorbar.
        title (str or None, optional): Title for the plot.
        percentile (float, optional): The fraction of pixels to keep in color-trim.
            If single float given, the same fraction of pixels is eliminated from both ends.
            If tuple of two floats is given, the two are used as the percentiles.
            Default=95.
        cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
        vmin (float, optional): Lower limit to use for colormap.
        vmax (float, optional): Upper limit to use for colormap.
        color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
        kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

    Returns:
        :py:class:`matplotlib.image.AxesImage`: Image from returned
            by :py:func:`matplotlib.pyplot.imshow`.

    .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
    """

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    # noinspection PyUnreachableCode
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh', 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = v

# version.py

In [13]:
"""
Get version identification from git

If the script is located within an active git repository,
git-describe is used to get the version information.
If this is not a git repository, then it is reasonable to
assume that the version is not being incremented and the
version returned will be the release version as read from
the VERSION file, which holds the version information.

The file VERSION will need to be changed by manually. This should be done
before running git tag (set to the same as the version in the tag).

Inspired by
https://github.com/aebrahim/python-git-version

.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""

from subprocess import check_output, CalledProcessError
from os import path, name, devnull, environ, listdir

__all__ = ("get_version",)

CURRENT_DIRECTORY = path.dirname(path.abspath(FILE))
VERSION_FILE = path.join(CURRENT_DIRECTORY, '..', 'VERSION')

# Find the "git" command to run depending on the OS:
GIT_COMMAND = "git"
if name == "nt":
    def find_git_on_windows():
        """find the path to the git executable on windows"""
        # first see if git is in the path
        try:
            check_output(["where", "/Q", "git"])
            # if this command succeeded, git is in the path
            return "git"
        # catch the exception thrown if git was not found
        except CalledProcessError:
            pass
        # There are several locations git.exe may be hiding
        possible_locations = []
        # look in program files for msysgit
        if "PROGRAMFILES(X86)" in environ:
            possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES(X86)"])
        if "PROGRAMFILES" in environ:
            possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES"])
        # look for the github version of git
        if "LOCALAPPDATA" in environ:
            github_dir = "%s/GitHub" % environ["LOCALAPPDATA"]
            if path.isdir(github_dir):
                for subdir in listdir(github_dir):
                    if not subdir.startswith("PortableGit"):
                        continue
                    possible_locations.append("%s/%s/bin/git.exe" % (github_dir, subdir))
        for possible_location in possible_locations:
            if path.isfile(possible_location):
                return possible_location
        # git was not found
        return "git"


    GIT_COMMAND = find_git_on_windows()


def call_git_describe(abbrev=7):
    """return the string output of git desribe"""
    try:
        with open(devnull, "w") as fnull:
            arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev]
            return check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull).decode("ascii").strip()
    except (OSError, CalledProcessError):
        return None


def call_git_getbranch():
    try:
        with open(devnull, "w") as fnull:
            arguments = [GIT_COMMAND, "symbolic-ref", "--short", "HEAD"]
            return check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull).decode("ascii").strip()
    except (OSError, CalledProcessError):
        return None


def format_git_describe(git_str, pep440=False):
    """format the result of calling 'git describe' as a python version"""
    if git_str is None:
        return None
    if "-" not in git_str:  # currently at a tag
        return git_str
    else:
        # formatted as version-N-githash
        # want to convert to version.postN-githash
        git_str = git_str.replace("-", ".post", 1)
        if pep440:  # does not allow git hash afterwards
            return git_str.split("-")[0]
        else:
            return git_str.replace("-g", "+git")


def read_release_version():
    """Read version information from VERSION file"""
    try:
        with open(VERSION_FILE, "r") as infile:
            version = str(infile.read().strip())
        if len(version) == 0:
            version = None
        return version
    except IOError:
        return None

def update_release_version():
    """Update VERSION file"""
    version = get_version(pep440=True)
    with open(VERSION_FILE, "w") as outfile:
        outfile.write(version)

def get_version(pep440=False, include_branch=True):
    """
    Tracks the version number.

    The file VERSION holds the version information. If this is not a git
    repository, then it is reasonable to assume that the version is not
    being incremented and the version returned will be the release version as
    read from the file.

    However, if the script is located within an active git repository,
    git-describe is used to get the version information.

    The file VERSION will need to be changed by manually. This should be done
    before running git tag (set to the same as the version in the tag).

    Parameters:
        pep440 (bool): When True, this function returns a version string suitable for
        a release as defined by PEP 440. When False, the githash (if
        available) will be appended to the version string.

    Returns:
        string: Version sting.
    """

    git_version = format_git_describe(call_git_describe(), pep440=pep440)
    if git_version is None:  # not a git repository
        return read_release_version()

    if include_branch:
        git_branch = call_git_getbranch()
        if git_branch is not None:
            git_version = git_branch + '-' + git_version

    return git_version

if __name__ == "__main__":
    print(get_version())

test1-v1.0.0.post30+git657a5c4


# wcs.py

In [14]:
"""
WCS tools.

.. codeauthor:: Simon Holmbo <simonholmbo@phys.au.dk>
"""
from copy import deepcopy
import numpy as np
import astropy.wcs
from scipy.optimize import minimize
from scipy.spatial.transform import Rotation


class WCS2:
    """
    Manipulate WCS solution.

    Initialize
    ----------
    wcs = WCS2(x, y, ra, dec, scale, mirror, angle)
    wcs = WCS2.from_matrix(x, y, ra, dec, matrix)
    wcs = WCS2.from_points(list(zip(x, y)), list(zip(ra, dec)))
    wcs = WCS2.from_astropy_wcs(astropy.wcs.WCS())

    ra, dec and angle should be in degrees
    scale should be in arcsec/pixel
    matrix should be the PC or CD matrix

    Examples
    --------
    Adjust x, y offset:
    wcs.x += delta_x
    wcs.y += delta_y

    Get scale and angle:
    print(wcs.scale, wcs.angle)

    Change an astropy.wcs.WCS (wcs) angle
    wcs = WCS2(wcs)(angle=new_angle).astropy_wcs

    Adjust solution with points
    wcs.adjust_with_points(list(zip(x, y)), list(zip(ra, dec)))
    """

    # ----------------------------------------------------------------------------------------------
    def __init__(self, x, y, ra, dec, scale, mirror, angle):
        self.x, self.y = x, y
        self.ra, self.dec = ra, dec
        self.scale = scale
        self.mirror = mirror
        self.angle = angle

    # ----------------------------------------------------------------------------------------------
    @classmethod
    def from_matrix(cls, x, y, ra, dec, matrix):
        '''Initiate the class with a matrix.'''

        assert np.shape(matrix) == (2, 2), 'Matrix must be 2x2'

        scale, mirror, angle = cls._decompose_matrix(matrix)

        return cls(x, y, ra, dec, scale, mirror, angle)

    # ----------------------------------------------------------------------------------------------
    @classmethod
    def from_points(cls, xy, rd):
        """Initiate the class with at least pixel + sky coordinates."""

        assert np.shape(xy) == np.shape(rd) == (len(xy), 2) and len(
            xy) > 2, 'Arguments must be lists of at least 3 sets of coordinates'

        xy, rd = np.array(xy), np.array(rd)

        x, y, ra, dec, matrix = cls._solve_from_points(xy, rd)
        scale, mirror, angle = cls._decompose_matrix(matrix)

        return cls(x, y, ra, dec, scale, mirror, angle)

    # ----------------------------------------------------------------------------------------------
    @classmethod
    def from_astropy_wcs(cls, astropy_wcs):
        """Initiate the class with an astropy.wcs.WCS object."""

        if not isinstance(astropy_wcs, astropy.wcs.WCS):
            raise ValueError('Must be astropy.wcs.WCS')

        (x, y), (ra, dec) = astropy_wcs.wcs.crpix, astropy_wcs.wcs.crval
        scale, mirror, angle = cls._decompose_matrix(astropy_wcs.pixel_scale_matrix)

        return cls(x, y, ra, dec, scale, mirror, angle)

    # ----------------------------------------------------------------------------------------------
    def adjust_with_points(self, xy, rd):
        """
        Adjust the WCS with pixel + sky coordinates.

        If one set is given the change will be a simple offset.
        If two sets are given the offset, angle and scale will be derived.
        And if more sets are given a completely new solution will be found.
        """

        assert np.shape(xy) == np.shape(rd) == (len(xy), 2), 'Arguments must be lists of sets of coordinates'

        xy, rd = np.array(xy), np.array(rd)

        self.x, self.y = xy.mean(axis=0)
        self.ra, self.dec = rd.mean(axis=0)

        A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0)
        b[:, 0] *= np.cos(np.deg2rad(rd[:, 1]))

        if len(xy) == 2:

            M = np.diag([[-1, 1][self.mirror], 1])

            def R(t):
                return np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])

            def chi2(x):
                return np.power(A.dot(x[1] / 60 / 60 * R(x[0]).dot(M).T) - b, 2).sum()

            self.angle, self.scale = minimize(chi2, [self.angle, self.scale]).x

        elif len(xy) > 2:
            matrix = np.linalg.lstsq(A, b, rcond=None)[0].T
            self.scale, self.mirror, self.angle = self._decompose_matrix(matrix)

    # ----------------------------------------------------------------------------------------------
    @property
    def matrix(self):

        scale = self.scale / 60 / 60
        mirror = np.diag([[-1, 1][self.mirror], 1])
        angle = np.deg2rad(self.angle)

        matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])

        return scale * matrix @ mirror

    # ----------------------------------------------------------------------------------------------
    @property
    def astropy_wcs(self):
        wcs = astropy.wcs.WCS()
        wcs.wcs.crpix = self.x, self.y
        wcs.wcs.crval = self.ra, self.dec
        wcs.wcs.pc = self.matrix
        return wcs

    # ----------------------------------------------------------------------------------------------
    @staticmethod
    def _solve_from_points(xy, rd):

        (x, y), (ra, dec) = xy.mean(axis=0), rd.mean(axis=0)

        A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0)
        b[:, 0] *= np.cos(np.deg2rad(rd[:, 1]))

        matrix = np.linalg.lstsq(A, b, rcond=None)[0].T

        return x, y, ra, dec, matrix

    # ----------------------------------------------------------------------------------------------
    @staticmethod
    def _decompose_matrix(matrix):

        scale = np.sqrt(np.power(matrix, 2).sum() / 2) * 60 * 60

        if np.argmax(np.power(matrix[0], 2)):
            mirror = True if np.sign(matrix[0, 1]) != np.sign(matrix[1, 0]) else False
        else:
            mirror = True if np.sign(matrix[0, 0]) == np.sign(matrix[1, 1]) else False

        matrix = matrix if mirror else matrix.dot(np.diag([-1, 1]))

        matrix3d = np.eye(3)
        matrix3d[:2, :2] = matrix / (scale / 60 / 60)
        angle = Rotation.from_matrix(matrix3d).as_euler('xyz', degrees=True)[2]

        return scale, mirror, angle

    # ----------------------------------------------------------------------------------------------
    def __setattr__(self, name, value):

        if name == 'ra' and (value < 0 or value >= 360):
            raise ValueError("0 <= R.A. < 360")

        elif name == 'dec' and (value < -180 or value > 180):
            raise ValueError("-180 <= Dec. <= 180")

        elif name == 'scale' and value <= 0:
            raise ValueError("Scale > 0")

        elif name == 'mirror' and not isinstance(value, bool):
            raise ValueError('mirror must be boolean')

        elif name == 'angle' and (value <= -180 or value > 180):
            raise ValueError("-180 < Angle <= 180")

        super().__setattr__(name, value)

    # ----------------------------------------------------------------------------------------------
    def __call__(self, **kwargs):
        '''Make a copy with, or a copy with changes.'''

        keys = ('x', 'y', 'ra', 'dec', 'scale', 'mirror', 'angle')

        if not all(k in keys for k in kwargs):
            raise ValueError('unknown argument(s)')

        obj = deepcopy(self)
        for k, v in kwargs.items():
            obj.__setattr__(k, v)
        return obj

    # ----------------------------------------------------------------------------------------------
    def __repr__(self):
        ra, dec = self.astropy_wcs.wcs_pix2world([(0, 0)], 0)[0]
        return f'WCS2(0, 0, {ra:.4f}, {dec:.4f}, {self.scale:.2f}, {self.mirror}, {self.angle:.2f})'


# coordinatematch.py

In [15]:
"""
Match two sets of coordinates.

.. codeauthor:: Simon Holmbo <simonholmbo@phys.au.dk>
"""
import numpy as np
import time
from itertools import count, islice, chain, product, zip_longest
from astropy.coordinates.angle_utilities import angular_separation
from astropy.coordinates import SkyCoord
import astropy.wcs
from scipy.spatial import cKDTree as KDTree
from networkx import Graph, connected_components
#from .wcs import WCS2
#from flows.utilities import create_logger
#from flows.image import FlowsImage
#from flows import reference_cleaning as refclean
#from flows.target import Target


logger = create_logger()


def correct_wcs(image: FlowsImage, references: refclean.References, target: Target,
                timeout: float = np.inf) -> FlowsImage:
    """
    Correct WCS of image to match the reference image.
    """
    # Start pre-cleaning
    sep_cleaner = refclean.ReferenceCleaner(image, references, rsq_min=0.3)

    # Use Source Extractor to make clean references
    sep_references_clean = sep_cleaner.make_sep_clean_references()
    xy = np.lib.recfunctions.structured_to_unstructured(np.array(sep_references_clean.xy))
    # Find WCS
    logger.info("Finding new WCS solution...")
    head_wcs = str(WCS2.from_astropy_wcs(image.wcs))
    logger.debug('Head WCS: %s', head_wcs)
    # Solve for new WCS
    cm = CoordinateMatch(
        xy=xy,
        rd=np.array(list(zip(references.coords.ra.deg, references.coords.dec.deg))),
        xy_order=np.argsort(np.power(xy - np.array(image.shape[::-1]) / 2, 2).sum(axis=1)),
        rd_order=np.argsort(target.coords.separation(references.coords)),
        xy_nmax=100, rd_nmax=100, maximum_angle_distance=0.002)

    try:
        i_xy, i_rd = map(np.array, zip(*cm(5, 1.5, timeout=timeout)))
    except TimeoutError:
        logger.warning('TimeoutError: No new WCS solution found')
    except StopIteration:
        logger.warning('StopIterationError: No new WCS solution found')
    else:
        logger.info('Found new WCS')
        image.wcs = astropy.wcs.utils.fit_wcs_from_points(np.array(list(zip(*cm.xy[i_xy]))),
                                                          SkyCoord(*map(list, zip(*cm.rd[i_rd])), unit='deg'))
        del i_xy, i_rd

    logger.debug(f'Used WCS: {WCS2.from_astropy_wcs(image.wcs)}')
    return image


# noinspection PyArgumentList
class CoordinateMatch(object):
    def __init__(self, xy, rd, xy_order=None, rd_order=None, xy_nmax=None, rd_nmax=None, n_triangle_packages=10,
                 triangle_package_size=10000, maximum_angle_distance=0.001, distance_factor=1):

        self.xy, self.rd = np.array(xy), np.array(rd)

        self._xy = xy - np.mean(xy, axis=0)
        self._rd = rd - np.mean(rd, axis=0)
        self._rd[:, 0] *= np.cos(np.deg2rad(self.rd[:, 1]))

        xy_n, rd_n = min(xy_nmax, len(xy)), min(rd_nmax, len(rd))

        self.i_xy = xy_order[:xy_n] if xy_order is not None else np.arange(xy_n)
        self.i_rd = rd_order[:rd_n] if rd_order is not None else np.arange(rd_n)

        self.n_triangle_packages = n_triangle_packages
        self.triangle_package_size = triangle_package_size

        self.maximum_angle_distance = maximum_angle_distance
        self.distance_factor = distance_factor

        self.triangle_package_generator = self._sorted_triangle_packages()

        self.i_xy_triangles = list()
        self.i_rd_triangles = list()
        self.parameters = None
        self.neighbours = Graph()

        self.normalizations = type('Normalizations', (object,), dict(ra=0.0001, dec=0.0001, scale=0.002, angle=0.002))

        self.bounds = type('Bounds', (object,),
                           dict(xy=self.xy.mean(axis=0), rd=None, radius=None, scale=None, angle=None))

    # ----------------------------------------------------------------------------------------------
    def set_normalizations(self, ra=None, dec=None, scale=None, angle=None):
        """
        Set normalization factors in the (ra, dec, scale, angle) space.

        Defaults are:
            ra = 0.0001 degrees
            dec = 0.0001 degrees
            scale = 0.002 log(arcsec/pixel)
            angle = 0.002 radians
        """

        if self.parameters is not None:
            raise RuntimeError("can't change normalization after matching is started")

        # TODO: Dont use "assert" here - raise ValueError instead
        assert ra is None or 0 < ra
        assert dec is None or 0 < dec
        assert scale is None or 0 < scale
        assert angle is None or 0 < angle

        self.normalizations.ra = ra if ra is not None else self.normalizations.ra
        self.normalizations.dec = dec if dec is not None else self.normalizations.dec
        self.normalizations.scale = scale if scale is not None else self.normalizations.scale
        self.normalizations.angle = angle if ra is not None else self.normalizations.angle

    # ----------------------------------------------------------------------------------------------
    def set_bounds(self, x=None, y=None, ra=None, dec=None, radius=None, scale=None, angle=None):
        """
        Set bounds for what are valid results.

        Set x, y, ra, dec and radius to specify that the x, y coordinates must be no
        further that the radius [degrees] away from the ra, dec coordinates.
        Set upper and lower bounds on the scale [log(arcsec/pixel)] and/or the angle
        [radians] if those are known, possibly from previous observations with the
        same system.
        """

        if self.parameters is not None:
            raise RuntimeError("can't change bounds after matching is started")

        if [x, y, ra, dec, radius].count(None) == 5:
            # TODO: Dont use "assert" here - raise ValueError instead
            assert 0 <= ra < 360
            assert -180 <= dec <= 180
            assert 0 < radius

            self.bounds.xy = x, y
            self.bounds.rd = ra, dec
            self.bounds.radius = radius

        elif [x, y, ra, dec, radius].count(None) > 0:
            raise ValueError('x, y, ra, dec and radius must all be specified')

        # TODO: Dont use "assert" here - raise ValueError instead
        assert scale is None or 0 < scale[0] < scale[1]
        assert angle is None or -np.pi <= angle[0] < angle[1] <= np.pi

        self.bounds.scale = scale if scale is not None else self.bounds.scale
        self.bounds.angle = angle if angle is not None else self.bounds.angle

    # ----------------------------------------------------------------------------------------------
    def _sorted_triangles(self, pool):
        for i, c in enumerate(pool):
            for i, b in enumerate(pool[:i]):
                for a in pool[:i]:
                    yield a, b, c

    # ----------------------------------------------------------------------------------------------
    def _sorted_product_pairs(self, p, q):
        i_p = np.argsort(np.arange(len(p)))
        i_q = np.argsort(np.arange(len(q)))
        for _i_p, _i_q in sorted(product(i_p, i_q), key=lambda idxs: sum(idxs)):
            yield p[_i_p], q[_i_q]

    # ----------------------------------------------------------------------------------------------
    def _sorted_triangle_packages(self):

        i_xy_triangle_generator = self._sorted_triangles(self.i_xy)
        i_rd_triangle_generator = self._sorted_triangles(self.i_rd)

        i_xy_triangle_slice_generator = (tuple(islice(i_xy_triangle_generator, self.triangle_package_size)) for _ in
                                         count())
        i_rd_triangle_slice_generator = (list(islice(i_rd_triangle_generator, self.triangle_package_size)) for _ in
                                         count())

        for n in count(step=self.n_triangle_packages):

            i_xy_triangle_slice = tuple(filter(None, islice(i_xy_triangle_slice_generator, self.n_triangle_packages)))
            i_rd_triangle_slice = tuple(filter(None, islice(i_rd_triangle_slice_generator, self.n_triangle_packages)))

            if not len(i_xy_triangle_slice) and not len(i_rd_triangle_slice):
                return

            i_xy_triangle_generator2 = self._sorted_triangles(self.i_xy)
            i_rd_triangle_generator2 = self._sorted_triangles(self.i_rd)

            i_xy_triangle_cum = filter(None,
                                       (tuple(islice(i_xy_triangle_generator2, self.triangle_package_size)) for _ in
                                        range(n)))
            i_rd_triangle_cum = filter(None,
                                       (tuple(islice(i_rd_triangle_generator2, self.triangle_package_size)) for _ in
                                        range(n)))

            for i_xy_triangles, i_rd_triangles in chain(filter(None, chain(*zip_longest(  # alternating chain
                    product(i_xy_triangle_slice, i_rd_triangle_cum), product(i_xy_triangle_cum, i_rd_triangle_slice)))),
                                                        self._sorted_product_pairs(i_xy_triangle_slice,
                                                                                   i_rd_triangle_slice)):
                yield np.array(i_xy_triangles), np.array(i_rd_triangles)

    # ----------------------------------------------------------------------------------------------
    def _get_triangle_angles(self, triangles):

        sidelengths = np.sqrt(np.power(triangles[:, (1, 0, 0)] - triangles[:, (2, 2, 1)], 2).sum(axis=2))

        # law of cosines
        angles = np.power(sidelengths[:, ((1, 2), (0, 2), (0, 1))], 2).sum(axis=2)
        angles -= np.power(sidelengths[:, (0, 1, 2)], 2)
        angles /= 2 * sidelengths[:, ((1, 2), (0, 2), (0, 1))].prod(axis=2)

        return np.arccos(angles)

    # ----------------------------------------------------------------------------------------------
    def _solve_for_matrices(self, xy_triangles, rd_triangles):

        n = len(xy_triangles)

        A = xy_triangles - np.mean(xy_triangles, axis=1).reshape(n, 1, 2)
        b = rd_triangles - np.mean(rd_triangles, axis=1).reshape(n, 1, 2)

        matrices = [np.linalg.lstsq(Ai, bi, rcond=None)[0].T for Ai, bi in zip(A, b)]

        return np.array(matrices)

    # ----------------------------------------------------------------------------------------------
    def _extract_parameters(self, xy_triangles, rd_triangles, matrices):

        parameters = []
        for xy_com, rd_com, matrix in zip(xy_triangles.mean(axis=1), rd_triangles.mean(axis=1), matrices):
            # com -> center-of-mass

            cos_dec = np.cos(np.deg2rad(rd_com[1]))
            coordinates = (self.bounds.xy - xy_com).dot(matrix)
            coordinates = coordinates / np.array([cos_dec, 1]) + rd_com

            wcs = WCS2.from_matrix(*xy_com, *rd_com, matrix)

            parameters.append((*coordinates, np.log(wcs.scale), np.deg2rad(wcs.angle)))

        return parameters

    # ----------------------------------------------------------------------------------------------
    def _get_bounds_mask(self, parameters):

        i = np.ones(len(parameters), dtype=bool)
        parameters = np.array(parameters)

        if self.bounds.radius is not None:
            i *= angular_separation(*np.deg2rad(self.bounds.rd),
                                    *zip(*np.deg2rad(parameters[:, (0, 1)]))) <= np.deg2rad(self.bounds.radius)

        if self.bounds.scale is not None:
            i *= self.bounds.scale[0] <= parameters[:, 2]
            i *= parameters[:, 2] <= self.bounds.scale[1]

        if self.bounds.angle is not None:
            i *= self.bounds.angle[0] <= parameters[:, 3]
            i *= parameters[:, 3] <= self.bounds.angle[1]

        return i

    # ----------------------------------------------------------------------------------------------
    def __call__(self, minimum_matches=4, ratio_superiority=1, timeout=60):
        """
        Start the alogrithm.

        Can be run multiple times with different arguments to relax the
        restrictions.

        Example
        --------
        cm = CoordinateMatch(xy, rd)

        lkwargs = [{
            minimum_matches = 20,
            ratio_superiority = 5,
            timeout = 10
        },{
            timeout = 60
        }

        for i, kwargs in enumerate(lkwargs):
            try:
                i_xy, i_rd = cm(**kwargs)
            except TimeoutError:
                continue
            except StopIteration:
                print('Failed, no more stars.')
            else:
                print('Success with kwargs[%d].' % i)
        else:
            print('Failed, timeout.')
        """

        self.parameters = list() if self.parameters is None else self.parameters

        t0 = time.time()

        while time.time() - t0 < timeout:

            # get triangles and derive angles

            i_xy_triangles, i_rd_triangles = next(self.triangle_package_generator)

            xy_angles = self._get_triangle_angles(self._xy[i_xy_triangles])
            rd_angles = self._get_triangle_angles(self._rd[i_rd_triangles])

            # sort triangle vertices based on angles

            i = np.argsort(xy_angles, axis=1)
            i_xy_triangles = np.take_along_axis(i_xy_triangles, i, axis=1)
            xy_angles = np.take_along_axis(xy_angles, i, axis=1)

            i = np.argsort(rd_angles, axis=1)
            i_rd_triangles = np.take_along_axis(i_rd_triangles, i, axis=1)
            rd_angles = np.take_along_axis(rd_angles, i, axis=1)

            # match triangles
            matches = KDTree(xy_angles).query_ball_tree(KDTree(rd_angles), r=self.maximum_angle_distance)
            matches = np.array([(_i_xy, _i_rd) for _i_xy, _li_rd in enumerate(matches) for _i_rd in _li_rd])

            if not len(matches):
                continue

            i_xy_triangles = list(i_xy_triangles[matches[:, 0]])
            i_rd_triangles = list(i_rd_triangles[matches[:, 1]])

            # get parameters of wcs solutions
            matrices = self._solve_for_matrices(self._xy[np.array(i_xy_triangles)], self._rd[np.array(i_rd_triangles)])

            parameters = self._extract_parameters(self.xy[np.array(i_xy_triangles)], self.rd[np.array(i_rd_triangles)],
                                                  matrices)

            # apply bounds if any
            if any([self.bounds.radius, self.bounds.scale, self.bounds.angle]):
                mask = self._get_bounds_mask(parameters)

                i_xy_triangles = np.array(i_xy_triangles)[mask].tolist()
                i_rd_triangles = np.array(i_rd_triangles)[mask].tolist()
                parameters = np.array(parameters)[mask].tolist()

            # normalize parameters
            normalization = [getattr(self.normalizations, v) for v in ('ra', 'dec', 'scale', 'angle')]
            normalization[0] *= np.cos(np.deg2rad(self.rd[:, 1].mean(axis=0)))
            parameters = list(parameters / np.array(normalization))

            # match parameters
            neighbours = KDTree(parameters).query_ball_tree(KDTree(self.parameters + parameters),
                                                            r=self.distance_factor)
            neighbours = np.array([(i, j) for i, lj in enumerate(neighbours, len(self.parameters)) for j in lj])
            neighbours = list(neighbours[(np.diff(neighbours, axis=1) < 0).flatten()])

            if not len(neighbours):
                continue

            self.i_xy_triangles += i_xy_triangles
            self.i_rd_triangles += i_rd_triangles
            self.parameters += parameters
            self.neighbours.add_edges_from(neighbours)

            # get largest neighborhood
            communities = list(connected_components(self.neighbours))
            c1 = np.array(list(max(communities, key=len)))
            i = np.unique(np.array(self.i_xy_triangles)[c1].flatten(), return_index=True)[1]

            if ratio_superiority > 1 and len(communities) > 1:
                communities.remove(set(c1))
                c2 = np.array(list(max(communities, key=len)))
                _i = np.unique(np.array(self.i_xy_triangles)[c2].flatten())

                if len(i) / len(_i) < ratio_superiority:
                    continue

            if len(i) >= minimum_matches:
                break

        else:
            raise TimeoutError

        i_xy = np.array(self.i_xy_triangles)[c1].flatten()[i]
        i_rd = np.array(self.i_rd_triangles)[c1].flatten()[i]

        return list(zip(i_xy, i_rd))

# epsfbuilder.py

In [16]:
"""
Photutils hack for EPSF building

.. codeauthor:: Simon Holmbo <simonholmbo@phys.au.dk>
"""
import time
import numpy as np
from scipy.interpolate import griddata, UnivariateSpline
import photutils.psf
from typing import List, Tuple
#from flows.utilities import create_logger
logger = create_logger()

class FlowsEPSFBuilder(photutils.psf.EPSFBuilder):
    def _create_initial_epsf(self, stars):
        epsf = super()._create_initial_epsf(stars)
        epsf.origin = None

        X, Y = np.meshgrid(*map(np.arange, epsf.shape[::-1]))

        X = X / epsf.oversampling[0] - epsf.x_origin
        Y = Y / epsf.oversampling[1] - epsf.y_origin

        self._epsf_xy_grid = X, Y

        return epsf

    def _resample_residual(self, star, epsf):
        # max_dist = .5 / np.sqrt(np.sum(np.power(epsf.oversampling, 2)))

        # star_points = list(zip(star._xidx_centered, star._yidx_centered))
        # epsf_points = list(zip(*map(np.ravel, self._epsf_xy_grid)))

        # star_tree = cKDTree(star_points)
        # dd, ii = star_tree.query(epsf_points, distance_upper_bound=max_dist)
        # mask = np.isfinite(dd)

        # star_data = np.full_like(epsf.data, np.nan)
        # star_data.ravel()[mask] = star._data_values_normalized[ii[mask]]

        star_points = list(zip(star._xidx_centered, star._yidx_centered))
        star_data = griddata(star_points, star._data_values_normalized, self._epsf_xy_grid)

        return star_data - epsf._data

    def __call__(self, *args, **kwargs):
        t0 = time.time()

        epsf, stars = super().__call__(*args, **kwargs)

        epsf.fit_info = dict(n_iter=len(self._epsf), max_iters=self.maxiters, time=time.time() - t0, )

        return epsf, stars


def verify_epsf(epsf: photutils.psf.EPSFModel) -> Tuple[bool, List[float]]:
    fwhms = []
    epsf_ok = True
    for a in (0, 1):
        # Collapse the PDF along this axis:
        profile = epsf.data.sum(axis=a)
        itop = profile.argmax()
        poffset = profile[itop] / 2

        # Run a spline through the points, but subtract half of the peak value, and find the roots:
        # We have to use a cubic spline, since roots() is not supported for other splines
        # for some reason
        profile_intp = UnivariateSpline(np.arange(0, len(profile)), profile - poffset, k=3, s=0, ext=3)
        lr = profile_intp.roots()

        # Do some sanity checks on the ePSF:
        # It should pass 50% exactly twice and have the maximum inside that region.
        # I.e. it should be a single gaussian-like peak
        if len(lr) != 2 or itop < lr[0] or itop > lr[1]:
            logger.error(f"EPSF is not a single gaussian-like peak along axis {a}")
            epsf_ok = False
        else:
            axis_fwhm = lr[1] - lr[0]
            fwhms.append(axis_fwhm)
    return epsf_ok, fwhms

# filoio.py

In [17]:
import os
from pathlib import Path
from typing import Optional, Protocol, Dict, TypeVar, Union
from configparser import ConfigParser
from bottleneck import allnan
from tendrils import api, utils
#from .load_image import load_image
#from .utilities import create_logger
#from .target import Target
#from .image import FlowsImage
#from . import reference_cleaning as refclean
#from .filters import get_reference_filter
logger = create_logger()

DataFileType = TypeVar("DataFileType", bound=dict)

class DirectoryProtocol(Protocol):
    archive_local: str
    output_folder: str

    def set_output_dirs(self, target_name: str, fileid: int, create: bool = True) -> None:
        raise NotImplementedError

    def image_path(self, image_path: str) -> str:
        raise NotImplementedError

    @property
    def photometry_path(self) -> str:
        raise NotImplementedError

    def save_as(self, filename: str) -> str:
        raise NotImplementedError

    @property
    def log_path(self) -> str:
        raise NotImplementedError

    @classmethod
    def from_fid(cls, fid: int) -> 'DirectoryProtocol':
        raise NotImplementedError

class Directories:
    """
    Class for creating input and output directories, given a configparser instance.
    Overwrite archive_local or output_folder to manually place elsewhere.
    """
    archive_local = LOCAL_ARCHIEVE 
    output_folder = FOLDER_OUTPUT
    
    #######
        #OLD CODE:
    #archive_local: Optional[str] = None
    #output_folder: Optional[str] = None

    #######

    def __init__(self, config: Optional[ConfigParser] = None):
        self.config = config or utils.load_config()

    def set_output_dirs(self, target_name: str, fileid: int, create: bool = True,
                        output_folder_root: Optional[str] = None) -> None:
        """
        The function is meant to be called from within a context where a
        target_name and fileid are defined, so that the output_folder
        can be created appropriately.

        Parameters:
            target_name (str): Target name.
            fileid (int): The fileid of the file being processed
            create (bool): Whether to create the output_folder if it doesn't exist.
            output_folder_root (str): Overwrite the root directory for output.
        """

        # Checking for None allows manual declarations to not be overwritten.
        if self.archive_local is None:
            self.archive_local = self._set_archive()

        self.output_folder = self._set_output(target_name, fileid, output_folder_root)

        # Create output folder if necessary.
        if create:
            os.makedirs(self.output_folder, exist_ok=True)
            logger.info("Placing output in '%s'", self.output_folder)

    def _set_archive(self) -> Optional[str]:
        archive_local = self.config.get('photometry', 'archive_local', fallback=None)
        if archive_local is not None and not os.path.isdir(archive_local):
            raise FileNotFoundError("ARCHIVE is not available: " + archive_local)
        logger.info(f"Using data from: {archive_local}.")
        return archive_local

    def _set_output(self, target_name: str, fileid: int, output_folder_root: Optional[str] = None) -> str:
        """
        Directory for output, defaults to current
        directory if config is invalid or empty.
        """
        output_folder_root = self.config.get('photometry', 'output', fallback='.') if output_folder_root is None \
            else output_folder_root
        output_folder = os.path.join(output_folder_root, target_name, f'{fileid:05d}')
        return output_folder

    def image_path(self, image_path: str) -> str:
        return os.path.join(self.archive_local, image_path)

    @property
    def photometry_path(self) -> str:
        return os.path.join(self.output_folder, 'photometry.ecsv')

    def save_as(self, filename: str) -> str:
        return os.path.join(self.output_folder, filename)

    @property
    def log_path(self) -> str:
        return os.path.join(self.output_folder, "photometry.log")

    @classmethod
    def from_fid(cls, fid: int, config: Optional[ConfigParser] = None, create: bool = True,
                 datafile: Optional[Dict] = None) -> DirectoryProtocol:
        instance = cls(config)
        datafile = datafile or api.get_datafile(fid)
        instance.set_output_dirs(datafile['target_name'], fid, create)
        return instance


class DirectoriesDuringTest:
    """
    Directory class in testing config.
    """
    archive_local = None
    output_folder = None

    def __init__(self, input_dir: Union[str, Path], output_dir: Union[str, Path]):
        self.input_dir = input_dir
        self.output_dir = output_dir

    def set_output_dirs(self, target_name: str, fileid: int, create: bool = True) -> None:
        self.output_folder = os.path.join(self.output_dir, target_name, f'{fileid:05d}')
        if create:
            os.makedirs(self.output_folder, exist_ok=True)

    def image_path(self, image_path: str) -> str:
        return os.path.join(self.input_dir+image_path)

    @property
    def photometry_path(self) -> str:
        return os.path.join(self.output_folder, 'photometry.ecsv')

    def save_as(self, filename: str) -> str:
        return os.path.join(self.output_folder, filename)

    @property
    def log_path(self) -> str:
        return os.path.join(self.output_folder, "photometry.log")

    @classmethod
    def from_fid(cls, fid: int, input_dir: Union[str, Path] = './test/input',
                 output_dir: Union[str, Path] = './test/output', datafile: Optional[Dict] = None) -> 'DirectoryProtocol':
        instance = cls(input_dir, output_dir)
        datafile = datafile or api.get_datafile(fid)
        instance.set_output_dirs(datafile['target_name'], fid)
        return instance


class IOManager:
    """
    Implement a runner to shuffle data.
    """

    def __init__(self, target: Target,
                 directories: DirectoryProtocol,
                 datafile: Dict):
        self.target = target
        self.directories = directories
        self.output_folder = directories.output_folder
        self.archive_local = directories.archive_local
        self.datafile = datafile
        self.diff_image_exists = False

    def _load_image(self, image_path: str) -> FlowsImage:
        """
        Load an image from a file.
        """
        # Load the image from the FITS file:
        image = load_image(self.directories.image_path(image_path), target_coord=self.target.coords)
        return image

    def load_science_image(self, image_path: Optional[str] = None) -> FlowsImage:
        image_path = image_path or self.datafile['path']
        image = self._load_image(image_path)
        logger.info("Load image '%s'", self.directories.image_path(image_path))
        image.fid = self.datafile['fileid']
        image.template_fid = None if self.datafile.get('template') is None else self.datafile['template']['fileid']
        return image

    def get_filter(self):
        return get_reference_filter(self.target.photfilter)

    def load_references(self, catalog: Optional[Dict] = None) -> refclean.References:
        use_filter = self.get_filter()
        references = api.get_catalog(self.target.name)['references'] if catalog is None else catalog['references']
        references.sort(use_filter)
        # Check that there actually are reference stars in that filter:
        if allnan(references[use_filter]):
            raise ValueError("No reference stars found in current photfilter.")
        return refclean.References(table=references)

    def load_diff_image(self) -> Optional[FlowsImage]:
        diffimage_df = self.datafile.get('diffimg', None)
        diffimage_path = diffimage_df.get('path', None) if diffimage_df else None
        self.diff_image_exists = diffimage_path is not None
        if diffimage_df and not self.diff_image_exists:
            logger.warning("Diff image present but without path, skipping diff image photometry")
        if self.diff_image_exists:
            diffimage = self._load_image(diffimage_path)
            logger.info("Load difference image '%s'", self.directories.image_path(diffimage_path))
            diffimage.fid = diffimage_df['fileid']
            return diffimage
        return None

    @classmethod
    def from_fid(cls, fid: int, directories: Optional[DirectoryProtocol] = None,
                 create_directories: bool = True, datafile: Optional[Dict] = None) -> 'IOManager':
        """
        Create an IOManager from a fileid.
        """
        datafile = datafile or api.get_datafile(fid)
        target = Target.from_fid(fid=fid, datafile=datafile)
        directories = directories or Directories.from_fid(fid=fid, create=create_directories, datafile=datafile)
        return cls(target=target, directories=directories, datafile=datafile)


def del_dir(target: Union[Path, str],
            only_if_empty: bool = False,
            delete_parent_if_file: bool = False) -> None:
    """
    Delete a given directory and its subdirectories.
    Ex: If filename is the path to a file: `del_dir(Path(filename).parent)`

    :param target: The directory to delete
    :param only_if_empty: Raise RuntimeError if any file is found in the tree
    :param delete_parent_if_file: Delete the parent directory if it is a file
    """
    target = Path(target).expanduser()
    if not target.exists():
        logger.warning("Not deleted: Directory '%s' does not exist", target)
        return
    if not target.is_dir() and delete_parent_if_file:
        logger.warning("Not deleted: '%s' is not a directory and delete_parent_if_file was False", target)
        return
    for p in sorted(target.glob('**/*'), reverse=True):  # This should also work on Windows (fingers crossed).
        if not p.exists():
            continue
        p.chmod(0o666)  # This should also work on Windows but we should have read/write permissions anyway
        if p.is_dir():
            p.rmdir()
        else:
            if only_if_empty:
                raise RuntimeError(f'{p.parent} is not empty!')
            p.unlink()
    target.rmdir()


# background.py

In [18]:
from typing import Optional, Callable

import numpy as np
from astropy.stats import SigmaClip
from numpy.typing import ArrayLike
from photutils import Background2D, SExtractorBackground
from photutils.utils import calc_total_error

class FlowsBackground:

    def __init__(self, background_estimator: Background2D = Background2D):
        self.background_estimator = Background2D
        self.background: Optional[ArrayLike] = None
        self.background_rms: Optional[ArrayLike] = None

    def estimate_background(self, clean_image: np.ma.MaskedArray) -> None:
        # Estimate image background:
        # Not using image.clean here, since we are redefining the mask anyway
        bkg2d = self.background_estimator(clean_image, (128, 128), filter_size=(5, 5),
                                         sigma_clip=SigmaClip(sigma=3.0), bkg_estimator=SExtractorBackground(),
                                         exclude_percentile=50.0)
        self.background = bkg2d.background
        self.background_rms = bkg2d.background_rms

    def background_subtract(self, clean_image: ArrayLike) -> ArrayLike:
        if self.background is None:
            self.estimate_background(clean_image)
        return clean_image - self.background

    def error(self, clean_image: ArrayLike, error_method: Callable = calc_total_error):
        """
        Calculate the 2D error using the background RMS.
        """
        if self.background is None:
            raise AttributeError("background must be estimated before calling error")
        return error_method(clean_image, self.background_rms, 1.0)

# ----------!!! photometry.py code starts here !!! ----------#

In [19]:
print("everything above works!")

everything above works!


In [20]:
"""
Flows photometry code.

.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
.. codeauthor:: Emir Karamehmetoglu <emir.k@phys.au.dk>
.. codeauthor:: Simon Holmbo <sholmbo@phys.au.dk>
"""

__version__ = get_version(pep440=False)

PhotutilsBackground = TypeVar('PhotutilsBackground', bound=photutils.background.core.BackgroundBase)
logger = create_logger()


# def get_datafile(fileid: int) -> Dict:
#     """
#     Get datafile from API, log it, return.
#     """
#     datafile = api.get_datafile(fileid)
#     logger.debug("Datafile: %s", datafile)
#     return datafile


# def get_catalog(targetid: int) -> Dict:
#     catalog = api.get_catalog(targetid, output='table')
#     logger.debug(f"catalog obtained for target: {targetid}")
#     return catalog


class PSFBuilder:
    init_cutout_size: int = 29
    min_pixels: int = 15

    def __init__(self, image: FlowsImage, target: Target, fwhm_guess: float,
                 epsf_builder: FlowsEPSFBuilder = FlowsEPSFBuilder):
        self.image = image
        self.target = target
        self.fwhm = fwhm_guess
        self.epsf_builder = epsf_builder

        # To be updated later.
        self.epsf = None

    # @TODO: Move to PhotometryMediator
    @property
    def star_size(self) -> int:
        # Make cutouts of stars using extract_stars:
        # Scales with FWHM
        size = int(np.round(self.init_cutout_size * self.fwhm / 6))
        size += 0 if size % 2 else 1  # Make sure it's odd
        size = max(size, self.min_pixels)  # Never go below 15 pixels
        return size

    def extract_star_cutouts(self, star_xys: np.ndarray) -> List[np.ma.MaskedArray]:
        """
        Extract star cutouts from the image.
        """
        # Extract stars from image
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', AstropyUserWarning)
            stars = extract_stars(NDData(data=self.image.subclean.data, mask=self.image.mask),
                                  Table(star_xys, names=('x', 'y')),
                                  size=self.star_size + 6  # +6 for edge buffer
                                  )
        logger.info("Number of stars input to ePSF builder: %d", len(stars))
        return stars

    def make_epsf(self, stars):
        """
        Make an ePSF from the star cutouts.
        """
        # Build ePSF
        logger.info("Building ePSF...")
        builder = self.epsf_builder(
            oversampling=1, shape=1 * self.star_size,
            fitter=EPSFFitter(fit_boxsize=max(int(np.round(1.5 * self.fwhm)), 5)),
            recentering_boxsize=max(int(np.round(2 * self.fwhm)), 5),
            norm_radius=max(self.fwhm, 5), maxiters=100,
            progress_bar=multiprocessing.parent_process() is None and logger.getEffectiveLevel() <= 20
        )
        epsf, stars = builder(stars)

        logger.info(f"Built PSF model "
                    f"{epsf.fit_info['n_iter'] / epsf.fit_info['max_iters']} in {epsf.fit_info['time']} seconds")

        return epsf, stars


class Photometry:

    def __init__(self, photometry_obj: Optional[BasicPSFPhotometry] = None):
        self.photometry_obj = photometry_obj

    @staticmethod
    def calculate_appflux(apphot_tbl: Table, apertures: CircularAperture, annuli: CircularAnnulus) -> Table:
        """
        Calculate the aperture flux for the given apertures and annuli and append result to table.
        """
        # Subtract background estimated from annuli:
        bkg = (apphot_tbl['aperture_sum_1'] / annuli.area) * apertures.area
        apphot_tbl['flux_aperture'] = apphot_tbl['aperture_sum_0'] - bkg

        apphot_tbl['flux_aperture_error'] = np.sqrt(apphot_tbl['aperture_sum_err_0'] ** 2 +
                                                    (apphot_tbl[
                                                         'aperture_sum_err_1'] / annuli.area * apertures.area) ** 2)
        return apphot_tbl

    def apphot(self, coordinates: ArrayLike, image: FlowsImage, fwhm: float, use_raw: bool = False) -> Table:
        img = image.clean if use_raw else image.subclean
        apertures = CircularAperture(coordinates, r=fwhm)
        annuli = CircularAnnulus(coordinates, r_in=1.5 * fwhm, r_out=2.5 * fwhm)
        apphot_tbl = aperture_photometry(img, [apertures, annuli], mask=image.mask, error=image.error)
        return self.calculate_appflux(apphot_tbl, apertures, annuli)

    def create_photometry_object(self, fwhm: Union[float, u.Quantity], psf_model: photutils.psf.EPSFModel,
                                 fitsize: Union[int, Tuple[int]], fitter: Callable = fitting.LevMarLSQFitter(),
                                 bkg: PhotutilsBackground = MedianBackground()):
        self.photometry_obj = BasicPSFPhotometry(group_maker=DAOGroup(fwhm), bkg_estimator=bkg, psf_model=psf_model,
                                                 fitter=fitter, fitshape=fitsize, aperture_radius=fwhm)

    def psfphot(self, image: ArrayLike, init_table: Table) -> Tuple[BasicPSFPhotometry, Table]:
        """PSF photometry on init guesses table/row.
        """
        if self.photometry_obj is None:
            raise ValueError('Photometry object not initialized.')
        # logger.info(f"{init_table}")
        output: Table = self.photometry_obj(image=image, init_guesses=init_table)
        return self.photometry_obj, output

    @staticmethod
    def rescale_flux_error(phot_tables: Dict[int, Table],
                           flux: float, flux_err: float, exptime: float) -> tuple[float, float]:
        """Rescale the error using input phot_tables dict with keys as input fit shapes and values as tables/rows
        returning the new error estimate and its fit shape."""
        select_first_row = isinstance(list(phot_tables.values())[0], Table)

        for fit_shape, row in phot_tables.items():
            row = row[0] if select_first_row else row
            new_err = row['flux_unc'] / exptime
            new_flux = row['flux_fit'] / exptime
            if new_flux <= flux + flux_err:
                return fit_shape, new_err
        logger.warning("Rescaled psf flux errors do not overlap input flux + error to 1 sigma, using original error.")
        return 0, flux_err

    @staticmethod
    def get_fit_shapes(fwhm: Union[float, int], star_size: int, fwhm_min: int = 2,
                       fwhm_max: int = 4) -> NDArray[np.int_]:
        if star_size / fwhm < fwhm_min:
            return np.atleast_1d(np.array(star_size))
        fit_shapes = np.arange(int(fwhm_min * fwhm) - 1, min(int(fwhm_max * fwhm), star_size), 1)
        return fit_shapes[fit_shapes % 2 == 1]  # odd


class PhotometryManager:

    def __init__(self, target: Target, image: FlowsImage, bkg: FlowsBackground,
                 references: References, directories: DirectoryProtocol,
                 fwhm: Optional[float] = None,
                 psf_builder: Optional[PSFBuilder] = None,
                 cleaner: Optional[ReferenceCleaner] = None,
                 diffimage: Optional[FlowsImage] = None):
        self.target = target
        self.image = image
        self.bkg = bkg
        self.references = references
        self.directories = directories
        # Initially possibly None:
        self.fwhm = fwhm
        self.psf_builder = psf_builder
        self.cleaner = cleaner
        self.diffimage = diffimage
        # To be updated later.
        self.clean_references = references
        self.clean_references_with_diff = references
        self.init_guesses = None
        self.init_guesses_diff = None
        self.photometry = Photometry()
        self.diff_im_exists = diffimage is not None
        self.results_table = None

    def propogate_references(self):
        self.references.make_sky_coords()  # Make sky coordinates
        self.references.propagate(self.image.obstime)  # get reference catalog at obstime

    def background_subtract(self):
        self.image.subclean = self.bkg.background_subtract(self.image.clean)
        self.image.error = self.bkg.error(self.image.clean)
        if self.diff_im_exists:
            self.diffimage.error = self.image.error

    def recalculate_image_wcs(self, cm_timeout: float):
        self.image = correct_wcs(self.image, self.references, target=self.target, timeout=cm_timeout)

    def calculate_pixel_coordinates(self):
        # Calculate pixel-coordinates of references:
        self.references.get_xy(self.image.wcs)
        self.references.make_pixel_columns()

    def clean_reference_stars(self, rsq_min: float = 0.15):
        # Clean out the references:
        self.cleaner = ReferenceCleaner(self.image, self.references, rsq_min=rsq_min)
        # Reject references that are too close to target or edge of the image
        masked_references = self.cleaner.mask_edge_and_target(self.target.coords)
        if not masked_references.table:
            raise RuntimeError("No clean references in field")

        # Clean masked reference star locations
        self.clean_references, self.fwhm = self.cleaner.clean_references(masked_references)

    def create_psf_builder(self) -> PSFBuilder:
        return PSFBuilder(self.image, self.target, self.fwhm)

    def update_reference_with_epsf(self, stars):
        # Store which stars were used in ePSF in the table:
        self.clean_references.table.add_column(col=[False], name='used_for_epsf')
        self.clean_references.table['used_for_epsf'][[star.id_label - 1 for star in stars.all_good_stars]] = True
        logger.info("Number of stars used for ePSF: %d", np.sum(self.clean_references.table['used_for_epsf']))

    def create_epsf(self, psf_builder: PSFBuilder = None):
        if psf_builder is None:
            psf_builder = self.create_psf_builder()

        # EPSF creation
        star_cutouts = psf_builder.extract_star_cutouts(self.cleaner.gaussian_xys)
        epsf, stars = psf_builder.make_epsf(star_cutouts)
        epsf_ok, epsf_fwhms = verify_epsf(epsf)
        if not epsf_ok:
            raise RuntimeError("Bad ePSF detected.")
        psf_builder.epsf = epsf

        # Use the largest FWHM as new FWHM
        fwhm = np.max(epsf_fwhms)
        logger.info(f"Final FWHM based on ePSF: {fwhm}")
        psf_builder.fwhm = fwhm

        # Update state
        self.fwhm = fwhm
        self.psf_builder = psf_builder
        self.update_reference_with_epsf(stars)

    def prepare_target_and_references_for_photometry(self):
        # position in the image including target as row 0:
        self.target.calc_pixels(self.image.wcs)
        self.clean_references.add_target(self.target, starid=0)  # by default prepends target
        self.init_guesses = InitGuess(self.clean_references, target_row=0)
        if self.diff_im_exists:
            self.clean_references_with_diff = copy(self.clean_references)
            self.clean_references_with_diff.add_target(self.target, starid=-1)
        self.init_guesses_diff = InitGuess(self.clean_references_with_diff, target_row=1, diff_row=0)

    def apphot(self) -> Table:
        # apphot_tbl.insert_row(0, dict(diff_apphot_tbl[0]))
        apphot_tbl = self.photometry.apphot(self.clean_references.xy, self.image, self.fwhm)
        if self.diff_im_exists:
            # Add diff image photometry:
            diff_tbl = self.photometry.apphot(self.clean_references.xy[0], self.diffimage, self.fwhm, use_raw=True)
            apphot_tbl.insert_row(0, dict(diff_tbl[0]))
        return apphot_tbl

    def psfphot(self, fit_shape: Optional[Union[int, Tuple[int, int]]] = None) -> Table:
        fit_shape = self.psf_builder.star_size if fit_shape is None else fit_shape
        # PSF photometry:
        self.photometry.create_photometry_object(fwhm=self.fwhm, psf_model=self.psf_builder.epsf, fitsize=fit_shape)
        psfphot_tbl = self.raw_psf_phot()
        if self.diff_im_exists:
            # Add diff image photometry:
            diff_tbl = self.diff_psf_phot()
            psfphot_tbl.insert_row(0, dict(diff_tbl[0]))
        return psfphot_tbl

    def raw_psf_phot(self, init_guess: Optional[Table] = None) -> Table:
        init_guess = self.init_guesses.init_guess_full if init_guess is None else init_guess
        _, psf_tbl = self.photometry.psfphot(image=self.image.subclean, init_table=init_guess)
        return psf_tbl

    def diff_psf_phot(self) -> Table:
        _, psf_tbl = self.photometry.psfphot(image=self.diffimage.clean,
                                             init_table=self.init_guesses_diff.init_guess_diff)
        return psf_tbl

    def rescale_uncertainty(self, psfphot_tbl: Table, dynamic: bool = True, 
                            static_fwhm: float = 2.5, epsilon_mag: float = 0.004,
                            ensure_greater: bool = True):
        """
        Rescale the uncertainty of the PSF photometry using a variable fitsize.

        Parameters
        ----------
        psfphot_tbl : Table
            Table of PSF photometry.
        dynamic : bool
            Dynamically decide FWHM multiple for rescaling.
        static_fwhm : float
            FWHM multiple to use incase dynamic fails or don't want to use it. Default 2.5 determined empirically.
        epsilon_mag : float
            Small magnitude change within which new and old uncertainties are considered the same. 
            Should be smaller than ~1/2 the expected uncertainty.
        """
        # Rescale psf errors from fit iteratively
        fit_shapes = self.photometry.get_fit_shapes(self.fwhm, self.psf_builder.star_size)
        fit_shape = int(static_fwhm * self.fwhm)
        fit_shape = fit_shape if fit_shape % 2 == 1 else fit_shape + 1
        if dynamic and len(fit_shapes) > 1:
            _phot_tables_dict = {}
            for fitshape in fit_shapes:
                self.photometry.create_photometry_object(
                    fwhm=self.fwhm, psf_model=self.psf_builder.epsf, fitsize=fitshape)
                if self.diff_im_exists:
                    _table = self.diff_psf_phot()
                _table = self.raw_psf_phot(self.init_guesses.init_guess_target)
                if "flux_unc" in _table.colnames:
                    _phot_tables_dict[fitshape] = _table

            if len(_phot_tables_dict) == 0:
                logger.warning("No PSF errors found for dynamic rescaling, trying static.")
                dynamic = False
            else:
                # Find the fit shape elbow:
                flux = psfphot_tbl[0]['flux_fit']
                flux_err = psfphot_tbl[0]['flux_unc']
                exptime = self.image.exptime
                dynamic_fit_shape, new_err = self.photometry.rescale_flux_error(_phot_tables_dict, flux, flux_err,
                                                                                exptime)
                fit_shape = dynamic_fit_shape if dynamic_fit_shape != 0 else fit_shape

        # Recalculate all reference uncertainties using new fitsize:
        logger.info(f"Recalculating all reference uncertainties using new fitsize:"
                    f" {fit_shape} pixels, ({fit_shape/self.fwhm if dynamic else static_fwhm :.2} * FWHM).")
        psfphot_tbl_rescaled = self.psfphot(fit_shape)
        if psfphot_tbl['flux_unc'][0] > psfphot_tbl_rescaled['flux_unc'][0] + epsilon_mag and ensure_greater:
            logger.info("Recalculated uncertainties were smaller than original and ``ensure_greater`` was True:"
                        "Not using rescaled uncertainties for the SN.")
            psfphot_tbl['flux_unc'][1:] = psfphot_tbl_rescaled['flux_unc'][1:]
            return psfphot_tbl

        psfphot_tbl['flux_unc'] = psfphot_tbl_rescaled['flux_unc']
        return psfphot_tbl

    def make_result_table(self, psfphot_tbl: Table, apphot_tbl: Table):
        # Build results table:
        clean_references = self.clean_references_with_diff if self.diff_im_exists else self.clean_references
        self.results_table = ResultsTable.make_results_table(clean_references.table, apphot_tbl, psfphot_tbl,
                                                             self.image)

    def calculate_mag(self, make_plot: bool = False) -> Tuple[Optional[plt.Figure], Optional[plt.Axes]]:
        if self.results_table is None:
            raise ValueError("Results table is not initialized. Run photometry first.")
        # Todo: refactor.
        # Get instrumental magnitude (currently we do too much here).
        results_table, mag_fig, mag_ax = instrumental_mag(self.results_table, self.target, make_plot)
        self.results_table = results_table
        return mag_fig, mag_ax

    def calculate_pixel_scale(self):
        # Find the pixel-scale of the science image
        pixel_area = proj_plane_pixel_area(self.image.wcs.celestial)
        pixel_scale = np.sqrt(pixel_area) * 3600  # arcsec/pixel
        logger.info("Science image pixel scale: %f", pixel_scale)
        return pixel_scale

    def add_metadata(self):
        # Add metadata to the results table:
        self.results_table.meta['fileid'] = self.image.fid
        self.results_table.meta['target_name'] = self.target.name
        self.results_table.meta['version'] = __version__
        self.results_table.meta['template'] = self.image.template_fid
        self.results_table.meta['diffimg'] = self.diffimage.fid if self.diff_im_exists else None
        self.results_table.meta['photfilter'] = self.target.photfilter
        self.results_table.meta['fwhm'] = self.fwhm * u.pixel
        pixel_scale = self.calculate_pixel_scale()
        self.results_table.meta['pixel_scale'] = pixel_scale * u.arcsec / u.pixel
        self.results_table.meta['seeing'] = (self.fwhm * pixel_scale) * u.arcsec
        self.results_table.meta['obstime-bmjd'] = float(self.image.obstime.mjd)
        self.results_table.meta['used_wcs'] = str(self.image.wcs)

    @classmethod
    def create_from_fid(cls, fid: int, directories: Optional[DirectoryProtocol] = None,
                        create_directories: bool = True, datafile: Optional[Dict] = None) -> 'PhotometryManager':
        """
        Create a Photometry object from a fileid.
        """
        io = IOManager.from_fid(fid, directories=directories, create_directories=create_directories, datafile=datafile)
        return PhotometryManager(target=io.target, image=io.load_science_image(), bkg=FlowsBackground(),
                                 references=io.load_references(), directories=io.directories,
                                 diffimage=io.load_diff_image())


def do_phot(fileid: int, cm_timeout: Optional[float] = None, make_plots: bool = True,
            directories: Optional[DirectoryProtocol] = None, datafile: Optional[Dict[str, Any]] = None,
            rescale: bool = True, rescale_dynamic: bool = True) -> ResultsTable:
    # Set up photometry runner
    pm = PhotometryManager.create_from_fid(fileid, directories=directories, datafile=datafile, create_directories=True)

    # Set up photometry
    pm.propogate_references()
    pm.background_subtract()
    pm.recalculate_image_wcs(cm_timeout if cm_timeout is not None else np.inf)
    pm.calculate_pixel_coordinates()
    pm.clean_reference_stars()
    pm.create_epsf()  # Using default EPSF constructor, could pass in custom.
    pm.prepare_target_and_references_for_photometry()  # Add target to reference table

    # Do photometry
    apphot_tbl = pm.apphot()
    # Verify uncertainty exists after PSF phot:
    psfphot_tbl = ResultsTable.verify_uncertainty_column(pm.psfphot())  
    if rescale:  # Rescale uncertainties
        psfphot_tbl = pm.rescale_uncertainty(psfphot_tbl, dynamic=rescale_dynamic) 

    # Build results table and calculate magnitudes
    pm.make_result_table(psfphot_tbl, apphot_tbl)
    mag_fig, mag_ax = pm.calculate_mag(make_plot=make_plots)
    pm.add_metadata()

    if make_plots:
        do_plots(pm, mag_fig, mag_ax)
    return pm.results_table


def timed_photometry(fileid: int, cm_timeout: Optional[float] = None, make_plots: bool = True,
                     directories: Optional[DirectoryProtocol] = None, save: bool = True,
                     datafile: Optional[Dict[str, Any]] = None, rescale: bool = True,
                     rescale_dynamic: bool = True) -> ResultsTable:
    # TODO: Timer should be moved out of this function.
    tic = default_timer()
    results_table = do_phot(fileid, cm_timeout, make_plots, directories, datafile, rescale, rescale_dynamic)

    # Save the results table:
    if save:
        results_table.write(directories.photometry_path, format='ascii.ecsv', delimiter=',', overwrite=True)

    # Log result and time taken:
    logger.info("------------------------------------------------------")
    logger.info("Success!")
    logger.info("Main target: %f +/- %f", results_table[0]['mag'], results_table[0]['mag_error'])
    logger.info("Photometry took: %.1f seconds", default_timer() - tic)
    return results_table
def do_plots(pm: PhotometryManager, mag_fig: plt.Figure, mag_ax: plt.Axes):
    # Plot the image:
    fig, ax = plt.subplots(1, 2, figsize=(20, 18))
    plot_image(pm.image.clean, ax=ax[0], scale='log', cbar='right', title='Image')
    plot_image(pm.image.mask, ax=ax[1], scale='linear', cbar='right', title='Mask')
    fig.savefig(pm.directories.save_as('original.png'), bbox_inches='tight')
    plt.close(fig)

    # Plot background estimation:
    fig, ax = plt.subplots(1, 3, figsize=(20, 6))
    plot_image(pm.image.clean, ax=ax[0], scale='log', title='Original')
    plot_image(pm.bkg.background, ax=ax[1], scale='log', title='Background')
    plot_image(pm.image.subclean, ax=ax[2], scale='log', title='Background subtracted')
    fig.savefig(pm.directories.save_as('background.png'), bbox_inches='tight')
    plt.close(fig)

    # Create plot of target and reference star positions:
    fig, ax = plt.subplots(1, 1, figsize=(20, 18))
    plot_image(pm.image.subclean, ax=ax, scale='log', cbar='right', title=pm.target.name)
    ax.scatter(pm.references.table['pixel_column'], pm.references.table['pixel_row'], c='r', marker='o', alpha=0.6)
    ax.scatter(pm.cleaner.gaussian_xys[:, 0], pm.cleaner.gaussian_xys[:, 1], marker='s', alpha=0.6, edgecolors='green',
               facecolors='none')
    ax.scatter(pm.target.pixel_column, pm.target.pixel_row, marker='+', s=20, c='r')
    fig.savefig(pm.directories.save_as('positions.png'), bbox_inches='tight')
    plt.close(fig)

    # Plot EPSF
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15))
    plot_image(pm.psf_builder.epsf.data, ax=ax1, cmap='viridis')
    for a, ax in ((0, ax3), (1, ax2)):
        profile = pm.psf_builder.epsf.data.sum(axis=a)
        ax.plot(profile, 'k.-')
        ax.axvline(profile.argmax())
        ax.set_xlim(-0.5, len(profile) - 0.5)
    ax4.axis('off')
    fig.savefig(pm.directories.save_as('epsf.png'), bbox_inches='tight')
    plt.close(fig)
    del ax, ax1, ax2, ax3, ax4

    # Create two plots of the difference image:
    if pm.diff_im_exists:
        fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(20, 20))
        plot_image(pm.diffimage.clean, ax=ax, cbar='right', title=pm.target.name)
        ax.plot(pm.target.pixel_column, pm.target.pixel_row, marker='+', markersize=20, color='r')
        fig.savefig(pm.directories.save_as('diffimg.png'), bbox_inches='tight')
        ax.set_xlim(pm.target.pixel_column - 50, pm.target.pixel_column + 50)
        ax.set_ylim(pm.target.pixel_row - 50, pm.target.pixel_row + 50)
        fig.savefig(pm.directories.save_as('diffimg_zoom.png'), bbox_inches='tight')
        plt.close(fig)

    # Calibration (handled in magnitudes.py).
    mag_fig.savefig(pm.directories.save_as('calibration.png'), bbox_inches='tight')
    plt.close(mag_fig)

In [None]:
photometry_menager = 

In [21]:
print("Done")

Done


In [22]:
#from .fileio import DirectoryProtocol, IOManager  # noqa: E402
#from .background import FlowsBackground  # noqa: E402


In [23]:
#import sys
#sys.path.append("/flows") # go to parent dir


In [24]:
#from utilities import create_logger
#from filters import get_reference_filter
#import magnitudes


In [25]:
#from magnitudes import instrumental_mag
