In [1]:
import cv2
import numpy as np
import os
import pandas as pd
from astropy.io import fits
from tensorflow.keras.models import load_model
from keras.models import save_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from datetime import datetime
from sklearn.metrics import precision_score, recall_score, f1_score

def find_latest_dictionary_csv(folder_path):
    csv_files = [f for f in os.listdir(folder_path) if f.startswith('dictionary') and f.endswith('.csv')]
    if not csv_files:
        print("No dictionary CSV file found.")
        return None
    latest_csv = max(csv_files)
    return latest_csv

def calculate_metrics(stats_df):
    correct_predictions = stats_df[stats_df['ground_truth'] == stats_df['cnn_prediction']]
    accuracy = len(correct_predictions) / len(stats_df)
    precision = precision_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    recall = recall_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    f1 = f1_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    return accuracy, precision, recall, f1

def preprocess_fit(image_path):
    with fits.open(image_path) as hdul:
        image_data = hdul[0].data.astype(np.uint8)
        resized_image = cv2.resize(image_data, (224, 224))
        rgb_image = cv2.cvtColor(resized_image, cv2.COLOR_GRAY2RGB)
    return np.expand_dims(rgb_image, axis=0)


def detect_lines_cnn(image_path, model):
    preprocessed_image = preprocess_fit(image_path)
    prediction = model.predict(preprocessed_image, verbose=0)[0][0]
    if prediction >= 0.5:
        return True
    else:
        return False

def load_dictionary(folder_path):
    latest_dictionary_csv = find_latest_dictionary_csv(folder_path)
    if latest_dictionary_csv:
        dictionary_df = pd.read_csv(os.path.join(folder_path, latest_dictionary_csv))
        return dictionary_df
    else:
        print("Exiting program.")
        exit()

def read_fit_data(folder_path, dictionary_df, model):
    filenames = []
    ground_truth = []
    cnn_predictions = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.fit') and filename.startswith('tic'):
            image_path = os.path.join(folder_path, filename)
            matching_row = dictionary_df[dictionary_df['output'] == filename]
            if len(matching_row) == 0:
                print(f"No label found for filename: {filename}")
                continue
            label = matching_row.iloc[0]['label']
            ground_truth.append(label)
            cnn_prediction = detect_lines_cnn(image_path, model)
            cnn_predictions.append(cnn_prediction)
            filenames.append(filename)
    return pd.DataFrame({
        'filename': filenames,
        'ground_truth': ground_truth,
        'cnn_prediction': cnn_predictions
    })

def train_model_on_test_data(model, folder_path):
    dictionary_df = load_dictionary(folder_path)

    stats_df = read_fit_data(folder_path, dictionary_df, model)

    accuracy, precision, recall, f1 = calculate_metrics(stats_df)
    print("Metrics on Test Data before re-training:")
    print("Accuracy:", accuracy)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

    model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

    train_data = []
    train_labels = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.fit') and filename.startswith('tic'):
            image_path = os.path.join(folder_path, filename)
            train_data.append(preprocess_fit(image_path))
            matching_row = dictionary_df[dictionary_df['output'] == filename]
            if len(matching_row) == 0:
                print(f"No label found for filename: {filename}")
                continue
            label = matching_row.iloc[0]['label']
            train_labels.append(label)
    train_data = np.array(train_data)
    train_labels = np.array(train_labels)
    print("Shape of train_data before squeezing:", train_data.shape)  
    train_data = np.squeeze(train_data)  # Remove extra dimension
    print("Shape of train_data after squeezing:", train_data.shape)


    history = model.fit(train_data, train_labels, epochs=5, batch_size=32, validation_split=0.2)

    val_accuracy = history.history['val_accuracy']
    print("Validation Accuracy during training:")
    print(val_accuracy)

    # After training, re-evaluate the model on the test data
    stats_df_after_training = read_fit_data(folder_path, dictionary_df, model)

    # Calculate metrics after re-training
    accuracy_after_training, precision_after_training, recall_after_training, f1_after_training = calculate_metrics(stats_df_after_training)
    print("Metrics on Test Data after re-training:")
    print("Accuracy:", accuracy_after_training)
    print("Precision:", precision_after_training)
    print("Recall:", recall_after_training)
    print("F1 Score:", f1_after_training)


folder_path = './Data/fits/'

# Load the serialized model
serialized_model_path = 'trained_model_2024-04-15_17-04-20.h5'
model = load_model(serialized_model_path)

# Train the model on the test data
train_model_on_test_data(model, folder_path)


# serialized_upgraded_model_path = "resnetModel25IV2024retrained92acc.h5"
# model.save(serialized_upgraded_model_path)

# print("Upgraded model serialized and saved successfully!")





Metrics on Test Data before re-training:
Accuracy: 0.6764705882352942
Precision: 0.631578947368421
Recall: 0.75
F1 Score: 0.6857142857142857
Shape of train_data before squeezing: (34, 1, 224, 224, 3)
Shape of train_data after squeezing: (34, 224, 224, 3)
Epoch 1/5


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Validation Accuracy during training:
[0.7142857313156128, 0.7142857313156128, 0.8571428656578064, 0.8571428656578064, 0.8571428656578064]
Metrics on Test Data after re-training:
Accuracy: 0.7352941176470589
Precision: 0.6842105263157895
Recall: 0.8125
F1 Score: 0.742857142857143
