In [4]:
import cv2
import numpy as np
import os
import pandas as pd
from astropy.io import fits
from tensorflow.keras.models import load_model
from sklearn.metrics import precision_score, recall_score, f1_score

# Load the pre-trained model
# model = load_model('trained_model.h5')
# model = load_model('trained_model_2024-04-15_17-04-20.h5')
model = load_model('resnetModel23IV2024retrained92acc.h5')

# Load labels from the latest dictionary CSV file
def load_dictionary(folder_path):
    # Find the latest dictionary CSV file
    latest_dictionary_csv = find_latest_dictionary_csv(folder_path)
    if latest_dictionary_csv:
        # Load the CSV file into a DataFrame
        dictionary_df = pd.read_csv(os.path.join(folder_path, latest_dictionary_csv))
        return dictionary_df
    else:
        print("Exiting program.")
        exit()

# Find the latest dictionary CSV file
def find_latest_dictionary_csv(folder_path):
    # List all CSV files in the folder that start with 'dictionary' and end with '.csv'
    csv_files = [f for f in os.listdir(folder_path) if f.startswith('dictionary') and f.endswith('.csv')]
    if not csv_files:
        print("No dictionary CSV file found.")
        return None
    # Get the latest CSV file based on file name
    latest_csv = max(csv_files)
    return latest_csv

# Function to preprocess FIT image
def preprocess_fit(image_path):
    # Read FIT file
    with fits.open(image_path) as hdul:
        # Extract image data
        image_data = hdul[0].data.astype(np.uint8)  # Convert to uint8
        # Resize image data to 224x224 (MobileNetV2 input shape)
        resized_image = cv2.resize(image_data, (224, 224))
        # Convert to RGB
        rgb_image = cv2.cvtColor(resized_image, cv2.COLOR_GRAY2RGB)
        # Expand dimensions to create batch axis
        preprocessed_image = np.expand_dims(rgb_image, axis=0)
    return preprocessed_image

# Function to detect lines in FIT image using serialized model
def detect_lines_cnn(image_path, min_line_length):
    # Preprocess the FIT image
    preprocessed_image = preprocess_fit(image_path)
    
    # Predict using the serialized model
    prediction = model.predict(preprocessed_image)[0][0]
    
    if prediction >= 0.5:
        return True
    else:
        return False

# Function to read FIT files and make predictions
def test_model_on_new_data(folder_path, dictionary_df):
    filenames = []
    ground_truth = []
    cnn_predictions = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.fit') and filename.startswith('tic'):
            image_path = os.path.join(folder_path, filename)
            cnn_prediction = detect_lines_cnn(image_path, min_line_length=100)
            cnn_predictions.append(cnn_prediction)
            
            # Get ground truth label from dictionary
            matching_row = dictionary_df[dictionary_df['output'] == filename]
            if len(matching_row) == 0:
                print(f"No label found for filename: {filename}")
                ground_truth.append(None)
            else:
                ground_truth.append(matching_row['label'].values[0])
                
            filenames.append(filename)
    
    return pd.DataFrame({
        'filename': filenames,
        'ground_truth': ground_truth,
        'cnn_prediction': cnn_predictions
    })

# Define the folder containing FIT files for testing
test_folder_path = './Data/fits/'
# test_folder_path = './Data/fits_2024-04-11_13-43-56/'
# test_folder_path = './Data/fits_2024-04-11_14-06-59/'
# test_folder_path = './Data/fits_2024-04-11_14-10-55/'

# Load dictionary from the latest CSV file in the folder
dictionary_df = load_dictionary(test_folder_path)

# Test the serialized model on new FIT files
test_stats_df = test_model_on_new_data(test_folder_path, dictionary_df)

# Display DataFrame
print("Results for new FIT files:")
display(test_stats_df)

# Calculate overall accuracy
correct_predictions = (test_stats_df['ground_truth'] == test_stats_df['cnn_prediction']).sum()
total_predictions = len(test_stats_df)
overall_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
print("Overall Accuracy:", overall_accuracy)

# Calculate precision, recall, and F1 score
precision = precision_score(test_stats_df['ground_truth'], test_stats_df['cnn_prediction'], average='binary')
recall = recall_score(test_stats_df['ground_truth'], test_stats_df['cnn_prediction'], average='binary')
f1 = f1_score(test_stats_df['ground_truth'], test_stats_df['cnn_prediction'], average='binary')

print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)


Results for new FIT files:


Unnamed: 0,filename,ground_truth,cnn_prediction
0,tic1.fit,1,True
1,tic10.fit,1,True
2,tic11.fit,1,True
3,tic12.fit,1,False
4,tic13.fit,1,True
5,tic14.fit,1,True
6,tic15.fit,1,True
7,tic16.fit,1,True
8,tic17.fit,1,True
9,tic18.fit,1,True


Overall Accuracy: 0.7647058823529411
Precision: 0.75
Recall: 0.75
F1 Score: 0.75
