# Notebook for the manuscript: Finding AGN remnants with machine learning

In [None]:
import numpy as np
import pandas as pd
import scipy.stats as st
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import DetCurveDisplay, RocCurveDisplay,plot_roc_curve,plot_precision_recall_curve
from dataclasses import dataclass
import dataclasses
import matplotlib.pyplot as plt
import matplotlib
import mahotas as mh
matplotlib.rcParams.update({'font.size': 16})
from astropy.wcs import WCS
import seaborn as sns
from copy import deepcopy
import hdbscan
import os
import sys
import pyvo
%load_ext autoreload
%autoreload 2
# Insert path to pinklib.postprocessing.py for support functions
# Library available here:
sys.path.insert(0, '<path_to_pinklib_folder>')
import postprocessing as post
from pinklib.postprocessing import CutoutSettings
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier,GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.model_selection import permutation_test_score
from sklearn.inspection import permutation_importance
from joblib import Parallel, logger
from sklearn.base import is_classifier, clone
from sklearn.utils import indexable, check_random_state, _safe_indexing
from sklearn.utils.validation import _check_fit_params
from sklearn.utils.fixes import delayed
from sklearn.utils.metaestimators import _safe_split
from sklearn.metrics import check_scoring
from sklearn.model_selection._split import check_cv
from sklearn.metrics import recall_score, make_scorer, SCORERS, fbeta_score

from sklearn.utils import class_weight
import time
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score

# Settings

In [None]:
#### Import data, set up file locations and names
# Figures directory
paper_fig_dir = 'figures'
cat_dir = 'catalogues'
data_directory = 'remnants'
comp_path=os.path.join(cat_dir,'LOFAR_HBA_T1_DR1_merge_ID_v1.2.comp.h5')
gaul_path=os.path.join(cat_dir,'LOFAR_HBA_T1_DR1_catalog_v0.99.gaus.h5')
vac_path=os.path.join(cat_dir,'LOFAR_HBA_T1_DR1_merge_ID_optical_f_v1.2.h5')
# Path to directory that contains fits-files of all Stokes-I image pointings.
LoTSS_DR2_dir='LoTSS_DR2/RA0h_field'

comp_cat = pd.read_hdf(comp_path,'df')
gaul_cat = pd.read_hdf(gaul_path,'df')
value_added_catalogue = pd.read_hdf(vac_path,'df')

# File name conventions used
################################################
# Name of the fits file
fits_filename = 'mosaic-blanked'
fits_rms_filename = 'mosaic.rms'
fits_new_catalogue_filename = 'mosaic.cat_new'
# Name of the trained SOM subdirectory
trained_subdirectory = None
# output filename
cutouts_filename = 'LoTSS_DR2_value_added_variable'
cutouts_filename_final = cutouts_filename +'_final'
rms_cutouts_filename = 'LoTSS_DR2_value_added_variable_rms'
rms_cutouts_filename_final = rms_cutouts_filename +'_final'

# Name of the output directory, where your trained SOM and your SOM-mapping should reside
output_directory = os.path.join(data_directory,'output')
figures_dir = os.path.join(output_directory, 'figures')
map_dir = os.path.join(output_directory,'maps')
run_dir = os.path.join(data_directory,'run')
os.makedirs(output_directory, exist_ok=True)
os.makedirs(map_dir, exist_ok=True)
os.makedirs(run_dir, exist_ok=True)
os.makedirs(figures_dir, exist_ok=True)

# Cut-out parameters
################################################
# Resolution for each cutout in degrees
resolution = 0.0085/60# Note: we already hit bedrock resolution (the highest res. hips-files) so increasing it won't help.
apply_clipping = False
# Enable scaling?
log_scaling = False
sinh_scaling = False
rotated_size_arcsec = 100
rotated_size = 67 # fits-pixels
fullsize = 95 #85 # fits-pixels (rotated_size * sqqrt(2))
test_fraction = 0.3
random_seed = 42
lower_sigma_limit = 1.5 # Lower clip bound is defined as local rms * lower_sigma_limit
# Name of the binary file
cutouts_bin_name = f'cutouts_preprocessed_rotatedsize_{rotated_size_arcsec}arcsec'
################################################
gpu_id = 0

# # trained SOM parameters
# ################################################
layout = 'quadratic'
som_label = '- 10x10 cyclic'
number_of_channels = 1
som_width, som_height, som_depth = 10, 10, 1
# ################################################

# Catalogue parameters
################################################
mosaic_id_key,ra_key,dec_key = 'Mosaic_ID', 'RA', 'DEC'
################################################

# Outliers parameters
################################################
debug = False
number_of_outliers_to_show = 100
max_number_of_images_to_show = 10
#####################################



lgz_size = value_added_catalogue['LGZ_Size'].fillna(0)
lgz_width = value_added_catalogue['LGZ_Width'].fillna(0)
lgz_PA = value_added_catalogue['LGZ_PA'].fillna(0)
maj_size = value_added_catalogue['Maj'].fillna(0)
min_size = value_added_catalogue['Min'].fillna(0)
PA = value_added_catalogue['PA'].fillna(0)
value_added_catalogue['source_size'] = (lgz_size + maj_size).astype(float)
value_added_catalogue['source_width'] = (lgz_width+min_size).astype(float)
value_added_catalogue['source_PA'] = (lgz_PA+PA).astype(float)

fits_path_LoTSS = os.path.join(
    'LoTSS_DR2/RA0h_field/P206+50', 
    fits_filename+'.fits')
image, hdr = post.load_fits(fits_path_LoTSS, dimensions_normal=False)
wcs = WCS(hdr,naxis=2)
angular_resolution_LoTSS = abs(hdr['CDELT1'])*3600
print(f'Angular resolution: {angular_resolution_LoTSS:.2f} arcsec/pixels')

print("# sources in HETDEX bigger than 60\":", 
      np.sum(value_added_catalogue['source_size'] > 60))

# Load Marisa's sample of visually inspected AGN remnant candidates
remnant_names = pd.read_csv('datasets/151remnants-candidates-hetdex.txt',
                            names=['Source_Name'],header=None).Source_Name.values
print(f"We have {len(remnant_names)} remnants.")
hetdex_bigger_than60arcsec = value_added_catalogue[value_added_catalogue['source_size']>60]
calibration_set = value_added_catalogue[value_added_catalogue.Source_Name.isin(remnant_names)]

# Adjust random forest code such that no data leakage occurs
The cell below makes sure that the SOM remnant ratio and the Haralick cluster ratios
do not leak between training and validation and between training and test data.

In [None]:
@dataclass
class plotSetting:
    normalize: bool
    zoom_in: bool
    highlight_neurons: list = dataclasses.field(default_factory=list)
    highlight_colors: list = dataclasses.field(default_factory=list)
    legend_list: list = dataclasses.field(default_factory=list)


def insert_local_remnant_ratios(features_original, labels_,debug=False):
    features_ = deepcopy(features_original)
    bmn_train = [int(feat[0]) for feat in features_]
    bmn_train_remnants = [bmn for (bmn,label) 
                                in zip(bmn_train,labels_) if label]
    all_labels, all_counts = np.unique(bmn_train, return_counts=True)
    remnant_labels, remnant_counts = np.unique(bmn_train_remnants, return_counts=True)
    all_dict = {l:c for l, c in zip(all_labels, all_counts)}
    remnant_dict = {l:c for l, c in zip(remnant_labels, remnant_counts)}
    # Create abs and ratios
    remnants_per_mapped_to_neuron = np.array([remnant_dict.get(key, 0) for key in bmn_train])
    remnants_per_mapped_to_neuron_ratio = np.array([remnant_dict.get(key, 0)/all_dict.get(key, 1)
                                           for key in bmn_train])
    
    features_T = features_.T
    features_T[0] = remnants_per_mapped_to_neuron
    features_T[1] = remnants_per_mapped_to_neuron_ratio
    
    ################################# HARALICK Ratios
    # Get haralick cluster labels for the whole set and for the remnant
    hara_hard_train = [int(feat[-1]) for feat in features_]
    hara_hard_train_remnants = [hara_label for (hara_label,label) 
                                in zip(hara_hard_train,labels_) if label]
    # Count # of sources in each cluster for the whole set and for the remnant
    remnant_hara_hard_labels, remnant_hara_hard_counts = np.unique(hara_hard_train_remnants,
                                                         return_counts=True)
    all_hara_hard_labels, all_hara_hard_counts = np.unique(hara_hard_train,return_counts=True)

    # Insert counts into dict
    remnant_dict_hard = {l:c for l, c in zip(remnant_hara_hard_labels, remnant_hara_hard_counts)}
    all_dict_hard = {l:c for l, c in zip(all_hara_hard_labels, all_hara_hard_counts)}

    # Create list of ratios.
    hara_hard_ratios = [remnant_dict_hard.get(key, 0)/all_dict_hard.get(key, 1) 
                   for key in hara_hard_train]
    features_T[-1] = hara_hard_ratios
    
    
    if debug:
        abs_bmn_debug = np.array([remnant_dict.get(key, 0) for key in range(25)])
        rel_bmn_debug = np.array([remnant_dict.get(key, 0)/all_dict.get(key, 1)
                           for key in range(25)])
        hara_hard_debug = [remnant_dict_hard.get(key, 0)/all_dict_hard.get(key, 1) 
                   for key in range(-1,6)]
        hara_soft_debug = [remnant_dict_soft.get(key, 0)/all_dict_soft.get(key, 1) 
                   for key in range(6)]

        plt.imshow(abs_bmn_debug.reshape(5,5))
        plt.title('SOM remnant')
        plt.show()
        plt.imshow(rel_bmn_debug.reshape(5,5))
        plt.title('SOM remnant ratio')
        plt.show()
        plt.bar(list(range(-1,6)),hara_hard_debug)
        plt.title('Hara hard remnant ratio')
        plt.show()
        plt.bar(list(range(6)),hara_soft_debug)
        plt.title('Hara soft remnant ratio')
        plt.show()
    return features_T.T, remnant_dict, all_dict, remnant_dict_hard, all_dict_hard



def custom_permutation_test_score(estimator,X,y,*,groups=None,cv=None,n_permutations=100,
    n_jobs=-1,random_state=0,verbose=0,scoring=None,fit_params=None,):
    """Evaluate the significance of a cross-validated score with permutations.
    Permutes targets to generate 'randomized data' and compute the empirical
    p-value against the null hypothesis that features and targets are
    independent.
    The p-value represents the fraction of randomized data sets where the
    estimator performed as well or better than in the original data. A small
    p-value suggests that there is a real dependency between features and
    targets which has been used by the estimator to give good predictions.
    A large p-value may be due to lack of real dependency between features
    and targets or the estimator was not able to use the dependency to
    give good predictions.
    Read more in the :ref:`User Guide <permutation_test_score>`."""
    X, y, groups = indexable(X, y, groups)

    cv = check_cv(cv, y, classifier=is_classifier(estimator))
    scorer = check_scoring(estimator, scoring=scoring)
    random_state = check_random_state(random_state)

    # We clone the estimator to make sure that all the folds are
    # independent, and that it is pickle-able.
    score = _custom_permutation_test_score(
        clone(estimator), X, y, y, groups, cv, scorer, fit_params=fit_params
    )
    permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
        delayed(_custom_permutation_test_score)(
            clone(estimator),
            X,
            y,
            _shuffle(y, groups, random_state),
            groups,
            cv,
            scorer,
            fit_params=fit_params,
        )
        for _ in range(n_permutations)
    )
    permutation_scores = np.array(permutation_scores)
    pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1)
    return score, permutation_scores, pvalue

def insert_local_test_ratios(features_test, remnant_dict, all_dict, \
                                            remnant_dict_hard, all_dict_hard):
    # Isolate BMN
    features_ = deepcopy(features_test)
    bmn_testsubset = [int(feat[0]) for feat in features_]
    
    # Create abs and ratios
    remnants_per_mapped_to_neuron = [remnant_dict.get(key, 0) for key in bmn_testsubset]
    remnants_per_mapped_to_neuron_ratio = [remnant_dict.get(key, 0)/all_dict.get(key, 1)
                                           for key in bmn_testsubset]
    
    features_T = features_.T
    features_T[0] = remnants_per_mapped_to_neuron
    features_T[1] = remnants_per_mapped_to_neuron_ratio
    
    ################################# HARALICK Ratios
    # Get haralick cluster labels for the whole set and for the remnant subset
    hara_hard_testsubset = [int(feat[-1]) for feat in features_]
    # Create list of ratios.
    hara_hard_ratios = [remnant_dict_hard.get(key, 0)/all_dict_hard.get(key, 1) 
                   for key in hara_hard_testsubset]
    features_T[-1] = hara_hard_ratios

    return features_T.T

def _custom_permutation_test_score(estimator, X, real_y, y, groups, cv, scorer, fit_params):
    """Auxiliary function for custom_permutation_test_score"""
    # Adjust length of sample weights
    fit_params = fit_params if fit_params is not None else {}
    avg_score = []
    for train, test in cv.split(X, y, groups):
        X_train_before, y_train = _safe_split(estimator, X, y, train)
        X_test_before, y_test = _safe_split(estimator, X, y, test, train)
        _, real_y_train = _safe_split(estimator, X, real_y, train)
        X_train, remnant_dict, all_dict, remnant_dict_hard, all_dict_hard, \
            = insert_local_remnant_ratios(X_train_before,real_y_train)
        X_test = insert_local_test_ratios(X_test_before, remnant_dict, all_dict, \
                                            remnant_dict_hard, all_dict_hard)
        
        fit_params = _check_fit_params(X, fit_params, train)
        estimator.fit(X_train, y_train, **fit_params)
        avg_score.append(scorer(estimator, X_test, y_test))
    return np.mean(avg_score)


