## Imports

In [7]:
import numpy as np
import cv2 as cv
import pandas as pd

from astropy.io import fits
from astropy import wcs

from pavlidis import pavlidis
from _helpers import make_directory, prepareData, can_go_down

from fnmatch import fnmatch
from os import listdir, path as os_path
import argparse
from tqdm import tqdm
# from time import perf_counter

import warnings
from astropy.utils.exceptions import AstropyUserWarning
from astropy.wcs import FITSFixedWarning

warnings.filterwarnings(action='ignore', category=AstropyUserWarning)
warnings.filterwarnings(action='ignore', category=FITSFixedWarning)

## Variables

In [8]:
## blockSize: Size of a pixel neighborhood that is used to calculate a threshold value for the pixel: 3, 5, 7,
# and so on.
## C: Constant subtracted from the mean or weighted mean (see the details below). Normally, it is positive but may
# be zero or negative as well.
## calculate_coordinates: Whether to calculate the coordinates of dataset objects using headers or not.

block_size = 19
constant = 2
calculate_coordinates = True

mode = 'train'  # choices=('train', 'test')
csv_name = 'Subtypes'  # choices=('Subtypes', 'Combined')
class_column = 'Sp type'  # choices=('Sp type', 'Cl')

data_root = '/home/stepan/Data/DFBS'; assert data_root != None

if mode == 'train':
    output_path = f'{data_root}/Plates/{csv_name}/{block_size}_{constant}'
else:
    output_path = f'{data_root}/Inference/{csv_name}/{block_size}_{constant}'

fits_path = os_path.join(data_root, 'fits_files')
headers_path = os_path.join(data_root, 'fits_headers')

raw_folder = 'images'
classified_folder = 'images_classified'


## Process

### Data acquisition

In [12]:
data = prepareData(
    path=f'{data_root}/Datasets/{csv_name}.csv')
print(data.head(), '\n')

fits_headers, fits_set = prepareFits(
    headers_path=headers_path,
    fits_path=fits_path,
    headers_pattern="*.fits.hdr",
    fits_pattern="*.fits")

if calculate_coordinates:
    coordinates, plates_containing_objects = getCoordinates(
        fits_headers=fits_headers,
        ra_dec=data[['_RAJ2000', '_DEJ2000']])
    np.save(f'{data_root}/Coordinates/{csv_name}', coordinates + 1)
else:
    print('Loading coordinates ...')
    coordinates = np.load(f'{data_root}/Coordinates/{csv_name}.npy') - 1
    plates_containing_objects = fits_set

print('Done.\n')

### Extraction

In [None]:
datapoint_plates = dict({})
all_datapoints = set({})  # For statistics

incorrect_datapoints = dict({})
short_images_extracted_count = 0
short_20_images = set({})

n_headers = len(fits_headers)

print('Extracting objects by plate ...')

plates_dataset = pd.DataFrame(columns=data.columns)
plate_dir = output_path

if mode == 'test':
    make_directory(f'{plate_dir}/{raw_folder}')

