In [None]:
import os
import tensorflow as tf
import astropy.io.fits as fits
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
import shutil

Constants defined on project level, change these to match your project in each notebook

In [None]:
project_directory = os.path.dirname(os.getcwd())
data_directory = os.path.join(project_directory, "data")
dataset_directory = os.path.join(project_directory, "tf_data")
model_directory = os.path.join(project_directory, "models")
converted_folder_name = "converted"
converted_directory = os.path.join(data_directory, converted_folder_name)

img_height = 200
img_width = 200
batch_size = 32

model_name = "model_v1" # Change this
directory_to_process = "" # Change this
model = tf.keras.models.load_model(os.path.join(model_directory, f'{model_name}.keras'))
class_names = sorted([name for name in os.listdir(dataset_directory) if os.path.isdir(os.path.join(dataset_directory, name))])

Function to process a single fits file and save it as a png

In [None]:
def alter_image(data, low=2e-3, high=1):
    
  data = np.log10(data + 1)
  data_min, data_max = np.min(data), np.max(data)
  data = (data - data_min) / (data_max - data_min)
  data = data * (high - low) + low
  
  data = np.clip(data, low, high)

  return data

In [None]:
def process_fits_file(path):
    raw_data = fits.open(path)
    
    data, header = raw_data[0].data, raw_data[0].header
    # normalizing the data
    data[data <= 0] = 1e-9
    data = np.log10(data)
    data_norm = (data - np.min(data)) / (np.max(data) - np.min(data))
    data_resized = cv2.resize(data_norm, (200,200))
    
    converted_path = str(os.path.join(converted_directory, str(os.path.basename(path).replace(".fits", ".png"))))
    plt.imsave(converted_path, data_resized, cmap="gray")
    
    raw_data.close()

In [None]:
def process_fits_map_file(path):
    filename = os.path.basename(path)
    seg_id = int(filename.split('_')[1].split('.')[0])
    raw_data = fits.open(path)
    
    science = raw_data["DSCI"].data
    segmentation = raw_data["SEG"].data
    
    science[np.where(segmentation != seg_id)] = 0
    science = alter_image(science)
    
    raw_data.close()
    converted_path = os.path.join(converted_directory, f"{seg_id}.png")
    plt.imsave(converted_path, science, cmap="gray")
    return {
        "filename": filename,
        "non_zero_pixels": len(science[np.where(science > 0.5)])
        # todo: add pixel value of the center of the segmentation
    }

Process all fits files in a directory

In [None]:
try:
    shutil.rmtree(os.path.join(data_directory, converted_folder_name))
except FileNotFoundError:
    pass
except Exception as e:
    print(f"An error occurred while trying to remove the directory: {e}")
os.makedirs(os.path.join(data_directory, converted_folder_name))

images_to_process = [os.path.join(directory_to_process, filename) for filename in os.listdir(directory_to_process) if filename.endswith(".fits")]
total_images = len(images_to_process)
file_info_list = []

with ThreadPoolExecutor() as executor:
    futures = [executor.submit(process_fits_map_file, path) for path in images_to_process]
    for future in tqdm(futures, total=total_images, desc='Setting up dataset'):
        result = future.result()
        file_info_list.append(result)
    print(f'Finished setting up dataset in {converted_folder_name}')

Predict all images in the converted folder and save result with accuracy to a csv file

In [None]:
converted_images_filepaths = [os.path.join(converted_directory, file) for file in os.listdir(converted_directory) if file.endswith(('png'))]
converted_images_filepaths.sort()

images = []
for file in converted_images_filepaths:
    images.append(cv2.imread(file, cv2.IMREAD_GRAYSCALE)) # Use cv2.IMREAD_COLOR to read in RGB format if model has 3 channel input
    
images = np.array(images)
predictions = model.predict(images)
# todo: add n amount of pixels in seg
csv_path = os.path.join(project_directory, "predictions.csv")
with open(csv_path, 'w') as csv_file:
    csv_file.write("Filename,NonZeroPixels," + ",".join([f"{name.capitalize()}_accuracy" for name in class_names]) + ",Loss\n")
    
    # Iterate over the filenames and predictions
    for i in range(len(converted_images_filepaths)):
        image_name = os.path.basename(converted_images_filepaths[i])
        non_zero_pixels = file_info_list[i]['non_zero_pixels']
        predicted_class = np.argmax(predictions[i])
        confidence = np.max(predictions[i])
        predicted_label = class_names[predicted_class]

        row = f"{image_name},{non_zero_pixels},{','.join(map(str, predictions[i]))}\n"
        csv_file.write(row)

print(f"Predictions saved to {csv_path}")