In [1]:
from radburst.utils.dataset import Dataset
import radburst.utils.preprocessing as prep
import radburst.utils.utils as util
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.backends.backend_pdf import PdfPages
import os


data_path = '/mnt/c/Users/camer/OneDrive/Documents/radburst/data/Fitfiles'
labels_path = '/mnt/c/Users/camer/OneDrive/Documents/radburst/data/labels/filtered-labels-20240309-20240701.csv'

# Create a Dataset object which loads all data from the given path (defined in dataset.py)
data = Dataset(data_dir= data_path,
              labels= labels_path,
              preprocess= prep.stan_rows_remove_verts)

dataset_only_bursts = data.only_bursts()
dataset_only_nonbursts = data.only_nonbursts()

In [2]:
def process_fits(fits_path):
    raw = util.load_fits_file(fits_path)
    stan_rows = prep.stan_rows_remove_verts(raw)
    stan_rows_blur = prep.blur(stan_rows)
    binary_mask = prep.create_binary_mask(stan_rows_blur)
    eroded_mask = prep.morph_ops(binary_mask)
    filtered_largest_2_regions, filtered_mask = prep.filtered_components(eroded_mask)
    #bboxes = filtered_comps.largest_2_bboxes()
    
    return {'raw': raw,
            'stan_rows': stan_rows,
            'stan_rows_blur': stan_rows_blur,
            'binary_mask': binary_mask,
            'eroded_mask': eroded_mask,
            'filtered_mask': filtered_mask,
            'filtered_regs': filtered_largest_2_regions}


def create_preprocessing_steps_fig(fits_path):      
    processed = process_fits(fits_path)
    
    fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10,10))
    
    path = fits_path.split('/')[1:][-1]

    image_title = [(processed['raw'], f'File: {path}'),
                   (processed['stan_rows_blur'], '1. Rows Standardized, Vertical Lines Removed, Blurred'),
                   (processed['binary_mask'], '2. Binary Mask'),
                   (processed['eroded_mask'], '3. Eroded Mask'),
                   (processed['filtered_mask'], '4. Filtered Mask\n(regions from 3 that meet criteria)'),
                   (processed['stan_rows'], '5. Detection \n(regions from 4 that meet criteria)')]
                   
                   
    for i, (img, title) in enumerate(image_title):
        row, col = i//2, i%2

        if 'mask' in title.lower():
            ax[row, col].imshow(img, aspect='auto', cmap='gray', vmin=0, vmax=1)
        else:
            ax[row, col].imshow(img, aspect='auto')

        ax[row, col].set_title(title)
        ax[row, col].set_axis_off()


    for reg in processed['filtered_regs']:
        
        #min_row, min_col, max_row, max_col = bbox
        height = reg.max_row - reg.min_row
        width = reg.max_col - reg.min_col
        x_min, y_min = reg.min_col, reg.min_row
    
        bounding_box = patches.Rectangle((x_min, y_min), width, height, linewidth=1, edgecolor='r', facecolor='none')
        ax[2,1].add_patch(bounding_box)

    plt.tight_layout()

    return fig


with PdfPages('preprocessing_detection_plots.pdf') as pdf:
    for path_from_data_dir in dataset_only_bursts.paths:
        full_fits_path = os.path.join(data_path, path_from_data_dir)
        fig = create_preprocessing_steps_fig(fits_path=full_fits_path)
        pdf.savefig(fig)
        plt.close(fig)

In [5]:
#for path_from_data_dir in dataset_only_bursts.paths:
#    full_fits_path = os.path.join(data_path, path_from_data_dir)
#    fig = create_preprocessing_steps_fig(fits_path=full_fits_path)
#    plt.show()

In [6]:
import random
random.seed(42)

num_samples = 25
rand_sample_nonburst_paths = random.sample(list(dataset_only_nonbursts.paths), num_samples)

#for path_from_data_dir in rand_sample_nonburst_paths:
#    full_fits_path = os.path.join(data_path, path_from_data_dir)
#    fig = create_preprocessing_steps_fig(fits_path=full_fits_path)
#    plt.show()

with PdfPages('preprocessing_detection_plots_nonbursts.pdf') as pdf:
    for path_from_data_dir in rand_sample_nonburst_paths:
        full_fits_path = os.path.join(data_path, path_from_data_dir)
        fig = create_preprocessing_steps_fig(fits_path=full_fits_path)
        pdf.savefig(fig)
        plt.close(fig)