def _shuffle(y, groups, random_state):
    """Return a shuffled copy of y eventually shuffle among same groups."""
    if groups is None:
        indices = random_state.permutation(len(y))
    else:
        indices = np.arange(len(groups))
        for group in np.unique(groups):
            this_mask = groups == group
            indices[this_mask] = random_state.permutation(indices[this_mask])
    return _safe_indexing(y, indices)

def classification_report_to_latex(report,
                                   target_names=['noncandidate','AGN remnant candidate']):
    table_boilerplate=r"""
\begin{table}
\centering
\caption{caption}
\label{tab:perf}
\begin{tabular}{rrrrr}
\hline\hline
  & precision & recall & f1-score & support \\
\hline"""
    for target in target_names:
        t = report[target]
        table_boilerplate+=fr"""
{target}  & ${t['precision']:.2f}$ & ${t['recall']:.2f}$ & ${t['f1-score']:.2f}$ & ${t['support']}$ \\"""  
    t = report['weighted avg']
    table_boilerplate+=fr"""
  &  &  & &  \\
weighted average  & ${t['precision']:.2f}$ & ${t['recall']:.2f}$ & ${t['f1-score']:.2f}$ & ${t['support']}$ \\"""  
    table_boilerplate+=fr"""
\hline
accuracy & ${report['accuracy']:.2f}$ & & &  ${t['support']}$ \\"""  
    table_boilerplate+="""
\hline
\end{tabular}
\end{table}"""
    return table_boilerplate

def mask_outside_gaussian(gaussians,astropy_cutout):
    # Create indices
    yi, xi = np.indices(astropy_cutout.shape)
        
    model = np.zeros(astropy_cutout.shape,dtype=bool)
    for g in gaussians:
        gau = g(xi,yi)
        max_gau = np.max(gau)
        model[gau < 0.01*max_gau] = True
    astropy_cutout[model] = np.nan
                       
    return model, astropy_cutout


def get_significance_stats(image, cutout_wcs, source, cat, noise,debug=False):
    """Mask everything outside source and return stats on the unmasked source.
    params:
        cat = a source catalogue in the pandas dataframe format"""
    use_source_size=True
    # retrieve pixel resolution as in practice it often varies between cutouts,
    # especially the RA can have less pixels representing a fixed set of arcseconds
    arcsec_per_pixel_RA, arcsec_per_pixel_DEC = post.return_cutout_pixel_resolution(
            image, cutout_wcs, verbose=False)
    if debug:
        print("angular res, ra, dec:", arcsec_per_pixel_RA,arcsec_per_pixel_DEC)

    # create subset of gaussian list
    ra, dec = source.RA, source.DEC

    sr = image.shape[0]*arcsec_per_pixel_RA/3600
    sd = image.shape[1]*arcsec_per_pixel_DEC/3600

    local_cat = cat[source.Source_Name == cat.Source_Name]
    if debug:
        print("Rough neighbour removal found this many neighbours:",len(local_cat))
          
    # Create gaussians
    gaussians = post.extract_gaussian_parameters_from_component_catalogue(local_cat,cutout_wcs,
            use_source_size=use_source_size)
    # Mask all outside
    model, residual = mask_outside_gaussian(gaussians,image)
    if debug:
        plt.title('model')
        plt.imshow(model)
        plt.show()
        plt.title('residual')
        plt.imshow(residual)
        plt.show()
    med = np.nanmedian(residual)/noise
    return residual, med

## Implement gridsearch such that no data leakage occurs

In [None]:
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, Sequence, Iterable
from functools import partial, reduce
from itertools import product
import numbers
import operator
import time
import warnings
from contextlib import suppress
from traceback import format_exc
import scipy.sparse as sp
from joblib import Parallel, delayed, logger

from sklearn.base import is_classifier, clone
from sklearn.utils import indexable, check_random_state, _safe_indexing
from sklearn.utils.validation import _check_fit_params
from sklearn.utils.validation import _num_samples
from sklearn.utils.validation import _deprecate_positional_args
from sklearn.utils.metaestimators import _safe_split
from sklearn.metrics._scorer import _check_multimetric_scoring, _MultimetricScorer

from sklearn.preprocessing import LabelEncoder
from numpy.ma import MaskedArray
from scipy.stats import rankdata
from sklearn.base import BaseEstimator, is_classifier, clone
from sklearn.base import MetaEstimatorMixin
from sklearn.model_selection._split import check_cv
from sklearn.model_selection._validation import _aggregate_score_dicts
from sklearn.model_selection._validation import _insert_error_scores
from sklearn.model_selection._validation import _normalize_score_results
from sklearn.exceptions import NotFittedError
from sklearn.utils import check_random_state
from sklearn.utils.random import sample_without_replacement
from sklearn.utils._tags import _safe_tags
from sklearn.utils.validation import indexable, check_is_fitted, _check_fit_params
from sklearn.utils.fixes import delayed
from sklearn.metrics._scorer import _check_multimetric_scoring

from sklearn.utils.validation import _deprecate_positional_args
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.metrics import check_scoring
from sklearn.utils import deprecated

def _check_param_grid(param_grid):
    if hasattr(param_grid, 'items'):
        param_grid = [param_grid]

    for p in param_grid:
        for name, v in p.items():
            if isinstance(v, np.ndarray) and v.ndim > 1:
                raise ValueError("Parameter array should be one-dimensional.")

            if (isinstance(v, str) or
                    not isinstance(v, (np.ndarray, Sequence))):
                raise ValueError("Parameter grid for parameter ({0}) needs to"
                                 " be a list or numpy array, but got ({1})."
                                 " Single values need to be wrapped in a list"
                                 " with one element.".format(name, type(v)))

            if len(v) == 0:
                raise ValueError("Parameter values for parameter ({0}) need "
                                 "to be a non-empty sequence.".format(name))
                
class ParameterGrid:
    """Grid of parameters with a discrete number of values for each.

    Can be used to iterate over parameter value combinations with the
    Python built-in function iter.
    The order of the generated parameter combinations is deterministic.

    Read more in the :ref:`User Guide <grid_search>`.

    Parameters
    ----------
    param_grid : dict of str to sequence, or sequence of such
        The parameter grid to explore, as a dictionary mapping estimator
        parameters to sequences of allowed values.

        An empty dict signifies default parameters.

        A sequence of dicts signifies a sequence of grids to search, and is
        useful to avoid exploring parameter combinations that make no sense
        or have no effect. See the examples below.

    Examples
    --------
    >>> from sklearn.model_selection import ParameterGrid
    >>> param_grid = {'a': [1, 2], 'b': [True, False]}
    >>> list(ParameterGrid(param_grid)) == (
    ...    [{'a': 1, 'b': True}, {'a': 1, 'b': False},
    ...     {'a': 2, 'b': True}, {'a': 2, 'b': False}])
    True

    >>> grid = [{'kernel': ['linear']}, {'kernel': ['rbf'], 'gamma': [1, 10]}]
    >>> list(ParameterGrid(grid)) == [{'kernel': 'linear'},
    ...                               {'kernel': 'rbf', 'gamma': 1},
    ...                               {'kernel': 'rbf', 'gamma': 10}]
    True
    >>> ParameterGrid(grid)[1] == {'kernel': 'rbf', 'gamma': 1}
    True

    See also
    --------
    :class:`GridSearchCV`:
        Uses :class:`ParameterGrid` to perform a full parallelized parameter
        search.
    """

    def __init__(self, param_grid):
        if not isinstance(param_grid, (Mapping, Iterable)):
            raise TypeError('Parameter grid is not a dict or '
                            'a list ({!r})'.format(param_grid))

        if isinstance(param_grid, Mapping):
            # wrap dictionary in a singleton list to support either dict
            # or list of dicts
            param_grid = [param_grid]

        # check if all entries are dictionaries of lists
        for grid in param_grid:
            if not isinstance(grid, dict):
                raise TypeError('Parameter grid is not a '
                                'dict ({!r})'.format(grid))
            for key in grid:
                if not isinstance(grid[key], Iterable):
                    raise TypeError('Parameter grid value is not iterable '
                                    '(key={!r}, value={!r})'
                                    .format(key, grid[key]))

        self.param_grid = param_grid

    def __iter__(self):
        """Iterate over the points in the grid.

        Returns
        -------
        params : iterator over dict of str to any
            Yields dictionaries mapping each estimator parameter to one of its
            allowed values.
        """
        for p in self.param_grid:
            # Always sort the keys of a dictionary, for reproducibility
            items = sorted(p.items())
            if not items:
                yield {}
            else:
                keys, values = zip(*items)
                for v in product(*values):
                    params = dict(zip(keys, v))
                    yield params

    def __len__(self):
        """Number of points on the grid."""
        # Product function that can handle iterables (np.product can't).
        product = partial(reduce, operator.mul)
        return sum(product(len(v) for v in p.values()) if p else 1
                   for p in self.param_grid)

    def __getitem__(self, ind):
        """Get the parameters that would be ``ind``th in iteration

        Parameters
        ----------
        ind : int
            The iteration index

        Returns
        -------
        params : dict of str to any
            Equal to list(self)[ind]
        """
        # This is used to make discrete sampling without replacement memory
        # efficient.
        for sub_grid in self.param_grid:
            # XXX: could memoize information used here
            if not sub_grid:
                if ind == 0:
                    return {}
                else:
                    ind -= 1
                    continue

            # Reverse so most frequent cycling parameter comes first
            keys, values_lists = zip(*sorted(sub_grid.items())[::-1])
            sizes = [len(v_list) for v_list in values_lists]
            total = np.product(sizes)

            if ind >= total:
                # Try the next grid
                ind -= total
            else:
                out = {}
                for key, v_list, n in zip(keys, values_lists, sizes):
                    ind, offset = divmod(ind, n)
                    out[key] = v_list[offset]
                return out

        raise IndexError('ParameterGrid index out of range')


