In [None]:
import os
import tensorflow as tf
import astropy.io.fits as fits
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import matplotlib.pyplot as plt

Constants defined on project level, change these to match your project in each notebook

In [None]:
project_directory = os.path.dirname(os.getcwd())
data_directory = os.path.join(project_directory, "data")
dataset_directory = os.path.join(project_directory, "tf_data")
model_directory = os.path.join(project_directory, "models")
converted_folder_name = "converted"

model_name = "model_v2" # Change this
directory_to_process = "/Users/jeroendenotter/Library/CloudStorage/OneDrive-SharedLibraries-MinnesotaState/Rutkowski, Michael J - PASSAGE/maps" # Change this
model = tf.keras.models.load_model(os.path.join(model_directory, f'{model_name}.keras'))
class_names = sorted([name for name in os.listdir(dataset_directory) if os.path.isdir(os.path.join(dataset_directory, name))])

Function to process a single fits file and save it as a png

In [None]:
def process_fits_file(path):
    filename = os.path.basename(path)
    seg_id = int(filename.split('_')[1].split('.')[0])
    
    raw_data = fits.open(path)
    science, segmentation = raw_data["DSCI"].data, raw_data["SEG"].data

    science[np.where(segmentation != seg_id)] = 0
    science = np.log10(science + 1)
    
    raw_data.close()
    # non_zero_pixels = len(science[np.where(science > 0)])  # Count non-zero pixels
    converted_path = os.path.join(data_directory, converted_folder_name, f"{seg_id}.png")
    plt.imsave(converted_path, science, cmap="gray")

Process all fits files in a directory

In [None]:
try:
    os.removedirs(os.path.join(data_directory, converted_folder_name))
except Exception:
    pass
os.makedirs(os.path.join(data_directory, converted_folder_name), exist_ok=True)

images_to_process = [os.path.join(directory_to_process, filename) for filename in os.listdir(directory_to_process) if filename.endswith(".fits")]
total_images = len(images_to_process)

with ThreadPoolExecutor() as executor:
    futures = [executor.submit(process_fits_file, path) for path in images_to_process]
    for future in tqdm(futures, total=total_images, desc='Setting up dataset'):
        future.result()
    print(f'Finished setting up dataset in {converted_folder_name}')

Function to predict n images at a time

In [None]:
def extract_channel(image, channel=0):
    return image[:, :, channel:channel+1]

def predict_batch(images, channel=0):
    images = np.array([extract_channel(image, channel) for image in images])
    images = tf.convert_to_tensor(images, dtype=tf.float32)    
    predictions = model.predict(images)
    return predictions

Predict all images in the converted folder and save result with accuracy to a csv file

In [None]:
converted_images = [os.path.join(data_directory, converted_folder_name, filename) for filename in os.listdir(os.path.join(data_directory, converted_folder_name)) if filename.endswith(".png")]
total_images = len(converted_images)    
predictions = predict_batch(([plt.imread(image) for image in converted_images]))

csv_path = os.path.join(project_directory, "predictions.csv")
with open(csv_path, 'w') as csv_file:
    csv_file.write("Filename," + ",".join([f"{name.capitalize()}_accuracy" for name in class_names]) + ",Loss\n")
    for image, prediction in zip(converted_images, predictions):
        csv_file.write(os.path.basename(image))
        for accuracy in prediction:
            csv_file.write(f",{accuracy}")

        csv_file.write("\n")