for i in tqdm(range(n_headers)):
    plate = fits_headers[i].split('/')[-1].split('.')[0]
    plate_num = plate.split('_')[0][3:]

    if plate in fits_set and plate in plates_containing_objects:
        ###################
        # t0 = perf_counter()
        ###################
        if mode == 'train':
            plate_dir = os_path.join(output_path, plate)

        if mode == 'train':
            make_directory(f'{plate_dir}/{raw_folder}')
            for class_name in np.unique(data[class_column]):
                make_directory(f'{plate_dir}/{classified_folder}/{class_name}')

        ###################
        # t1 = perf_counter()
        # print(f'Directory preparations: {t1 - t0}')
        ###################

        fbs_plate = fits.open(f'{fits_path}/{plate}.fits')

        ###################
        # t2 = perf_counter()
        # print(f'Opening fits: {t2 - t1}')
        ###################

        plate_img = np.array(fbs_plate[0].data, dtype=np.uint16)
        shape_y, shape_x = plate_img.shape
        del fbs_plate

        scaled_img = ((plate_img - plate_img.min()) / (plate_img.max() - plate_img.min()) * 255).astype(np.uint8)

        if np.mean(scaled_img) < 127.5:
            scaled_img = np.invert(scaled_img)

        ###################
        # t3 = perf_counter()
        # print(f'Scaling: {t3 - t2}')
        ###################

        gblur = cv.GaussianBlur(scaled_img, (3, 3), 2, 2)
        del scaled_img

        ###################
        # t4 = perf_counter()
        # print(f'Gaussian blurring: {t4 - t3}')
        ###################

        g_th = cv.adaptiveThreshold(gblur, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv.THRESH_BINARY_INV, blockSize=block_size, C=constant)
        del gblur

        ###################
        # t5 = perf_counter()
        # print(f'Thresholding: {t5 - t4}')
        ###################

        plate_datapoints = getPlateCoordinates(coordinates[i], shape_x, shape_y)

        ###################
        # t6 = perf_counter()
        # print(f'Getting plate coordinates: {t6 - t5}')
        ###################

        for pd_i in plate_datapoints:

            if pd_i not in all_datapoints:  # For statistics
                all_datapoints.add(pd_i)

            dx, dy = np.round(coordinates[i, pd_i]).astype(int)
            if g_th[dy, dx] == 255:
                while (dy < shape_y - 1) and can_go_down(g_th, dx, dy):
                    dy += 1

                y1, x1, y2, x2 = getContourEdges(g_th, dx, dy)
                if all([y1, x1, y2, x2]):
                    result = plate_img[y1:y2, x1:x2]
                    # result_sized = cv.resize(result, (20, 140))

                    datapoint_plates[pd_i] = dict({  ##########################
                        'plate': plate,
                        'dx': dx,
                        'dy': dy,
                    })

                    full_index = f'{plate_num}_{pd_i}'

                    if mode == 'train':
                        image_path = f'{plate_dir}/{raw_folder}/{pd_i}__{data.iloc[pd_i]["Name"]}.tiff'

                        classes_path = f'{plate_dir}/{classified_folder}/{data.iloc[pd_i][class_column]}' \
                                       f'/{pd_i}__{data.iloc[pd_i]["Name"]}.tiff'

                        cv.imwrite(classes_path, result)

                    else:
                        image_path = f'{plate_dir}/{raw_folder}/{full_index}__{data.iloc[pd_i]["Name"]}.tiff'

                    plates_dataset.loc[full_index] = data.iloc[pd_i]
                    plates_dataset.loc[full_index, 'dx'] = dx
                    plates_dataset.loc[full_index, 'dy'] = dy
                    plates_dataset.loc[full_index, 'plate'] = plate
                    plates_dataset.loc[full_index, 'path'] = image_path

                    cv.imwrite(image_path, result)

                    incorrect_datapoints.pop(pd_i, None)
                else:
                    if pd_i not in short_20_images:
                        short_20_images.add(pd_i)

            else:
                extracted = False
                for i_x in range(max(0, dx - 2), min(shape_x, dx + 3)):
                    if extracted:
                        break
                    for i_y in range(dy, max(-1, dy - 3), -1):
                        if extracted:
                            break
                        if i_x == dx and i_y == dy:
                            continue
                        if g_th[i_y, i_x] == 255:
                            # Copy i_y
                            y = int(i_y)

                            while (y < shape_y - 1) and can_go_down(g_th, i_x, y):
                                y += 1

                            # try:
                            y1, x1, y2, x2 = getContourEdges(g_th, i_x, y)
                            if all([y1, x1, y2, x2]):
                                result = plate_img[y1:y2, x1:x2]

                                datapoint_plates[pd_i] = dict({  ##########################
                                    'plate': plate,
                                    'dx': i_x,
                                    'dy': y,
                                })

                                full_index = f'{plate_num}_{pd_i}'

                                if mode == 'train':
                                    image_path = f'{plate_dir}/{raw_folder}/{pd_i}__{data.iloc[pd_i]["Name"]}.tiff'

                                    classes_path = f'{plate_dir}/{classified_folder}/' \
                                        f'{data.iloc[pd_i][class_column]}/{pd_i}__{data.iloc[pd_i]["Name"]}.tiff'

                                    cv.imwrite(classes_path, result)
                                else:
                                    image_path = f'{plate_dir}/{raw_folder}/{full_index}__{data.iloc[pd_i]["Name"]}.tiff'

                                plates_dataset.loc[full_index] = data.iloc[pd_i]
                                plates_dataset.loc[full_index, 'dx'] = dx
                                plates_dataset.loc[full_index, 'dy'] = dy
                                plates_dataset.loc[full_index, 'plate'] = plate
                                plates_dataset.loc[full_index, 'path'] = image_path

                                cv.imwrite(image_path, result)

                                incorrect_datapoints.pop(pd_i, None)

                                extracted = True
                                short_images_extracted_count += 1

                if not extracted:
                    if pd_i not in incorrect_datapoints:
                        incorrect_datapoints[pd_i] = [plate]
                    else:
                        incorrect_datapoints[pd_i].append(plate)

        ###################
        # t7 = perf_counter()
        # print(f'Extracting: {t7 - t6}')
        ###################

        ###################
        # print(f'Saving csv: {perf_counter() - t7}')
        ###################

plates_dataset.to_csv(f'{plate_dir}/extracted.csv')

## Just for Statistics

In [None]:
print()
print('all_datapoints:', len(all_datapoints))
print('incorrect_datapoints:', len(incorrect_datapoints))
print('short_20_images:', len(short_20_images.difference(incorrect_datapoints, datapoint_plates)))
print('extracted images count:',
      len(all_datapoints)
      - len(short_20_images.difference(incorrect_datapoints, datapoint_plates))
      - len(incorrect_datapoints))
print('short_images_extracted_count:', short_images_extracted_count)