class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
    """Abstract base class for hyper parameter search with cross-validation.
    """

    @abstractmethod
    @_deprecate_positional_args
    def __init__(self, estimator, *, scoring=None, n_jobs=None,
                 refit=True, cv=None, verbose=0,
                 pre_dispatch='2*n_jobs', error_score=np.nan,
                 return_train_score=True):

        self.scoring = scoring
        self.estimator = estimator
        self.n_jobs = n_jobs
        self.refit = refit
        self.cv = cv
        self.verbose = verbose
        self.pre_dispatch = pre_dispatch
        self.error_score = error_score
        self.return_train_score = return_train_score

    @property
    def _estimator_type(self):
        return self.estimator._estimator_type

    @property
    def _pairwise(self):
        # allows cross-validation to see 'precomputed' metrics
        return getattr(self.estimator, '_pairwise', False)

    def score(self, X, y=None):
        """Returns the score on the given data, if the estimator has been refit.

        This uses the score defined by ``scoring`` where provided, and the
        ``best_estimator_.score`` method otherwise.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Input data, where n_samples is the number of samples and
            n_features is the number of features.

        y : array-like of shape (n_samples, n_output) \
            or (n_samples,), default=None
            Target relative to X for classification or regression;
            None for unsupervised learning.

        Returns
        -------
        score : float
        """
        self._check_is_fitted('score')
        if self.scorer_ is None:
            raise ValueError("No score function explicitly defined, "
                             "and the estimator doesn't provide one %s"
                             % self.best_estimator_)
        if isinstance(self.scorer_, dict):
            if self.multimetric_:
                scorer = self.scorer_[self.refit]
            else:
                scorer = self.scorer_
            return scorer(self.best_estimator_, X, y)

        # callable
        score = self.scorer_(self.best_estimator_, X, y)
        if self.multimetric_:
            score = score[self.refit]
        return score

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def score_samples(self, X):
        """Call score_samples on the estimator with the best found parameters.

        Only available if ``refit=True`` and the underlying estimator supports
        ``score_samples``.

        .. versionadded:: 0.24

        Parameters
        ----------
        X : iterable
            Data to predict on. Must fulfill input requirements
            of the underlying estimator.

        Returns
        -------
        y_score : ndarray of shape (n_samples,)
        """
        self._check_is_fitted('score_samples')
        return self.best_estimator_.score_samples(X)

    def _check_is_fitted(self, method_name):
        if not self.refit:
            raise NotFittedError('This %s instance was initialized '
                                 'with refit=False. %s is '
                                 'available only after refitting on the best '
                                 'parameters. You can refit an estimator '
                                 'manually using the ``best_params_`` '
                                 'attribute'
                                 % (type(self).__name__, method_name))
        else:
            check_is_fitted(self)

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def predict(self, X):
        """Call predict on the estimator with the best found parameters.

        Only available if ``refit=True`` and the underlying estimator supports
        ``predict``.

        Parameters
        ----------
        X : indexable, length n_samples
            Must fulfill the input assumptions of the
            underlying estimator.

        """
        self._check_is_fitted('predict')
        return self.best_estimator_.predict(X)

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def predict_proba(self, X):
        """Call predict_proba on the estimator with the best found parameters.

        Only available if ``refit=True`` and the underlying estimator supports
        ``predict_proba``.

        Parameters
        ----------
        X : indexable, length n_samples
            Must fulfill the input assumptions of the
            underlying estimator.

        """
        self._check_is_fitted('predict_proba')
        return self.best_estimator_.predict_proba(X)

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def predict_log_proba(self, X):
        """Call predict_log_proba on the estimator with the best found parameters.

        Only available if ``refit=True`` and the underlying estimator supports
        ``predict_log_proba``.

        Parameters
        ----------
        X : indexable, length n_samples
            Must fulfill the input assumptions of the
            underlying estimator.

        """
        self._check_is_fitted('predict_log_proba')
        return self.best_estimator_.predict_log_proba(X)

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def decision_function(self, X):
        """Call decision_function on the estimator with the best found parameters.

        Only available if ``refit=True`` and the underlying estimator supports
        ``decision_function``.

        Parameters
        ----------
        X : indexable, length n_samples
            Must fulfill the input assumptions of the
            underlying estimator.

        """
        self._check_is_fitted('decision_function')
        return self.best_estimator_.decision_function(X)

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def transform(self, X):
        """Call transform on the estimator with the best found parameters.

        Only available if the underlying estimator supports ``transform`` and
        ``refit=True``.

        Parameters
        ----------
        X : indexable, length n_samples
            Must fulfill the input assumptions of the
            underlying estimator.

        """
        self._check_is_fitted('transform')
        return self.best_estimator_.transform(X)

    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
    def inverse_transform(self, Xt):
        """Call inverse_transform on the estimator with the best found params.

        Only available if the underlying estimator implements
        ``inverse_transform`` and ``refit=True``.

        Parameters
        ----------
        Xt : indexable, length n_samples
            Must fulfill the input assumptions of the
            underlying estimator.

        """
        self._check_is_fitted('inverse_transform')
        return self.best_estimator_.inverse_transform(Xt)

    @property
    def n_features_in_(self):
        # For consistency with other estimators we raise a AttributeError so
        # that hasattr() fails if the search estimator isn't fitted.
        try:
            check_is_fitted(self)
        except NotFittedError as nfe:
            raise AttributeError(
                "{} object has no n_features_in_ attribute."
                .format(self.__class__.__name__)
            ) from nfe

        return self.best_estimator_.n_features_in_

    @property
    def classes_(self):
        self._check_is_fitted("classes_")
        return self.best_estimator_.classes_

    def _run_search(self, evaluate_candidates):
        """Repeatedly calls `evaluate_candidates` to conduct a search.

        This method, implemented in sub-classes, makes it possible to
        customize the the scheduling of evaluations: GridSearchCV and
        RandomizedSearchCV schedule evaluations for their whole parameter
        search space at once but other more sequential approaches are also
        possible: for instance is possible to iteratively schedule evaluations
        for new regions of the parameter search space based on previously
        collected evaluation results. This makes it possible to implement
        Bayesian optimization or more generally sequential model-based
        optimization by deriving from the BaseSearchCV abstract base class.

        Parameters
        ----------
        evaluate_candidates : callable
            This callback accepts a list of candidates, where each candidate is
            a dict of parameter settings. It returns a dict of all results so
            far, formatted like ``cv_results_``.

        Examples
        --------

        ::

            def _run_search(self, evaluate_candidates):
                'Try C=0.1 only if C=1 is better than C=10'
                all_results = evaluate_candidates([{'C': 1}, {'C': 10}])
                score = all_results['mean_test_score']
                if score[0] < score[1]:
                    evaluate_candidates([{'C': 0.1}])
        """
        raise NotImplementedError("_run_search not implemented.")

    def _check_refit_for_multimetric(self, scores):
        """Check `refit` is compatible with `scores` is valid"""
        multimetric_refit_msg = (
            "For multi-metric scoring, the parameter refit must be set to a "
            "scorer key or a callable to refit an estimator with the best "
            "parameter setting on the whole data and make the best_* "
            "attributes available for that metric. If this is not needed, "
            f"refit should be set to False explicitly. {self.refit!r} was "
            "passed.")

        valid_refit_dict = (isinstance(self.refit, str) and
                            self.refit in scores)

        if (self.refit is not False and not valid_refit_dict
                and not callable(self.refit)):
            raise ValueError(multimetric_refit_msg)

    @_deprecate_positional_args
    def fit(self, X, y=None, *, groups=None, **fit_params):
        """Run fit with all sets of parameters.

        Parameters
        ----------

        X : array-like of shape (n_samples, n_features)
            Training vector, where n_samples is the number of samples and
            n_features is the number of features.

        y : array-like of shape (n_samples, n_output) \
            or (n_samples,), default=None
            Target relative to X for classification or regression;
            None for unsupervised learning.

        groups : array-like of shape (n_samples,), default=None
            Group labels for the samples used while splitting the dataset into
            train/test set. Only used in conjunction with a "Group" :term:`cv`
            instance (e.g., :class:`~sklearn.model_selection.GroupKFold`).

        **fit_params : dict of str -> object
            Parameters passed to the ``fit`` method of the estimator
        """
        estimator = self.estimator
        cv = check_cv(self.cv, y, classifier=is_classifier(estimator))

        refit_metric = "score"

        if callable(self.scoring):
            scorers = self.scoring
        elif self.scoring is None or isinstance(self.scoring, str):
            scorers = check_scoring(self.estimator, self.scoring)
        else:
            scorers = _check_multimetric_scoring(self.estimator, self.scoring)
            self._check_refit_for_multimetric(scorers)
            refit_metric = self.refit

        X, y, groups = indexable(X, y, groups)
        fit_params = _check_fit_params(X, fit_params)

        n_splits = cv.get_n_splits(X, y, groups)

        base_estimator = clone(self.estimator)

        parallel = Parallel(n_jobs=self.n_jobs,
                            pre_dispatch=self.pre_dispatch)

        fit_and_score_kwargs = dict(scorer=scorers,
                                    fit_params=fit_params,
                                    return_train_score=self.return_train_score,
                                    return_n_test_samples=True,
                                    return_times=True,
                                    return_parameters=False,
                                    error_score=self.error_score,
                                    verbose=self.verbose)
        results = {}
        with parallel:
            all_candidate_params = []
            all_out = []

            def evaluate_candidates(candidate_params):
                candidate_params = list(candidate_params)
                n_candidates = len(candidate_params)

                if self.verbose > 0:
                    print("Fitting {0} folds for each of {1} candidates,"
                          " totalling {2} fits".format(
                              n_splits, n_candidates, n_candidates * n_splits))

                out = parallel(delayed(_fit_and_score)(clone(base_estimator),
                                                       X, y,
                                                       train=train, test=test,
                                                       parameters=parameters,
                                                       split_progress=(
                                                           split_idx,
                                                           n_splits),
                                                       candidate_progress=(
                                                           cand_idx,
                                                           n_candidates),
                                                       **fit_and_score_kwargs)
                               for (cand_idx, parameters),
                                   (split_idx, (train, test)) in product(
                                   enumerate(candidate_params),
                                   enumerate(cv.split(X, y, groups))))

                if len(out) < 1:
                    raise ValueError('No fits were performed. '
                                     'Was the CV iterator empty? '
                                     'Were there no candidates?')
                elif len(out) != n_candidates * n_splits:
                    raise ValueError('cv.split and cv.get_n_splits returned '
                                     'inconsistent results. Expected {} '
                                     'splits, got {}'
                                     .format(n_splits,
                                             len(out) // n_candidates))

                # For callable self.scoring, the return type is only know after
                # calling. If the return type is a dictionary, the error scores
                # can now be inserted with the correct key. The type checking
                # of out will be done in `_insert_error_scores`.
                if callable(self.scoring):
                    _insert_error_scores(out, self.error_score)
                all_candidate_params.extend(candidate_params)
                all_out.extend(out)

                nonlocal results
                results = self._format_results(
                    all_candidate_params, n_splits, all_out)
                return results

            self._run_search(evaluate_candidates)

            # multimetric is determined here because in the case of a callable
            # self.scoring the return type is only known after calling
            first_test_score = all_out[0]['test_scores']
            self.multimetric_ = isinstance(first_test_score, dict)

            # check refit_metric now for a callabe scorer that is multimetric
            if callable(self.scoring) and self.multimetric_:
                self._check_refit_for_multimetric(first_test_score)
                refit_metric = self.refit

        # For multi-metric evaluation, store the best_index_, best_params_ and
        # best_score_ iff refit is one of the scorer names
        # In single metric evaluation, refit_metric is "score"
        if self.refit or not self.multimetric_:
            # If callable, refit is expected to return the index of the best
            # parameter set.
            if callable(self.refit):
                self.best_index_ = self.refit(results)
                if not isinstance(self.best_index_, numbers.Integral):
                    raise TypeError('best_index_ returned is not an integer')
                if (self.best_index_ < 0 or
                   self.best_index_ >= len(results["params"])):
                    raise IndexError('best_index_ index out of range')
            else:
                self.best_index_ = results["rank_test_%s"
                                           % refit_metric].argmin()
                self.best_score_ = results["mean_test_%s" % refit_metric][
                                           self.best_index_]
            self.best_params_ = results["params"][self.best_index_]

        if self.refit:
            # we clone again after setting params in case some
            # of the params are estimators as well.
            self.best_estimator_ = clone(clone(base_estimator).set_params(
                **self.best_params_))
            refit_start_time = time.time()
            if y is not None:
                self.best_estimator_.fit(X, y, **fit_params)
            else:
                self.best_estimator_.fit(X, **fit_params)
            refit_end_time = time.time()
            self.refit_time_ = refit_end_time - refit_start_time

        # Store the only scorer not as a dict for single metric evaluation
        self.scorer_ = scorers

        self.cv_results_ = results
        self.n_splits_ = n_splits

        return self

    def _format_results(self, candidate_params, n_splits, out):
        n_candidates = len(candidate_params)
        out = _aggregate_score_dicts(out)

        results = {}

        def _store(key_name, array, weights=None, splits=False, rank=False):
            """A small helper to store the scores/times to the cv_results_"""
            # When iterated first by splits, then by parameters
            # We want `array` to have `n_candidates` rows and `n_splits` cols.
            array = np.array(array, dtype=np.float64).reshape(n_candidates,
                                                              n_splits)
            if splits:
                for split_idx in range(n_splits):
                    # Uses closure to alter the results
                    results["split%d_%s"
                            % (split_idx, key_name)] = array[:, split_idx]

            array_means = np.average(array, axis=1, weights=weights)
            results['mean_%s' % key_name] = array_means

            if (key_name.startswith(("train_", "test_")) and
                    np.any(~np.isfinite(array_means))):
                warnings.warn(
                    f"One or more of the {key_name.split('_')[0]} scores "
                    f"are non-finite: {array_means}",
                    category=UserWarning
                )

            # Weighted std is not directly available in numpy
            array_stds = np.sqrt(np.average((array -
                                             array_means[:, np.newaxis]) ** 2,
                                            axis=1, weights=weights))
            results['std_%s' % key_name] = array_stds

            if rank:
                results["rank_%s" % key_name] = np.asarray(
                    rankdata(-array_means, method='min'), dtype=np.int32)

        _store('fit_time', out["fit_time"])
        _store('score_time', out["score_time"])
        # Use one MaskedArray and mask all the places where the param is not
        # applicable for that candidate. Use defaultdict as each candidate may
        # not contain all the params
        param_results = defaultdict(partial(MaskedArray,
                                            np.empty(n_candidates,),
                                            mask=True,
                                            dtype=object))
        for cand_idx, params in enumerate(candidate_params):
            for name, value in params.items():
                # An all masked empty array gets created for the key
                # `"param_%s" % name` at the first occurrence of `name`.
                # Setting the value at an index also unmasks that index
                param_results["param_%s" % name][cand_idx] = value

        results.update(param_results)
        # Store a list of param dicts at the key 'params'
        results['params'] = candidate_params

        test_scores_dict = _normalize_score_results(out["test_scores"])
        if self.return_train_score:
            train_scores_dict = _normalize_score_results(out["train_scores"])

        for scorer_name in test_scores_dict:
            # Computed the (weighted) mean and std for test scores alone
            _store('test_%s' % scorer_name, test_scores_dict[scorer_name],
                   splits=True, rank=True,
                   weights=None)
            if self.return_train_score:
                _store('train_%s' % scorer_name,
                       train_scores_dict[scorer_name],
                       splits=True)

        return results


class GridSearchCV(BaseSearchCV):

    _required_parameters = ["estimator", "param_grid"]

    @_deprecate_positional_args
    def __init__(self, estimator, param_grid, *, scoring=None,
                 n_jobs=None, refit=True, cv=None,
                 verbose=0, pre_dispatch='2*n_jobs',
                 error_score=np.nan, return_train_score=False):
        super().__init__(
            estimator=estimator, scoring=scoring,
            n_jobs=n_jobs, refit=refit, cv=cv, verbose=verbose,
            pre_dispatch=pre_dispatch, error_score=error_score,
            return_train_score=return_train_score)
        self.param_grid = param_grid
        _check_param_grid(param_grid)

    def _run_search(self, evaluate_candidates):
        """Search all candidates in param_grid"""
        evaluate_candidates(ParameterGrid(self.param_grid))
        
def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
                   parameters, fit_params, return_train_score=False,
                   return_parameters=False, return_n_test_samples=False,
                   return_times=False, return_estimator=False,
                   split_progress=None, candidate_progress=None,
                   error_score=np.nan):

    """Fit estimator and compute scores for a given dataset split"""
    progress_msg = ""
    if verbose > 2:
        if split_progress is not None:
            progress_msg = f" {split_progress[0]+1}/{split_progress[1]}"
        if candidate_progress and verbose > 9:
            progress_msg += (f"; {candidate_progress[0]+1}/"
                             f"{candidate_progress[1]}")

    if verbose > 1:
        if parameters is None:
            params_msg = ''
        else:
            sorted_keys = sorted(parameters)  # Ensure deterministic o/p
            params_msg = (', '.join(f'{k}={parameters[k]}'
                                    for k in sorted_keys))
    if verbose > 9:
        start_msg = f"[CV{progress_msg}] START {params_msg}"
        print(f"{start_msg}{(80 - len(start_msg)) * '.'}")

    # Adjust length of sample weights
    fit_params = fit_params if fit_params is not None else {}
    fit_params = _check_fit_params(X, fit_params, train)

    if parameters is not None:
        # clone after setting parameters in case any parameters
        # are estimators (like pipeline steps)
        # because pipeline doesn't clone steps in fit
        cloned_parameters = {}
        for k, v in parameters.items():
            cloned_parameters[k] = clone(v, safe=False)

        estimator = estimator.set_params(**cloned_parameters)

    start_time = time.time()

    X_train_before, y_train = _safe_split(estimator, X, y, train)
    X_test_before, y_test = _safe_split(estimator, X, y, test, train)
    
    # Adjust features such that there is no data leakage across the k-folds
    X_train, remnant_dict, all_dict, remnant_dict_hard, all_dict_hard, \
          = insert_local_remnant_ratios(
        X_train_before,y_train)
    X_test = insert_local_test_ratios(X_test_before, remnant_dict, all_dict, \
                                        remnant_dict_hard, all_dict_hard)

    
    
    result = {}
    try:
        if y_train is None:
            estimator.fit(X_train, **fit_params)
        else:
            estimator.fit(X_train, y_train, **fit_params)

    except Exception as e:
        # Note fit time as time until error
        fit_time = time.time() - start_time
        score_time = 0.0
        if error_score == 'raise':
            raise
        elif isinstance(error_score, numbers.Number):
            if isinstance(scorer, dict):
                test_scores = {name: error_score for name in scorer}
                if return_train_score:
                    train_scores = test_scores.copy()
            else:
                test_scores = error_score
                if return_train_score:
                    train_scores = error_score
            warnings.warn("Estimator fit failed. The score on this train-test"
                          " partition for these parameters will be set to %f. "
                          "Details: \n%s" %
                          (error_score, format_exc()),
                          FitFailedWarning)
        else:
            raise ValueError("error_score must be the string 'raise' or a"
                             " numeric value. (Hint: if using 'raise', please"
                             " make sure that it has been spelled correctly.)")
        result["fit_failed"] = True
    else:
        result["fit_failed"] = False

        fit_time = time.time() - start_time
        test_scores = _score(estimator, X_test, y_test, scorer)
        score_time = time.time() - start_time - fit_time
        if return_train_score:
            train_scores = _score(estimator, X_train, y_train, scorer)

    if verbose > 1:
        total_time = score_time + fit_time
        end_msg = f"[CV{progress_msg}] END "
        result_msg = params_msg + (";" if params_msg else "")
        if verbose > 2 and isinstance(test_scores, dict):
            for scorer_name in sorted(test_scores):
                result_msg += f" {scorer_name}: ("
                if return_train_score:
                    scorer_scores = train_scores[scorer_name]
                    result_msg += f"train={scorer_scores:.3f}, "
                result_msg += f"test={test_scores[scorer_name]:.3f})"
        result_msg += f" total time={logger.short_format_time(total_time)}"

        # Right align the result_msg
        end_msg += "." * (80 - len(end_msg) - len(result_msg))
        end_msg += result_msg
        print(end_msg)

    result["test_scores"] = test_scores
    if return_train_score:
        result["train_scores"] = train_scores
    if return_n_test_samples:
        result["n_test_samples"] = _num_samples(X_test)
    if return_times:
        result["fit_time"] = fit_time
        result["score_time"] = score_time
    if return_parameters:
        result["parameters"] = parameters
    if return_estimator:
        result["estimator"] = estimator
    return result

def _score(estimator, X_test, y_test, scorer):
    """Compute the score(s) of an estimator on a given test set.

    Will return a dict of floats if `scorer` is a dict, otherwise a single
    float is returned.
    """
    if isinstance(scorer, dict):
        # will cache method calls if needed. scorer() returns a dict
        scorer = _MultimetricScorer(**scorer)
    if y_test is None:
        scores = scorer(estimator, X_test)
    else:
        scores = scorer(estimator, X_test, y_test)

    error_msg = ("scoring must return a number, got %s (%s) "
                 "instead. (scorer=%s)")
    if isinstance(scores, dict):
        for name, score in scores.items():
            if hasattr(score, 'item'):
                with suppress(ValueError):
                    # e.g. unwrap memmapped scalars
                    score = score.item()
            if not isinstance(score, numbers.Number):
                raise ValueError(error_msg % (score, type(score), name))
            scores[name] = score
    else:  # scalar
        if hasattr(scores, 'item'):
            with suppress(ValueError):
                # e.g. unwrap memmapped scalars
                scores = scores.item()
        if not isinstance(scores, numbers.Number):
            raise ValueError(error_msg % (scores, type(scores), scorer))
    return scores


# Create LoTSS-DR2 cutouts of 151 remnants

In [None]:
# Create LOTSS cutout sets
# Get cutout size in pixels
rotated_size_arcsec = 100
rotated_size = int(round(rotated_size_arcsec/angular_resolution_LoTSS))
fullsize = int(np.ceil(rotated_size*np.sqrt(2)))
print(f'Full cutout sizes will be {fullsize} pixels or {fullsize*angular_resolution_LoTSS} arcsec')
print(f'Rotated cutout sizes will be {rotated_size} pixels or {rotated_size_arcsec} arcsec')
LoTSS_DR2_dir='LoTSS_DR2/RA0h_field'


overwrite=False
s4 = CutoutSettings(experiment='LoTSS only; no clip; size invariant', 
            run_id=312, store_filename='LoTSSDR2_151remnants_noclip_sizeinvariant', 
            fullsize=fullsize, data_dir=LoTSS_DR2_dir,arcsec_per_pixel=angular_resolution_LoTSS, 
            overwrite=overwrite, apply_clipping=False, lower_sigma_limit=1.5,
            upper_sigma_limit=1e9, variable_size=True,
            map_to_run_id=314, 
            normalize=False,zoom_in=False,
            map_to_binpath='LoTSS_Lockman_sizeinvariant_resolved_10mJy_noclip_9x9_ID314/resultLoTSS_Lockman_sizeinvariant_resolved_10mJy_noclip_9x9x95_0.32305409446133365_0.0030974898406490115_ID314.bin')
s = s4
debug = False
np.random.seed(42)
with open(os.path.join(run_dir,'map_151remnants_to_lockman_LoTSDR2.sh'),'w') as f:
    
    # Define parameters
    cutouts_bin_name = s.store_filename
    #cutouts_binary_path
    run_id = s.run_id
    gpu_id = 0
    som_size = 2
    overwrite=s.overwrite
    # Create SOM info object
    som = post.SOM([], number_of_channels, som_size, som_size, som_depth, "quadratic", 
               output_directory, trained_subdirectory, som_label, rotated_size, run_id)
    som.som_width = 9
    som.som_height = 9
    som.som_depth = 1
    som.number_of_channels = 1
    som.gauss_decrease = 0.9
    som.gauss_end = 0.3
    som.learning_constraint_decrease = 0.8
    som.learning_constraint = som.som_width*som.som_height/100
    som.gauss_start = max(som.som_width, som.som_height)/2
    som.random_seed = 42
    som.pbc = False
    som.init = "zero"
    som.layout = "quadratic"
    som.training_dataset_name = cutouts_bin_name
    som.save()
    som.print()
    rotated_size = int(round(rotated_size_arcsec/angular_resolution_LoTSS))
    fullsize = int(np.ceil(rotated_size*np.sqrt(2)))

    # Get numpy list of cutouts and update catalogue to only contain sources that succesfully got extracted.
    store_filename = cutouts_bin_name
    lower_sigma_limit = s.lower_sigma_limit
    upper_sigma_limit = s.upper_sigma_limit
    calibration_cutouts, calibration_cat_extracted = post.fits_to_cutouts_using_astropy(
        fits_filename, s.store_filename, 
        s.fullsize, calibration_set, mosaic_id_key, ra_key, dec_key, s.data_dir, 
        arcsec_per_pixel=s.arcsec_per_pixel,
        single_field=False,
        gaus_cat=gaul_cat, remove_neighbours=True,component_cat=comp_cat,
        rely_on_catalogue_size=True,sort=False, dimensions_normal=True, 
        overwrite=s.overwrite, 
        apply_clipping=s.apply_clipping, lower_sigma_limit=s.lower_sigma_limit,
        upper_sigma_limit=s.upper_sigma_limit, rescale=True,
        variable_size=s.variable_size, destination_size=s.fullsize, 
        rough_remove_neighbours=False, full_cat = value_added_catalogue,
        get_local_rms=True,
        apply_mask=True, verbose=False, mode='partial')    

    # Plot a random subset of the data
    if debug:
        for i in range(5):
            if 1==1 or not s.apply_clipping:
                print('RA DEC and totalflux, and peakflux')
                print(calibration_cat_extracted.iloc[i].RA, calibration_cat_extracted.iloc[i].DEC)
                print(calibration_cat_extracted.iloc[i].Total_flux, calibration_cat_extracted.iloc[i].Peak_flux)
                print('LoTSs')
                post.plot_cutout2D(calibration_cutouts[i], 
                                            wcs=None, sqrt=True,colorbar=True,cmap='viridis')

    # Check cutouts shape
    print("Cutouts shape (# cutouts, # channels, width, height) = ", np.shape(calibration_cutouts))
    # Write cutouts to binary file (required file format for PINK software)
    cutouts_binary_path_lotss = os.path.join(data_directory, cutouts_bin_name + '.bin')
    post.write_numpy_to_binary_v2(cutouts_binary_path_lotss, calibration_cutouts, 
                                    som.layout, overwrite=True)

    # Write mapping bash scripts
    map_path = os.path.join(map_dir, 
                f'{cutouts_bin_name}_ID{s.run_id}_mapped_to_ID{s.map_to_run_id}.bin')
    s.map_path = map_path
    s.som = som
    s.save(output_directory)
    mapstring =post.map_dataset_to_trained_som(som, cutouts_binary_path_lotss, 
                map_path, s.map_to_binpath,
        gpu_id, use_gpu=True, verbose=True, version=2,
        alternate_neuron_dimension=None, use_cuda_visible_devices=True,
        rotation_path=None, circular_shape=False)
    print(mapstring, file=f)
    print("Number of extracted large cutouts:", len(calibration_cutouts))
    print("Bash script to map to SOM:", map_path)

# Discard potential nearby SFGs

In [None]:
# All not calibration sources
run_id=312
settingsfile = post.load_pickle(os.path.join(output_directory,
                                f'SOM_settings_object_id{run_id}.pkl'))
run_id=367
cali_SFGfiltered_settingsfile = deepcopy(settingsfile)
cali_SFGfiltered_settingsfile.run_id = run_id
cali_SFGfiltered_settingsfile.store_filename = \
f'LoTSSDR2_{len(calibration_cutouts)}remnants_noclip_sizeinvariant_SFGfiltered'

# Get all HETDEX sources > 60 that were successfully extracted
run_id=318
settingsfile = post.load_pickle(os.path.join(output_directory,
                                f'SOM_settings_object_id{run_id}.pkl'))

run_id=340
hetdex_SFGfiltered_settingsfile = deepcopy(settingsfile)
hetdex_SFGfiltered_settingsfile.run_id = run_id
hetdex_SFGfiltered_settingsfile.store_filename += '_SFGfiltered' 

# Load catalogue
cat_path = os.path.join(output_directory,'catalogue_'+ \
    settingsfile.store_filename +'.h5')
cat_hetdex_bigger_than60arcsec_extracted = pd.read_hdf(cat_path)
extraction_fail_count = len(hetdex_bigger_than60arcsec)-len(cat_hetdex_bigger_than60arcsec_extracted)
print(f"Initial value-added catalogue contains {len(hetdex_bigger_than60arcsec)} sources "
     f"bigger than 60arcsec.\n"
     f"We cannot create cutouts for {extraction_fail_count} sources, "
      f"leaving us with {len(cat_hetdex_bigger_than60arcsec_extracted)} sources.")
# Load cutouts
cutouts_path = os.path.join(output_directory,
    settingsfile.store_filename +'.npy')
print("Loading cutouts from:",cutouts_path )
cutouts_hetdex_bigger_than60arcsec_extracted = np.load(cutouts_path)
print(f"Of these {len(cutouts_hetdex_bigger_than60arcsec_extracted)} cutouts:")
print(f"{len(calibration_cutouts)} have been accepted as AGN remnant candidate through visual inspection.")

In [None]:
# Filter based on nearby SFG (starforming galaxies)
print("For the calibration set:")
accepted_ids_cali, rejected_ids_cali = post.filter_cutouts_with_galaxy_scale_emission(
    list(range(len(calibration_cat_extracted))), calibration_cat_extracted,
   radio_to_optical_extent_ratio=10, object_search_radius_in_arcsec=10, verbose=True)
assert len(rejected_ids_cali)==0, "calibration set should not contain sfg"

print("\nFor the notcalibration set:")
accepted_ids_notcali, rejected_ids_notcali = post.filter_cutouts_with_galaxy_scale_emission(
    list(range(len(notcalibration_cat_extracted))), notcalibration_cat_extracted,
   radio_to_optical_extent_ratio=10, object_search_radius_in_arcsec=10, verbose=False)
print(f'{len(rejected_ids_notcali)} out of {len(notcalibration_cat_extracted)} sources were '
      'rejected because the angular extent of the optical emission is larger than '
      '10 times the extend of the radio emission.')


accepted_ids, rejected_ids = post.filter_cutouts_with_galaxy_scale_emission(
    list(range(len(cat_hetdex_bigger_than60arcsec_extracted))), 
    cat_hetdex_bigger_than60arcsec_extracted,
   radio_to_optical_extent_ratio=10, object_search_radius_in_arcsec=10, verbose=False)
print("\nFor all HETDEX sources >60arcsec:")
print(f'{len(rejected_ids)} out of {len(cat_hetdex_bigger_than60arcsec_extracted)} sources were '
      'rejected because the angular extent of the optical emission is larger than '
      '10 times the extend of the radio emission.')




In [None]:
# Tally how many filtered sources are actually SFG
debug=False
if debug:
    bin_path =  os.path.join(data_directory,
                        settingsfile.store_filename +'.bin')
    post.plot_LoTSS_and_PANSTARRS(1000, rejected_ids, 
            cat_hetdex_bigger_than60arcsec_extracted, bin_path, 
        version=2, save=False,         
         save_dir=None, save_index=None, query_SIMBAD_for_source_description=False,
            overwrite=False, 
         print_radio_to_optical_extent=False,
        title_lotss='LoTSS DR2',
         print_FIRST_cat_message=False,model_maj_min_angle_degree=False,                            
         arcsec_per_pixel_lotss=1.5, arcsec_per_pixel_PANSTARRS=0.25, plot_reticle=True)


In [None]:
# Create new files after filtering
# cali
cali_SFGfiltered_settingsfile, cali_SFGfiltered_cat, cali_SFGfiltered_cat_path, \
    cali_SFGfiltered_cutouts, cali_SFGfiltered_cutouts_path, \
    cali_SFGfiltered_bin_path = post.create_new_files_after_filtering(
    accepted_ids_cali, calibration_cutouts,
    calibration_cat_extracted, cali_SFGfiltered_settingsfile,
    data_directory,output_directory)
# all >60"
hetdex_SFGfiltered_settingsfile, hetdex_SFGfiltered_cat, hetdex_SFGfiltered_cat_path, \
    hetdex_SFGfiltered_cutouts, hetdex_SFGfiltered_cutouts_path, \
    hetdex_SFGfiltered_bin_path = post.create_new_files_after_filtering(
    accepted_ids,cutouts_hetdex_bigger_than60arcsec_extracted,
    cat_hetdex_bigger_than60arcsec_extracted,hetdex_SFGfiltered_settingsfile,
    data_directory,output_directory)

print(f"\n So we have {len(hetdex_SFGfiltered_cutouts)} hetdex cutouts:")
print(f"{len(cali_SFGfiltered_cutouts)} have been accepted as AGN remnant candidate through visual inspection.")
print(f"and {len(notcali_SFGfiltered_cutouts)} have been rejected as AGN remnant candidate through visual inspection.")
print(f"So {len(notcali_SFGfiltered_cutouts)+len(cali_SFGfiltered_cutouts)} visually inspected sources left.")
#print(f"{len(hetdex_bigger_than60arcsec)-151-extraction_fail_count+1} non remnant after extraction and 150 remnants.")

#  Split data into train and test (hold-out) set
And bring down class imbalance by undersampling the non-candidates.

In [None]:
# Labels
source_names = hetdex_SFGfiltered_cat.Source_Name.values
labels = np.array([sn in cali_SFGfiltered_cat.Source_Name.values
 for sn in source_names])

def undersample_non_candidates_upto(labels,upto=1050,random_seed=42):
    """Explicitly undersample a class upto"""
    np.random.seed(random_seed)
    # Get indices of non candidates
    idx_list = np.array(deepcopy([ilabel for ilabel, label in enumerate(labels) if label==False]))
    rest_list = np.array(deepcopy([ilabel for ilabel, label in enumerate(labels) if label]))
    if upto>=len(idx_list):
        print("You are asking to oversample this class.")
        sdfdsf
    # Random shuffle a copy of these indices
    np.random.shuffle(idx_list)
    # Get the first upto indices from this list
    idx_list = sorted(np.concatenate([idx_list[:upto],rest_list]))
    return idx_list
    
np.random.seed(random_seed)
# Split data into train and test set
test_size=0.3
idx_train, idx_test = train_test_split(list(range(len(labels))), test_size=test_size,
            stratify=labels, random_state=random_seed)

cat_train = hetdex_SFGfiltered_cat.iloc[idx_train]
cat_test = hetdex_SFGfiltered_cat.iloc[idx_test]
labels_train_full = labels[idx_train]
labels_test = labels[idx_test]
source_names_train = source_names[idx_train]
source_names_test = source_names[idx_test]
hetdex_cutouts_train = hetdex_SFGfiltered_cutouts[idx_train]
cutouts_test = hetdex_SFGfiltered_cutouts[idx_test]

undersampled_train_ids = undersample_non_candidates_upto(labels_train_full,
                                 upto=int(10*np.sum(labels_train_full)),
                                 random_seed=random_seed)

undersampled_train_ids_cali = np.array([tid
 for tid in undersampled_train_ids if labels_train_full[tid]])
idx_test_cali = np.array([tid
 for tid in idx_test if labels[tid]])
labels_train = labels_train_full[undersampled_train_ids]

# Numbers for approach diagram

In [None]:
# Format data into single pandas DF for convenient seaborn plotting
plot_data1 = pd.DataFrame({'source_size':np.log10(hetdex_SFGfiltered_cat["source_size"]),
            'Total_flux':np.log10(hetdex_SFGfiltered_cat["Total_flux"]),
            'Dataset':['All sources >60arcsec' for _ in range(len(hetdex_SFGfiltered_cat))]})
plot_data2 = pd.DataFrame({'source_size':np.log10(cali_SFGfiltered_cat["source_size"]),
            'Total_flux':np.log10(cali_SFGfiltered_cat["Total_flux"]),
            'Dataset':['Calibration sources' for _ in range(len(cali_SFGfiltered_cat))]})
plot_data = pd.concat([plot_data1,plot_data2])

# Seaborn plot
g = sns.JointGrid(data=plot_data, x='source_size', 
                  y="Total_flux", hue='Dataset', marginal_ticks=True)
g.plot_joint(sns.scatterplot,legend=False,size=1)
#g.plot_marginals(sns.kdeplot,common_norm=False)
g.plot_marginals(sns.histplot,stat='density',common_norm=False,element='step',fill=True)
g.set_axis_labels('Major axis length log10[arcsec]','Total flux log10[mJy]')
plt.savefig(os.path.join(paper_fig_dir,
        "sourcesize_vs_totalflux.pdf"),
           bbox_inches='tight')
plt.show()

In [None]:
print(f"{np.sum(value_added_catalogue['source_size'] > 60)} sources in hetdex are >60arcsec,"
     f" of which {len(remnant_names)} are remnant candidates.")
print(f"Cutout extraction succeeds for {len(cat_hetdex_bigger_than60arcsec_extracted)}"
      f" sources in hetdex >60arcsec of which {len(calibration_cat_extracted)} remnant candidates.")
print(f"After excluding nearby SFGs we have {len(hetdex_SFGfiltered_cutouts)} sources left"
      f" of which {len(cali_SFGfiltered_cutouts)} are remnant candidates.")
print(f"""\nAfter visual inspection comparing the LoTSS and the Panoramic Survey Telescope 
and Rapid Response System 1 (Pan-STARRS1) $3\\pi$ steradian survey \\citep{{panstarrs}} images,
we found that this label turns out to be correct for ${(len(rejected_ids)-23)/len(rejected_ids)*100:.1f}\\%$ 
(${(len(rejected_ids)-23)}/{len(rejected_ids)}$) of these sources.
When classifying radio sources as AGN remnant candidates in even larger datasets in the 
future, we will not visually inspect these potential nearby SFGs but simply discard them.
In this work too, we discard all ${len(rejected_ids)}$ sources labelled as potential 
nearby SFGs, leaving us with $3,908$ radio sources of which $150$ candidates.
This means that we do not consider roughly ${23/len(hetdex_SFGfiltered_cutouts)*100:.1f}\\%$ of sources $>60$ arcsec that could 
have been AGN remnant candidates.""")
print(f"\n{1-test_size:.0%} of sources for training set:\n"
     f"{len(source_names_train)} sources, of which {np.sum(labels_train_full)} candidates")

print(f"\nAfter undersampling the majority class, we reduce the training set to:\n"
     f"{len(undersampled_train_ids)} sources, of which {np.sum(labels_train)} candidates")

print(f"\nSpecifically we undersampled the majority class from ${len(source_names_train)-np.sum(labels_train_full)}$ to $1,050$ sources,"
      f" increasing the ratio AGN remnant candidate versus not-yet-inspected from 1 in 26 to"
      " 1 in 10.")

print(f"\n{test_size:.0%} of sources for test set:\n"
     f"{len(source_names_test)} sources, of which {np.sum(labels_test)} candidates")

In [None]:
# All (undersampled) train data
train_settings = deepcopy(hetdex_SFGfiltered_settingsfile)
train_settings.store_filename = 'LoTSSDR2_hetdex_morethan60arcsec_noclip_sizeinvariant_SFGfiltered_train_undersampled'
train_settings.run_id = 370
train_settings.map_to_id = 370

train_settings, train_cat, train_cat_path, \
    train_cutouts, train_cutouts_path, \
    train_bin_path = post.create_new_files_after_filtering(
    undersampled_train_ids, hetdex_cutouts_train,
    cat_train, train_settings,
    data_directory,output_directory)

# Just the train candidates
cali_train_settings = deepcopy(hetdex_SFGfiltered_settingsfile)
cali_train_settings.store_filename = 'LoTSSDR2_hetdex_morethan60arcsec_noclip_sizeinvariant_SFGfiltered_train_candidates'
cali_train_settings.run_id = 371
cali_train_settings.map_to_id = 370

cali_train_settings, cali_train_cat, cali_train_cat_path, \
    cali_train_cutouts, cali_train_cutouts_path, \
    cali_train_bin_path = post.create_new_files_after_filtering(
    undersampled_train_ids_cali, hetdex_cutouts_train,
    cat_train, cali_train_settings,
    data_directory,output_directory)

# All test data
test_settings = deepcopy(hetdex_SFGfiltered_settingsfile)
test_settings.store_filename = 'LoTSSDR2_hetdex_morethan60arcsec_noclip_sizeinvariant_SFGfiltered_test'
test_settings.run_id = 372
test_settings.map_to_id = 370

test_settings, test_cat, test_cat_path, \
    test_cutouts, test_cutouts_path, \
    test_bin_path = post.create_new_files_after_filtering(
    idx_test, hetdex_SFGfiltered_cutouts,
    hetdex_SFGfiltered_cat,test_settings,
    data_directory,output_directory)

# Just the test candidates
cali_test_settings = deepcopy(hetdex_SFGfiltered_settingsfile)
cali_test_settings.store_filename = 'LoTSSDR2_hetdex_morethan60arcsec_noclip_sizeinvariant_SFGfiltered_test_candidates'
cali_test_settings.run_id = 373
cali_test_settings.map_to_id = 370

cali_test_settings, cali_test_cat, cali_test_cat_path, \
    cali_test_cutouts, cali_test_cutouts_path, \
    cali_test_bin_path = post.create_new_files_after_filtering(
    idx_test_cali, hetdex_SFGfiltered_cutouts,
    hetdex_SFGfiltered_cat,cali_test_settings,
    data_directory,output_directory)

# Cutout examples

In [None]:
# Gather examples
np.random.seed(random_seed)
candi_examples=[]
noncandi_examples=[]
n_examples=5
for i, cutout in enumerate(train_cutouts):
    if train_cat.iloc[i].Source_Name in cali_train_cat.Source_Name.values:
        candi_examples.append(cutout)
    else:
        noncandi_examples.append(cutout)
candi_examples = np.array(candi_examples)
np.random.shuffle(candi_examples)
noncandi_examples = np.array(noncandi_examples)
np.random.shuffle(noncandi_examples)

matplotlib.rcParams.update({'font.size': 10})
f, axx = plt.subplots(n_examples,2,figsize=(4.5,10),
                       constrained_layout = True)
axx[0,1].set_title("Not yet inspected")
axx[0,0].set_title("AGN remnant candidates")
ll = np.shape(cutout)[0]
w = int(ll/np.sqrt(2))
ss=int((ll-w)/2)
print(ll,w,ss)
for i in range(n_examples):
    im = np.array(deepcopy(noncandi_examples[i]))
    im = im[ss:-ss,ss:-ss]
    axx[i,1].imshow(im)
for i in range(n_examples):
    im = np.array(deepcopy(candi_examples[i]))
    im = im[ss:-ss,ss:-ss]
    axx[i,0].imshow(im)
[axi.set_axis_off() for axi in axx.ravel()]
plt.savefig(os.path.join(paper_fig_dir,f'cutout_examples.pdf'),bbox_inches='tight')
plt.show()


# Train SOM with train set

In [None]:
# Train SOM
s = deepcopy(train_settings)
run_id = s.run_id
print("Run Id of to train som is:", run_id)
print("Its name is:", s.store_filename)

s.data_dir = 'LoTSS_DR2/RA0h_field'
with open(os.path.join(
    data_directory,'run','trainSOM_SFGfiltered_undersampled.sh'),'w') as f:

    # Define parameters
    cutouts_bin_name = s.store_filename
    #cutouts_binary_path
    gpu_id = 0
    som_size = 2
    overwrite=s.overwrite
    print("overwrite:", overwrite)
    som = s.som

    # Create SOM info object
    som.som_width = 9
    som.som_height = 9
    som.som_depth = 1
    som.number_of_channels = 1
    som.gauss_decrease = 0.9
    som.gauss_end = 0.3
    som.learning_constraint_decrease = 0.8
    som.learning_constraint = som.som_width*som.som_height/100
    som.gauss_start = max(som.som_width, som.som_height)/2
    som.random_seed = 42
    som.pbc = False
    som.init = "zero"
    som.layout = "quadratic"
    som.training_dataset_name = s.store_filename
    som.run_id = s.run_id
    som.print()
    som.output_directory = output_directory
    som.save()
    
    rotated_size = int(round(rotated_size_arcsec/angular_resolution_LoTSS))
    fullsize = int(np.ceil(rotated_size*np.sqrt(2)))
    s.data_sir = 'LoTSS_DR2/RA0h_field'
    print("Data dir", s.data_dir)
    # Get numpy list of cutouts and update catalogue to only contain sources that succesfully got extracted.
    store_filename = s.store_filename

    # Check cutouts shape
    print("Cutouts shape (# cutouts, # channels, width, height) = ", 
          np.shape(hetdex_SFGfiltered_cutouts))
    # Write cutouts to binary file (required file format for PINK software)
    cutouts_binary_path = os.path.join(data_directory, s.store_filename + '.bin')
    assert train_bin_path == cutouts_binary_path
    post.write_numpy_to_binary_v2(train_bin_path, train_cutouts,
                                    som.layout, overwrite=True)
    # Write train bash scripts
    s.som = som
    s.save(output_directory)
    s.flip_axis1=False
    s.flip_axis2=False
    s.rot90=False
    # Write train bash script
    s.som = som
    bash_path = os.path.join(run_dir,f'{s.store_filename}_ID{run_id}.sh')
    bashstring = post.write_bash_script_to_run_pink_v2(som, run_id, s.store_filename,
                                                            gpu_id,
                                                            data_directory.replace(
                                                                'data2','data'),
        output_directory.replace('data2','data'), bash_path, verbose=False)
    print(bashstring,file=f)
    print("Bash script to train SOM:", bash_path)
    
    force_chosen_som_index=25
    print(s.run_id)
    som_v2, data_map_v2, data_som_v2, trained_path_v2, \
        distance_to_bmu_sorted_down_id_v2, \
        closest_prototype_id_v2, cat_v2, _ = post.inspect_trained_som(s.run_id, 
                data_directory, output_directory, figures_dir, fit_threshold=0.077,
             verbose=False, align_prototypes=True, normalize=False,compress=True,
             flip_axis0=False,flip_axis1=False,rot90=False,
                version=2, zoom_in=True, save=False, 
                catalogue_name=train_cat,
                overwrite=True, force_chosen_som_index=force_chosen_som_index)
    print("\nRun bash script to train SOM:", bash_path)
    cali_SFGfiltered_settingsfile.map_to_binpath = trained_path_v2
    print('\n\nSOM',som_v2.print(),'\n\n')
    
    # Map train remnants to trained SOM
    map_path = os.path.join(map_dir,f'{cali_train_settings.store_filename}_ID{cali_train_settings.run_id}_mapped_to_ID{s.run_id}.bin')
    new_bin_path =  os.path.join(data_directory,cali_train_settings.store_filename +'.bin')
    mapstring = post.map_dataset_to_trained_som(som, 
                new_bin_path, 
                map_path, trained_path_v2,
        gpu_id, use_gpu=True, verbose=True, version=2,
        alternate_neuron_dimension=None, use_cuda_visible_devices=True,
        rotation_path=None, circular_shape=False)
    cali_train_settings.map_path = map_path
    cali_train_settings.save(output_directory)
    
    # Map test set to trained SOM
    map_path = os.path.join(map_dir,f'{test_settings.store_filename}_ID{test_settings.run_id}_mapped_to_ID{s.run_id}.bin')
    new_bin_path =  os.path.join(data_directory,
                        test_settings.store_filename +'.bin')
    mapstring = post.map_dataset_to_trained_som(som, 
                new_bin_path, 
                map_path, trained_path_v2,
        gpu_id, use_gpu=True, verbose=True, version=2,
        alternate_neuron_dimension=None, use_cuda_visible_devices=True,
        rotation_path=None, circular_shape=False)
    test_settings.map_path = map_path
    test_settings.save(output_directory)
    
    map_path = os.path.join(map_dir,f'{train_settings.store_filename}_ID{train_settings.run_id}_mapped_to_ID{s.run_id}.bin')
    new_bin_path =  os.path.join(data_directory,train_settings.store_filename +'.bin')
    train_settings.som = som_v2
    train_settings.map_path = som_v2.map_path
    train_settings.save(output_directory)


In [None]:
matplotlib.rcParams.update({'font.size': 18})

captions2=[f'All training sources\nmapped to trained SOM']
captions3=[f'{len(cali_train_cutouts)} sources labelled as\nAGN remnant candidate\nduring initial visual inspection\nmapped to trained SOM']

cali_train_settings.zoom_in = True
data_map_train, data_map_train_remnant, counts1, counts2 = post.plot_som_and_two_heatmaps(
    train_settings, cali_train_settings,
   output_directory,map_dir, save=True,compress=False,
    save_path=os.path.join(paper_fig_dir,f'trained_som_id{train_settings.run_id}_and_heatmaps.pdf'),
    specific_som=train_settings,
    highlight=[],highlight_colors=[],
    annotations_for_paper=False,
    save_dir=figures_dir,
    caption2=captions2[0], caption3=captions3[0],
    caption1=f'9x9 SOM\ntrained with {len(train_cutouts)} radio sources')

In [None]:
matplotlib.rcParams.update({'font.size': 18})

captions2=['Sources >60\" \nin HETDEX\nnearby SFGs filtered\nmapped to trained SOM']
captions3=[f'{len(cali_train_cutouts)} AGN remnant candidates\naccepted by visual inspection\nmapped to trained SOM']

captions2=[f'All training sources\nmapped to trained SOM']
captions3=[f'{len(cali_train_cutouts)} sources labelled as\nAGN remnant candidate\nduring initial visual inspection\nmapped to trained compressed SOM']

cali_train_settings.zoom_in = True
data_map_train_compressed, data_map_train_remnant_compressed, counts1, counts2 = post.plot_som_and_two_heatmaps(
    train_settings, cali_train_settings,
   output_directory,map_dir, save=True,compress=True,
        save_path=os.path.join(paper_fig_dir,f'trained_som_id{train_settings.run_id}_compressed_and_heatmaps.pdf'),

    specific_som=train_settings,
    highlight=[],highlight_colors=[],
    annotations_for_paper=False,
    save_dir=figures_dir,
    caption2=captions2[0], caption3=captions3[0],
    caption1=f'5x5 compressed SOM\ntrained with {len(train_cutouts)} radio sources')

In [None]:
# For the test set
# The two heatmaps should be similar
matplotlib.rcParams.update({'font.size': 18})

captions2=['Sources >60\" \nin HETDEX\nnearby SFGs filtered\nmapped to trained SOM']
captions3=[f'{len(test_cutouts)} sources from our test set\nmapped to trained SOM']

_, data_map_test_compressed, counts1, counts2 = post.plot_som_and_two_heatmaps(
    train_settings, test_settings,
   output_directory,map_dir, save=False,compress=True,
    #specific_som=train_settings,
    highlight=[],highlight_colors=[],
    annotations_for_paper=False,
    save_dir=figures_dir,
    caption2=captions2[0], caption3=captions3[0],
    caption1=f'5x5 compressed SOM\ntrained with {len(train_cutouts)} radio sources')

# Exploring alternative morphological metrics

## Sort using Concentration index (C)

In [None]:
def pixelvalues_within_radius(image, radius, debug=False):
    """Given an image, determine the summed pixelvalue within a circle
    of radius r."""
    assert len(np.shape(image)) == 2
    width, height = np.shape(image)
    assert radius <= width and radius <= height
    image_copy = deepcopy(image)
    initial_sum = np.sum(image_copy)
    a, b = int(height/2), int(width/2)

    y,x = np.ogrid[-a:int(height)-a, -b:int(width)-b]
    mask = x*x + y*y > radius*radius
    
    if debug:
        # debug plot
        image_copy[mask] = 2
        plt.figure()
        plt.imshow(image_copy, origin='lower', cmap='viridis')
        plt.grid(False)
        plt.show()
        
    image_copy[mask] = 0
    return np.sum(image_copy)/initial_sum
    
    
def concentration_index(image, r_inner_fraction=0.2, r_outer_fraction=0.8):
    """Given an image, determine the Concentration index as 
    defined in Conselice 2014 review on galaxy morphology."""
    
    width, height = np.shape(image)
    assert width == height

    radii = np.linspace(0,width/2,100)
    r_inner, r_outer= 0, 0
    inner_reached, outer_reached = False, False
    fluxfractions = [pixelvalues_within_radius(image, r) for r in radii]
    
    for r,v in zip(radii,fluxfractions): 
        if not inner_reached and v > r_inner_fraction:
            r_inner = r
            inner_reached = True        
        if not outer_reached and v > r_outer_fraction:
            r_outer = r
            outer_reached = True
    return 5*np.log(r_outer/r_inner)

## Sort using Gini coefficient (G)

In [None]:
def gini_index(image):
    """Given an image, determine the Gini coefficient as 
    defined in Lisker 2008."""
    assert len(np.shape(image)) == 2
    width, height = np.shape(image)
    image_copy = deepcopy(image)
    radii = np.linspace(0,width/2,100)
    new_radius=0
    
    # Determine 95% flux containing radius
    for radius in radii:
        f = pixelvalues_within_radius(image, radius, debug=False)
        if f > 0.95:
            new_radius = radius
            break
    #print("Old radius:", width/2,"New radius is:", new_radius)
    a, b = int(height/2), int(width/2)
    y,x = np.ogrid[-a:int(height)-a, -b:int(width)-b]
    mask = x*x + y*y > new_radius**2
    image_copy[mask] = np.nan
    raveled = np.ravel(image_copy)
    raveled = raveled[~np.isnan(raveled)]
    n = len(raveled)

    f_average = np.abs(np.mean(raveled))
    sorted_pixels = np.sort(np.abs(raveled))

    #print("summed second term",np.sum([(2*i -n - 1)*v for i, v in enumerate(sorted_pixels)]))
    return np.sum([(2*i -n - 1)*v for i, v in enumerate(sorted_pixels)])/(f_average*n*(n-1))
def flux_radius(image, flux_fraction_contained= 0.95):
    """Given an image, determine the Gini coefficient as 
    defined in Lisker 2008."""
    assert len(np.shape(image)) == 2
    width, height = np.shape(image)
    image_copy = deepcopy(image)
    radii = np.linspace(0,width/2,100)
    new_radius=0
    
    # Determine 95% flux containing radius
    for radius in radii:
        f = pixelvalues_within_radius(image, radius, debug=False)
        if f > flux_fraction_contained:
            new_radius = radius
            break
    return new_radius

## Sort using Clumpiness index (S)

In [None]:
from scipy.ndimage import gaussian_filter
def clumpiness_index(image, smoothing_kernel=4, debug=False):
    """Given an image, determine the Clumpiness index as 
    defined in Conselice 2014 review on galaxy morphology."""
    width, height = np.shape(image)
    assert width == height
    assert len(np.shape(image)) == 2
    image_copy = deepcopy(image)
    initial_sum = np.sum(image_copy)
    smoothed_image = gaussian_filter(image, smoothing_kernel)
    if debug:
        # Show concerning image
        print("Clumpiness:",10*(np.sum(np.abs(image_copy-smoothed_image))/initial_sum))
        image_copy[0,0] = 1
        smoothed_image[0,0] = 1
        plt.imshow(image_copy, origin='lower', cmap='viridis')
        plt.colorbar()
        plt.grid(False)
        plt.show()
        plt.imshow(smoothed_image, origin='lower', cmap='viridis')
        plt.grid(False)
        plt.colorbar()
        plt.show()
    return 10*(np.sum(np.abs(image_copy-smoothed_image))/initial_sum)

# plot Morphological metrics

In [None]:
# Get metrics for calibration set
cali_concentrations = np.array([concentration_index(image, r_inner_fraction=0.2, 
        r_outer_fraction=0.8) for image in cali_train_cutouts])
cali_clumpinesses = np.array([clumpiness_index(image, smoothing_kernel=8) 
        for image in cali_train_cutouts])
cali_ginis = np.array([gini_index(image) 
                       for image in cali_train_cutouts])
cali_euclid = np.min(data_map_train_remnant_compressed,axis=1)

# For hetdex filtered etc
hetdex_concentrations = np.array([concentration_index(image, r_inner_fraction=0.2, 
        r_outer_fraction=0.8) for image in train_cutouts])
hetdex_clumpinesses = np.array([clumpiness_index(image, smoothing_kernel=8) 
        for image in train_cutouts])
hetdex_ginis = np.array([gini_index(image) 
                      for image in train_cutouts])
hetdex_euclid = np.min(data_map_train_compressed,axis=1)
#hetdex_SFGfiltered_data_map_compressed, data_map_remnant_compressed, data_map_nonremnant_compressed

In [None]:
# Plot Euclidean norm vs concentration index
rows,cols = 4,4
ms=4
matplotlib.rcParams.update({'font.size': 14})

f, ax = plt.subplots(cols,rows,figsize=(10,10))
spacing=0
plt.subplots_adjust(wspace=spacing,hspace=spacing)
dc = {'Concentration index':cali_concentrations,'Gini coefficient':cali_ginis,
     'Clumpiness index':cali_clumpinesses,'Euclidean norm':cali_euclid}
da = {'Concentration index':hetdex_concentrations,'Gini coefficient':hetdex_ginis,
     'Clumpiness index':hetdex_clumpinesses,'Euclidean norm':hetdex_euclid}
    
for row_idx, ys in enumerate(['Euclidean norm','Concentration index','Gini coefficient','Clumpiness index']):
    #for col_idx, ys in enumerate(['Euclidean norm','Concentration index','Gini index','Clumpiness index']):
    for col_idx, xs in enumerate(['Euclidean norm','Concentration index','Gini coefficient','Clumpiness index']):
        if rows-row_idx +col_idx > 3:
            ax[row_idx,col_idx].axis('off')
            continue
        if rows -1 != row_idx:
            ax[row_idx,col_idx].xaxis.set_major_locator(plt.NullLocator())
        else:
            ax[row_idx,col_idx].set_xlabel(xs.replace(' ','\n'))
        if 0 != col_idx:
            ax[row_idx,col_idx].yaxis.set_major_locator(plt.NullLocator())
        else:
            ax[row_idx,col_idx].set_ylabel(ys.replace(' ','\n'))
        l1 = ax[row_idx,col_idx].scatter(da[xs],da[ys],
                    s=ms,label='Rejected candidates')
        l2 = ax[row_idx,col_idx].scatter(dc[xs],dc[ys],
                    s=ms, label='Accepted candidates')
    
    
plt.savefig(os.path.join(paper_fig_dir,"metrics_cali_hetdex.pdf"),bbox_inches='tight')
plt.show()
print(f"Min and max value comparisons")
print(f"cali concentration ranges from {np.min(cali_concentrations):.1f} to {np.max(cali_concentrations):.1f}")
print(f"all concentration ranges from {np.min(hetdex_concentrations):.1f} to {np.max(hetdex_concentrations):.1f}")

In [None]:
# Plot all source above clumpiness index of 4
from matplotlib.patches import Ellipse
np.random.seed(random_seed)
matplotlib.rcParams.update({'font.size': 18})
def plot_optical_metric_examples(title, condition, savetitle='', save=True,flux_radii=None):
    w,h=3,3
    fig, ax = plt.subplots(w,h, figsize=(10,10.5),constrained_layout = True)
    spacing=0.0
    axr = ax.ravel()
    plt.suptitle(title)
    cuts = deepcopy(train_cutouts[condition])
    shuf = list(range(len(cuts)))
    np.random.shuffle(shuf)
    if not flux_radii is None:
        flux_r = deepcopy(flux_radii[condition])
        flux_r = flux_r[shuf]
    for i_c, c in enumerate(cuts[shuf][:w*h]):
        axr[i_c].imshow(c)
        if not flux_radii is None:
            ww,hh= c.shape
            axr[i_c].add_patch(Ellipse(xy=(ww/2,hh/2),width=2*flux_r[i_c],height=2*flux_r[i_c], 
                        edgecolor='r', fc='None', lw=2))
            axr[i_c].text(ww/2,hh/2,f"{2*flux_r[i_c]:.1f}")
        axr[i_c].axis('off')
    if save:
        plt.savefig(os.path.join(paper_fig_dir,savetitle+'.pdf'))
    plt.show()
low, high = np.percentile(hetdex_ginis,[20,80])
plot_optical_metric_examples(f'Gini coefficient < {low:.2f}', hetdex_ginis < low, 
                             savetitle='Gini_coefficient_low')
plot_optical_metric_examples(f'Gini coefficient > {high:.2f}', hetdex_ginis > high,
                             savetitle='Gini_coefficient_high')
low, high = np.percentile(hetdex_concentrations,[20,80])
plot_optical_metric_examples(f'Concentration index < {low:.2f}', hetdex_concentrations < low, 
                             savetitle='Concentration_index_low')
plot_optical_metric_examples(f'Concentration index > {high:.2f}', hetdex_concentrations > high,
                             savetitle='Concentration_index_high')
low, high = np.percentile(hetdex_clumpinesses,[20,80])
plot_optical_metric_examples(f'Clumpiness index < {low:.2f}', hetdex_clumpinesses < low, 
                             savetitle='Clumpiness_index_low')
plot_optical_metric_examples(f'Clumpiness index > {high:.2f}', hetdex_clumpinesses > high,
                             savetitle='Clumpiness_index_high')


# Decision Tree time

In [None]:
# We start from SFG filtered cat >60arcsec

# Compile source features and source labels (agn remnant candidate or not)
# Features to collect:
# 1. Number of AGN remnant candidates in the best matching neuron of the source
# 2. Total flux / peak flux
# 3. Total flux
# 4. Major axis / minor axis
# 5. Euclidean norm
# 6. Concentration index
# 7. Gini coefficient
# 8. Clumpiness index
# 9. Angular separation to the nearest cluster? (as a higher magnetic background field should allow us to observe AGN remnants for a longer time?)
# Or even better: Distance to the nearest cluster in Mpc
# Finally Labels
np.random.seed(random_seed)

# 1. best matching neuron
bmn = np.argmin(data_map_train_compressed,axis=1)
l = len(bmn)
# 2.
total_to_peak_fluxes = train_cat.Total_flux.values/train_cat.Peak_flux.values
assert l == len(total_to_peak_fluxes)
# 3.
major_to_minor_axes = train_cat.source_size.values/train_cat.source_width.values
assert l == len(major_to_minor_axes)
# 4.
euclidean_norms = np.min(data_map_train_compressed,axis=1)
assert l == len(euclidean_norms)
# 5. 
all_concentrations = np.array([concentration_index(image, r_inner_fraction=0.2, 
        r_outer_fraction=0.8) for image in train_cutouts])
assert l == len(all_concentrations)
# 6.
all_ginis = np.array([gini_index(image) for image in train_cutouts])
assert l == len(all_ginis)
# 7.
all_clumpinesses = np.array([clumpiness_index(image, smoothing_kernel=8) 
        for image in train_cutouts])
assert l == len(all_clumpinesses)
# 8. Haralick features
haralicks = []
for cutout in train_cutouts:
    #calculate haralick features
    data = deepcopy(cutout)*256
    # Reduce the cutout intensity levels to 256 integer values
    data = data.astype(np.uint8)
    # Further reduce the cutout intensity levels to 32 integer values
    # as advocated by M.A. Tahir, A. Bouridane, F. Kurugollu, A. Amira
    # Accelerating computation of GLCM and Haralick texture features on reconfigurable hardware
    #data = np.array([list(map(int,aa/8)) for aa in data])
    haralick_features = mh.features.haralick(data, return_mean=True)
    haralicks.append(haralick_features)
haralicks = np.array(haralicks)
#set up the HDBSCAN clusterer - see params for more details, but this seems to do okay
clusterer = hdbscan.HDBSCAN(min_cluster_size=32,min_samples=32,
                            algorithm='best', prediction_data=True,
                            cluster_selection_method='eom',
                            metric='euclidean')

clusterer = clusterer.fit(haralicks)

#get a list of unique labels and the number of counts of each label
#label = -1 is 'noise' - i.e. couldn't be fit to a cluster
clustered_haralicks = clusterer.labels_
ulab,ucount= np.unique(clustered_haralicks,return_counts=True)
print(f"Hdbscan generated {len(ulab)} clusters: {ulab} Containing {ucount} sources each.")
assert l == len(clustered_haralicks)

# Labels
labels_train = np.array([sn in cali_train_cat.Source_Name.values
 for sn in train_cat.Source_Name.values])

# Combine features and make random train test split
combined_data = pd.DataFrame({
                            'Source Name':train_cat.Source_Name.values,
                             'Remnants per SOM neuron':bmn,
                            # bmn will be turned into a ratio within
                            # each cross-fold
                            'Remnants ratio per SOM neuron':bmn,
                             'Total/peak flux':total_to_peak_fluxes,
                             'Major/minor axis':major_to_minor_axes,
                             'Euclidean norm':euclidean_norms,
                             'Concentration index':all_concentrations,
                             'Gini coefficient':all_ginis,
                             'Clumpiness index':all_clumpinesses,
                            # clustered_haralicks will be turned into a ratio within
                            # each cross-fold
                             'Clustered Haralick ratio':clustered_haralicks,
})
features_train=combined_data.to_numpy()
feature_list=combined_data.keys()[1:]
print("len dataframe:", len(combined_data))

In [None]:
# Now for the test set
np.random.seed(random_seed)

# 1. best matching neuron
bmn = np.argmin(data_map_test_compressed,axis=1)
l = len(bmn)
# 2.
total_to_peak_fluxes = test_cat.Total_flux.values/test_cat.Peak_flux.values
assert l == len(total_to_peak_fluxes)
# 3.
major_to_minor_axes = test_cat.source_size.values/test_cat.source_width.values
assert l == len(major_to_minor_axes)
# 4.
euclidean_norms = np.min(data_map_test_compressed,axis=1)
assert l == len(euclidean_norms)
# 5. 
all_concentrations = np.array([concentration_index(image, r_inner_fraction=0.2, 
        r_outer_fraction=0.8) for image in test_cutouts])
assert l == len(all_concentrations)
# 6.
all_ginis = np.array([gini_index(image) for image in test_cutouts])
assert l == len(all_ginis)
# 7.
all_clumpinesses = np.array([clumpiness_index(image, smoothing_kernel=8) 
        for image in test_cutouts])
assert l == len(all_clumpinesses)
# 8. Haralicks
haralicks = []
for cutout in test_cutouts:
    #calculate haralick features
    data = deepcopy(cutout)*256
    # Reduce the cutout intensity levels to 256 integer values
    data = data.astype(np.uint8)
    # Further reduce the cutout intensity levels to 32 integer values
    # as advocated by M.A. Tahir, A. Bouridane, F. Kurugollu, A. Amira
    # Accelerating computation of GLCM and Haralick texture features on reconfigurable hardware
    #data = np.array([list(map(int,aa/8)) for aa in data])
    haralick_features = mh.features.haralick(data, return_mean=True)
    haralicks.append(haralick_features)
haralicks = np.array(haralicks)
# Use previously trained clusterer
# https://hdbscan.readthedocs.io/en/latest/prediction_tutorial.html
test_labels, strengths = hdbscan.approximate_predict(clusterer, haralicks)

#get a list of unique labels and the number of counts of each label
#label = -1 is 'noise' - i.e. couldn't be fit to a cluster
ulab,ucount= np.unique(test_labels,return_counts=True)
print(f"Hdbscan generated {len(ulab)} clusters: {ulab} Containing {ucount} sources each.")
assert l == len(test_labels)

# Labels
labels_test = np.array([sn in cali_test_cat.Source_Name.values
 for sn in test_cat.Source_Name.values])

# Combine features and make random train test split
combined_data = pd.DataFrame({
                            'Source Name':test_cat.Source_Name.values,
                             'Remnants per SOM neuron':bmn,
                            # bmn will be turned into a ratio within
                            # each cross-fold
                            'Remnants ratio per SOM neuron':bmn, #np.array([0 for _ in range(l)]),
                             'Total/peak flux':total_to_peak_fluxes,
                             'Major/minor axis':major_to_minor_axes,
                             'Euclidean norm':euclidean_norms,
                             'Concentration index':all_concentrations,
                             'Gini coefficient':all_ginis,
                             'Clumpiness index':all_clumpinesses,
                            # clustered_haralicks will be turned into a ratio within
                            # each cross-fold
                             'Clustered Haralick ratio':test_labels,
})
features_test=combined_data.to_numpy()
print("len dataframe:", len(combined_data))

In [None]:
### Check if HDBScan approximate predict equals clusterer.fit when fed the same cutouts
first_ten = clustered_haralicks[:10]
# Haralicks
haralicks = []
for cutout in train_cutouts[:10]:
    data = deepcopy(cutout)*256
    data = data.astype(np.uint8)
    haralick_features = mh.features.haralick(data, return_mean=True)
    haralicks.append(haralick_features)
haralicks = np.array(haralicks)
debug_labels, strengths = hdbscan.approximate_predict(clusterer, haralicks)
assert all(first_ten == debug_labels)

In [None]:
# Print examples of sources in each cluster
#for cutout, label in zip(train_cutouts,clustered_haralicks):
    
np.random.seed(random_seed)
matplotlib.rcParams.update({'font.size': 18})
def plot_haralick_examples(title, condition, savetitle='', save=True):
    w,h=3,3
    fig, ax = plt.subplots(w,h, figsize=(10,10.5),constrained_layout = True)
    spacing=0.0
    axr = ax.ravel()
    plt.suptitle(title)
    cuts = deepcopy(train_cutouts[condition])
    shuf = list(range(len(cuts)))
    np.random.shuffle(shuf)
    for i_c, c in enumerate(cuts[shuf][:w*h]):
        axr[i_c].imshow(c)
        axr[i_c].axis('off')
    if save:
        plt.savefig(os.path.join(paper_fig_dir,savetitle+'.pdf'))
    plt.show()
save=False
plot_haralick_examples(f'Haralick cluster -1 (\'noise\' cluster)', clustered_haralicks == -1, 
                             save=save,savetitle='haralick_minus1')
plot_haralick_examples(f'Haralick cluster 0', clustered_haralicks == 0, 
                             save=save,savetitle='haralick_0')
plot_haralick_examples(f'Haralick cluster 1', clustered_haralicks == 1, 
                             save=save,savetitle='haralick_1')
plot_haralick_examples(f'Haralick cluster 2', clustered_haralicks == 2, 
                             save=save,savetitle='haralick_2')
plot_haralick_examples(f'Haralick cluster 3', clustered_haralicks == 3, 
                             save=save,savetitle='haralick_3')


In [None]:
sourcenames_train = np.array([f[0] for f in features_train])
sourcenames_test = np.array([f[0] for f in features_test])
features_train = np.array([f[1:] for f in features_train])
features_test = np.array([f[1:] for f in features_test])


print(pd.DataFrame({"labels_train":labels_train}).labels_train.value_counts(
    normalize = True))
print(pd.DataFrame({"labels_test":labels_test}).labels_test.value_counts(
    normalize = True))
print(f"Train: {np.sum(labels_train)} remnants, {len(labels_train)-np.sum(labels_train)} non-remnants")
print(f"Test: {np.sum(labels_test)} remnants, {len(labels_test)-np.sum(labels_test)} non-remnants")

## Employ gridsearch to find good RF parameters

In [None]:
len(features_train), len(labels_train), features_train.shape, labels_train.shape

In [None]:
np.random.seed(random_seed)
f_scorer = make_scorer(fbeta_score, beta=2)
scoring = f_scorer
redo=False
if redo:
    # Set up RF
    rf = RandomForestClassifier(n_estimators=5000, random_state=random_seed,
                        max_depth = 2, max_features=0.2, 
                        class_weight='balanced')

    # Set up stratified split
    cv = StratifiedKFold(n_splits=5,shuffle=True,random_state=random_seed)
    # Set up parameters to do gridsearch over
    param_grid = [{'max_depth': [2,4,8,16,32], 
                   'max_features': [0.2,0.3,0.4,0.5],
                   'class_weight': [{0: x, 1: 1.0-x} for x in [0.01,0.04,0.16,0.32,0.5]]
                  } ]
    # Gridsearch with full (unbalanced) train dataset
    start = time.time()
    clf = GridSearchCV(rf, param_grid, scoring=scoring, n_jobs=-1, refit=True, cv=cv, verbose=0, 
                  return_train_score=False)
    clf.fit(features_train, labels_train)
    print(f"Cross-validation custom gridsearch took ${(time.time()-start)/60:.2f}$ min.")
    unbalanced_par_score = (clf.best_params_, clf.best_score_)
    print("Best parameters:",unbalanced_par_score)
else:
    # Set up RF with previously found best values
    cv = StratifiedKFold(n_splits=5,shuffle=True,random_state=random_seed)
    clf = RandomForestClassifier(n_estimators=5000, random_state=random_seed,
                        max_depth = 8, max_features=0.3, 
                        class_weight={0: 0.04, 1: 0.96})
    clf.best_params_ = {'class_weight': {0: 0.04, 1: 0.96}, 'max_depth': 8, 'max_features': 0.3}
    unbalanced_par_score = ({'class_weight': {0: 0.04, 1: 0.96}, 'max_depth': 8, 'max_features': 0.3}, 0.6754247726575159)

## Get feature importance and test set performance

In [None]:
np.random.seed(random_seed)
# Fill features
features_train2, remnant_dict, all_dict, remnant_dict_hard, all_dict_hard, \
                 = insert_local_remnant_ratios(
                features_train, labels_train, debug=False)
features_test2 = insert_local_test_ratios(features_test, remnant_dict, all_dict, \
                                    remnant_dict_hard, all_dict_hard)
# Final training on full unbalanced training set
best_rf = RandomForestClassifier(n_estimators=1000, random_state=random_seed,
                    max_depth = unbalanced_par_score[0]['max_depth'], 
                    max_features=unbalanced_par_score[0]['max_features'], 
                    class_weight=unbalanced_par_score[0]['class_weight'])
best_rf.fit(features_train2, labels_train)
print(f"""In the cross-validation phase of training our RF, 
the grid-search reveals that a maximum tree depth of ${clf.best_params_['max_depth']}$,
a maximum feature-ratio of ${clf.best_params_['max_features']}$ 
and a class weight of ${clf.best_params_['class_weight'][0]}$ for our majority class 
(and ${clf.best_params_['class_weight'][1]}$ for our minority class),
are good parameter choices for our RF when optimizing for recall. 
Refitting our RF using these parameters on the full training set yields the following performance score on the test set.""")
print()

In [None]:
# Get numerical feature importances
importances = list(best_rf.feature_importances_)# List of tuples with variable and importance
feature_importances = [(feature, round(importance, 2)) 
                       for feature, importance in zip(feature_list, importances)]# Sort the feature importances by most important first
feature_importances = sorted(feature_importances, 
                             key = lambda x: x[1], reverse = True)# Print out the feature and importances 
[print('{:20} & ${}$ \\\\'.format(*pair)) for pair in feature_importances];

# Predict on testset
predictions_rf = best_rf.predict(features_test2)
print(classification_report(labels_test, predictions_rf))
# Plot confusion matrix
cm = confusion_matrix(labels_test, predictions_rf, labels=best_rf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=best_rf.classes_)
disp.plot()
plt.title('For full testset')
plt.ylabel('AGN remnant candidate')
plt.xlabel('AGN remnant candidate predicted')
plt.show()

## Find best threshold for our use case

In [None]:
np.random.seed(random_seed)
thres=0.25 #f2 5000trees

pred = np.array([True if b>thres else False 
                    for a,b in best_rf.predict_proba(features_train2)])
cm_train = confusion_matrix(labels_train,pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm_train, display_labels=best_rf.classes_)
disp.plot(cmap='gray')
plt.title('Train set')
plt.ylabel('AGN remnant candidate')
plt.xlabel('AGN remnant candidate predicted')
plt.show()

In [None]:
np.random.seed(random_seed)
print("recall",unbalanced_par_score)
thres=0.25#f2 5000trees

pred = np.array([True if b>thres else False 
                    for a,b in best_rf.predict_proba(features_test2)])
cm = confusion_matrix(labels_test,pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=best_rf.classes_)
disp.plot(cmap='gray')
# plt.title('For full testset')
plt.ylabel('AGN remnant candidate')
plt.xlabel('AGN remnant candidate predicted')
plt.savefig(os.path.join(paper_fig_dir,'confusion.pdf'),bbox_inches='tight')
plt.show()
print("Recall:", cm[1,1]/(cm[1,0]+cm[1,1]), cm[1,0],cm[1,1])

report = classification_report(labels_test, pred, output_dict=True,
                            target_names=['noncandidate','AGN remnant candidate'])
print("Latex performance report for given prediction threshold:")
print('Performance of trained RF on test set with prediction threshold set to guarantee full recall of the AGN remnant candidates.')
print(classification_report_to_latex(report))

In [None]:
#F2-score on test set:
def fb(tn,fn,tp,fp,b=2):
    return ((1+b**2)*tp) / ((1+b**2)*tp + fn*b**2 + fp)

print("Train set F2-score is", fb(cm_train[0,0],cm_train[1,0],cm_train[1,1],cm_train[0,1],b=2))
print("Test set F2-score is", fb(cm[0,0],cm[1,0],cm[1,1],cm[0,1],b=2))
oud = cm[0,0]+cm[0,1]+cm[1,0]+cm[1,1]
nieuw=cm[0,1]+cm[1,1]
print(oud,nieuw)
print(f"Reducing visual inspection by {(nieuw-oud)/oud:.2%}")

## Plot false positives

In [None]:
sn_to_iloc = {sn:idx for idx, sn in enumerate(hetdex_SFGfiltered_cat.Source_Name.values)}

In [None]:
def plot_from_sn(sn):
    idx = sn_to_iloc[sn]
    cutout = hetdex_SFGfiltered_cutouts[idx]
    plt.imshow(cutout)
    plt.show()

    
last_was_fp = False
for i, (prediction, sn) in enumerate(zip(pred, sourcenames_test)):
    if i>400: break
    if prediction:
        if not sn in cali_SFGfiltered_cat.Source_Name.values:
            print("False positive: Could be remnant?")
            plot_from_sn(sn)
            last_was_fp = True
        else:
            continue
            print("True positive (candidate)?")
            plot_from_sn(sn)
            last_was_fp = True
    else:
        if not sn in cali_SFGfiltered_cat.Source_Name.values:
            continue
            if last_was_fp:
                print("True negative")
                plot_from_sn(sn)
                last_was_fp = False
        else:
            print("False negative (weird candidate)?")
            plot_from_sn(sn)
            last_was_fp = True
#for i, (prediction, sn) in enumerate(zip(pred, sourcenames_test)):


In [None]:
def cutout_from_sn(sn):
    idx = sn_to_iloc[sn]
    cutout = hetdex_SFGfiltered_cutouts[idx]
    return cutout

ll = np.shape(cutout)[0]
w = int(ll/np.sqrt(2))
ss=int((ll-w)/2)
limit=6
# True positives
rnames = cali_SFGfiltered_cat.Source_Name.values
tp_cutouts =np.array([cutout_from_sn(sn) for prediction, sn in zip(pred, sourcenames_test)
      if prediction and (sn in rnames)])
# Draw random sample with replacement
np.random.seed(random_seed)
draws = np.random.choice(len(tp_cutouts), size=limit**2, replace=False)
tp_cutouts = tp_cutouts[draws]
ww = int(np.ceil(np.sqrt(len(tp_cutouts))))
f, a = plt.subplots(ww,ww,figsize=(10,10),constrained_layout=True)
plt.suptitle('True positives (AGN remnant candidates)')
[axx.set_axis_off() for axx in a.ravel()]
for i, axx in enumerate(a.ravel()):
    if i<len(tp_cutouts):
        im = deepcopy(tp_cutouts[i])
        im = im[ss:-ss,ss:-ss]
        axx.text(0,0,f'TP{i+1}',color='white',ha='left',va='bottom',size=12)        
        axx.imshow(im, origin='lower')
plt.savefig(os.path.join(paper_fig_dir,'output_examples_TP.pdf'),bbox_inches='tight')
plt.show()


# False positives (maybe candidates?)
rnames = cali_SFGfiltered_cat.Source_Name.values
fp_cutouts =np.array([cutout_from_sn(sn) for prediction, sn in zip(pred, sourcenames_test)
      if prediction and (sn not in rnames)])
fp_names =np.array([sn for prediction, sn in zip(pred, sourcenames_test)
    if prediction and (sn not in rnames)])
np.random.seed(random_seed)
draws = np.random.choice(len(fp_cutouts), size=limit**2, replace=False)
fp_cutouts = fp_cutouts[draws]
fp_names = fp_names[draws]
ww = int(np.ceil(np.sqrt(len(fp_cutouts))))
f, a = plt.subplots(ww,ww,figsize=(10,10),constrained_layout=True)
[axx.set_axis_off() for axx in a.ravel()]
for i, axx in enumerate(a.ravel()):
    if i<len(fp_cutouts):
        im = deepcopy(fp_cutouts[i])
        im = im[ss:-ss,ss:-ss]
        axx.text(0,0,f'FP{i+1}',color='white',ha='left',va='bottom',size=12)        
        axx.imshow(im, origin='lower')
plt.suptitle('False positives (more likely candidates)')
plt.savefig(os.path.join(paper_fig_dir,'output_examples_FP.pdf'),bbox_inches='tight')
plt.show()

# True negatives
rnames = cali_SFGfiltered_cat.Source_Name.values
tn_cutouts =np.array([cutout_from_sn(sn) for prediction, sn in zip(pred, sourcenames_test)
      if (not prediction) and (sn not in rnames)])
tn_names =np.array([sn for prediction, sn in zip(pred, sourcenames_test)
      if (not prediction) and (sn not in rnames)])
np.random.seed(random_seed)
draws = np.random.choice(len(tn_cutouts), size=limit**2, replace=False)
tn_cutouts = tn_cutouts[draws]
tn_names = tn_names[draws]
ww = int(np.ceil(np.sqrt(len(fp_cutouts))))
f, a = plt.subplots(ww,ww,figsize=(10,10),constrained_layout=True)
[axx.set_axis_off() for axx in a.ravel()]
for i, axx in enumerate(a.ravel()):
    if i<len(fp_cutouts):
        im = deepcopy(tn_cutouts[i])
        im = im[ss:-ss,ss:-ss]
        axx.text(0,0,f'TN{i+1}',color='white',ha='left',va='bottom',size=12)        
        axx.imshow(im, origin='lower')
plt.suptitle('True negatives (less likely candidates)')
plt.savefig(os.path.join(paper_fig_dir,'output_examples_TN.pdf'),bbox_inches='tight')
plt.show()


In [None]:
cat = deepcopy(hetdex_SFGfiltered_cat)
cat = cat.set_index('Source_Name')

In [None]:
# write names to file 
with open('36false_positives_noTotalFlux.txt','w') as f:
    print('#, Source Name, RA [deg], DEC [deg]', file=f)
    for ii, sn in enumerate(fp_names):
        ra = cat.loc[sn].RA
        dec = cat.loc[sn].DEC
        print(f'FP{ii+1}, {sn}, {ra}, {dec}', file=f)
        
with open('36true_negatives_noTotalFlux.txt','w') as f:
    print('#, Source Name, RA [deg], DEC [deg]', file=f)
    for ii, sn in enumerate(tn_names):
        ra = cat.loc[sn].RA
        dec = cat.loc[sn].DEC
        print(f'TN{ii+1}, {sn}, {ra}, {dec}', file=f)

## Test classifier significance

In [None]:
np.random.seed(random_seed)
# Evaluate the significance of a cross-validated score with permutations.
start =time.time()
score_test, perm_scores_test, pvalue_test = custom_permutation_test_score(
    best_rf, features_train, labels_train, cv=cv, scoring=scoring,
    n_permutations=1000, random_state=random_seed)
print(f"Time taken: {time.time()-start:.0f} sec")
score_test,pvalue_test

In [None]:
fig, ax = plt.subplots()
matplotlib.rcParams.update({'font.size': 14})
ax.hist(perm_scores_test, density=False, histtype='step',linewidth=3)
ax.axvline(score_test, ls="--", color="r",linewidth=3)
score_label = f"F2-score on\n original data: {score_test*100:.1f}%\n(p-value: {pvalue_test:.3f})"
ax.text(score_test-0.02, 200, score_label, fontsize=12,ha='right',va='top')
ax.set_xlabel(f"F2-score")
ax.set_ylabel("Counts")
plt.savefig(os.path.join(paper_fig_dir, "test_score_151vshetdex_f2score.pdf"),bbox_inches='tight')
plt.show()

## Test permutation feature importance

In [None]:
np.random.seed(random_seed)
# Find out Permutation importance for feature evaluation
n_repeats=40
start =time.time()
result = permutation_importance(
    best_rf, features_train2, labels_train, n_repeats=n_repeats, 
    scoring=scoring,
    random_state=random_seed, n_jobs=-1)

sorted_importances_idx = result.importances_mean.argsort()
importances_train = pd.DataFrame(
    result.importances[sorted_importances_idx].T,
    columns=feature_list[sorted_importances_idx],
)

result = permutation_importance(
    best_rf, features_test2, labels_test, n_repeats=n_repeats, 
        scoring=scoring,
    random_state=random_seed, n_jobs=-1)
sorted_importances_idx = result.importances_mean.argsort()
importances_test = pd.DataFrame(
    result.importances[sorted_importances_idx].T,
    columns=feature_list[sorted_importances_idx],
)

print(f"Time taken: {time.time()-start:.0f} sec")

In [None]:
matplotlib.rcParams.update({'font.size': 10})
ax = importances_train.plot.box(vert=False, whis=10,figsize=(8,4))
ax.set_title("Permutation Importances (train set)")
ax.axvline(x=0, color="k", linestyle="--")
ax.set_xlabel(f"Decrease in F2-score when permutating feature",loc='right')
plt.savefig(os.path.join(paper_fig_dir,"perm_train_id370.pdf"),bbox_inches='tight')
plt.show()

ax = importances_test.plot.box(vert=False, whis=10,figsize=(8,4))
ax.set_title("Permutation Importances (test set)")
ax.axvline(x=0, color="k", linestyle="--")
ax.set_xlabel(f"Decrease in F2-score when permutating feature",loc='right')
plt.savefig(os.path.join(paper_fig_dir,"perm_test_id370.pdf"),bbox_inches='tight')
plt.show()
total=[]
for label,val in feature_importances:
    print(f"{label},")
    total+=val