In [2]:
import cv2
import numpy as np
import os
import pandas as pd
from astropy.io import fits
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from datetime import datetime
from sklearn.metrics import precision_score, recall_score, f1_score

def calculate_metrics(stats_df):
    correct_predictions = stats_df[stats_df['ground_truth'] == stats_df['cnn_prediction']]
    accuracy = len(correct_predictions) / len(stats_df)
    precision = precision_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    recall = recall_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    f1 = f1_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    return accuracy, precision, recall, f1

def serialize_model(model, accuracy):
    if accuracy >= 0.80:
        current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        model_filename = f"trained_model_{current_datetime}.h5"
        model.save(model_filename)
        print("Model serialized successfully. Model saved as:", model_filename)
        return model_filename
    else:
        return None

# Function to preprocess FIT image
def preprocess_fit(image_path):
    with fits.open(image_path) as hdul:
        image_data = hdul[0].data.astype(np.uint8)
        resized_image = cv2.resize(image_data, (224, 224))
        rgb_image = cv2.cvtColor(resized_image, cv2.COLOR_GRAY2RGB)
        preprocessed_image = np.expand_dims(rgb_image, axis=0)
    return preprocessed_image

# Function to detect lines in FIT image using Hough Transform
def detect_lines_hough(image_path, min_line_length):
    preprocessed_image = preprocess_fit(image_path)
    preprocessed_image = preprocessed_image.astype(np.uint8)
    edges = cv2.Canny(preprocessed_image[0], 50, 150, apertureSize=3)
    lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100, minLineLength=min_line_length, maxLineGap=10)
    if lines is not None:
        return True
    else:
        return False

# Function to detect lines in FIT image using MobileNetV2
def detect_lines_cnn(image_path, min_line_length, model):
    preprocessed_image = preprocess_fit(image_path)
    prediction = model.predict(preprocessed_image, verbose=0)[0][0]
    if prediction >= 0.5:
        return True
    else:
        return False

# Find the latest dictionary CSV file
def find_latest_dictionary_csv(folder_path):
    csv_files = [f for f in os.listdir(folder_path) if f.startswith('dictionary') and f.endswith('.csv')]
    if not csv_files:
        print("No dictionary CSV file found.")
        return None
    latest_csv = max(csv_files)
    return latest_csv

# Load labels from the latest dictionary CSV file
def load_dictionary(folder_path):
    latest_dictionary_csv = find_latest_dictionary_csv(folder_path)
    if latest_dictionary_csv:
        dictionary_df = pd.read_csv(os.path.join(folder_path, latest_dictionary_csv))
        return dictionary_df
    else:
        print("Exiting program.")
        exit()

# Function to read FIT files and labels
def read_fit_data(folder_path, dictionary_df, model):
    filenames = []
    ground_truth = []
    hough_predictions = []
    cnn_predictions = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.fit') and filename.startswith('tic'):
            image_path = os.path.join(folder_path, filename)
            matching_row = dictionary_df[dictionary_df['output'] == filename]
            if len(matching_row) == 0:
                print(f"No label found for filename: {filename}")
                continue
            label = matching_row.iloc[0]['label']
            ground_truth.append(label)
            hough_prediction = detect_lines_hough(image_path, min_line_length=100)
            hough_predictions.append(hough_prediction)
            cnn_prediction = detect_lines_cnn(image_path, min_line_length=100, model=model)
            cnn_predictions.append(cnn_prediction)
            filenames.append(filename)
    return pd.DataFrame({
        'filename': filenames,
        'ground_truth': ground_truth,
        'hough_prediction': hough_predictions,
        'cnn_prediction': cnn_predictions
    })

# Define the folder containing FIT files
folder_path = './Data/fits/'

metrics_list = []

for i in range(100):
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    for layer in base_model.layers:
        layer.trainable = False

    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(256, activation='relu')(x)
    predictions = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

    dictionary_df = load_dictionary(folder_path)
    stats_df = read_fit_data(folder_path, dictionary_df, model)

    accuracy, precision, recall, f1 = calculate_metrics(stats_df)

    print(i, ".: CNN Model Metrics:", accuracy, precision, recall, f1)
    
    model_filename = serialize_model(model, accuracy)
    if model_filename:
        print("Results:")
        display(stats_df)
    
    metrics_list.append({'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1})

# Save the list of metrics to a CSV file
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
csv_filename = f"model_metrics_mobilenetv2_{current_datetime}.csv"
metrics_df = pd.DataFrame(metrics_list)
metrics_df.to_csv(csv_filename, index=False)
print("Model metrics saved to:", csv_filename)


0 .: CNN Model Metrics: 0.4230769230769231 0.4583333333333333 0.8461538461538461 0.5945945945945945
1 .: CNN Model Metrics: 0.46153846153846156 0.4782608695652174 0.8461538461538461 0.6111111111111112
2 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
3 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
4 .: CNN Model Metrics: 0.4230769230769231 0.4583333333333333 0.8461538461538461 0.5945945945945945
5 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
6 .: CNN Model Metrics: 0.5384615384615384 1.0 0.07692307692307693 0.14285714285714288


  _warn_prf(average, modifier, msg_start, len(result))


7 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
8 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
9 .: CNN Model Metrics: 0.46153846153846156 0.47368421052631576 0.6923076923076923 0.5625


  _warn_prf(average, modifier, msg_start, len(result))


10 .: CNN Model Metrics: 0.5 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


11 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
12 .: CNN Model Metrics: 0.6538461538461539 1.0 0.3076923076923077 0.47058823529411764


  _warn_prf(average, modifier, msg_start, len(result))


13 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
14 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
15 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
16 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
17 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
18 .: CNN Model Metrics: 0.7307692307692307 0.6875 0.8461538461538461 0.7586206896551724
19 .: CNN Model Metrics: 0.6153846153846154 0.8 0.3076923076923077 0.4444444444444444
20 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666


  _warn_prf(average, modifier, msg_start, len(result))


21 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
22 .: CNN Model Metrics: 0.46153846153846156 0.42857142857142855 0.23076923076923078 0.3
23 .: CNN Model Metrics: 0.6153846153846154 0.8 0.3076923076923077 0.4444444444444444
24 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
25 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
26 .: CNN Model Metrics: 0.5769230769230769 0.6 0.46153846153846156 0.5217391304347826


  _warn_prf(average, modifier, msg_start, len(result))


27 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
28 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
29 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
30 .: CNN Model Metrics: 0.5769230769230769 0.5555555555555556 0.7692307692307693 0.6451612903225806
31 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
32 .: CNN Model Metrics: 0.5384615384615384 0.6 0.23076923076923078 0.33333333333333337
33 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
34 .: CNN Model Metrics: 0.46153846153846156 0.45454545454545453 0.38461538461538464 0.41666666666666663
35 .: CNN Model Metrics: 0.5384615384615384 0.5555555555555556 0.38461538461538464 0.4545454545454546
36 .: CNN Model Metrics: 0.6923076923076923 0.7777777777777778 0.5384615384615384 0.6363636363636364


  _warn_prf(average, modifier, msg_start, len(result))


37 .: CNN Model Metrics: 0.5 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


38 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
39 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
40 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
41 .: CNN Model Metrics: 0.6538461538461539 0.8333333333333334 0.38461538461538464 0.5263157894736842
42 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666


  _warn_prf(average, modifier, msg_start, len(result))


43 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
44 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666


  _warn_prf(average, modifier, msg_start, len(result))


45 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
46 .: CNN Model Metrics: 0.5 0.5 0.15384615384615385 0.23529411764705882
47 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
48 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
49 .: CNN Model Metrics: 0.4230769230769231 0.42857142857142855 0.46153846153846156 0.4444444444444445


  _warn_prf(average, modifier, msg_start, len(result))


50 .: CNN Model Metrics: 0.5 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


51 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
52 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
53 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666


  _warn_prf(average, modifier, msg_start, len(result))


54 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
55 .: CNN Model Metrics: 0.6538461538461539 0.6111111111111112 0.8461538461538461 0.7096774193548387
56 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666


  _warn_prf(average, modifier, msg_start, len(result))


57 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
58 .: CNN Model Metrics: 0.46153846153846156 0.48 0.9230769230769231 0.631578947368421


  _warn_prf(average, modifier, msg_start, len(result))


59 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
60 .: CNN Model Metrics: 0.5384615384615384 0.5555555555555556 0.38461538461538464 0.4545454545454546
61 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
62 .: CNN Model Metrics: 0.4230769230769231 0.45454545454545453 0.7692307692307693 0.5714285714285714


  _warn_prf(average, modifier, msg_start, len(result))


63 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
64 .: CNN Model Metrics: 0.46153846153846156 0.48 0.9230769230769231 0.631578947368421


  _warn_prf(average, modifier, msg_start, len(result))


65 .: CNN Model Metrics: 0.5 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


66 .: CNN Model Metrics: 0.5 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


67 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
68 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
69 .: CNN Model Metrics: 0.38461538461538464 0.42857142857142855 0.6923076923076923 0.5294117647058824
70 .: CNN Model Metrics: 0.46153846153846156 0.4782608695652174 0.8461538461538461 0.6111111111111112
71 .: CNN Model Metrics: 0.5769230769230769 0.55 0.8461538461538461 0.6666666666666667


  _warn_prf(average, modifier, msg_start, len(result))


72 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
73 .: CNN Model Metrics: 0.46153846153846156 0.47619047619047616 0.7692307692307693 0.588235294117647
74 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666


  _warn_prf(average, modifier, msg_start, len(result))


75 .: CNN Model Metrics: 0.5 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


76 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
77 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
78 .: CNN Model Metrics: 0.46153846153846156 0.4 0.15384615384615385 0.2222222222222222
79 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
80 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
81 .: CNN Model Metrics: 0.38461538461538464 0.42857142857142855 0.6923076923076923 0.5294117647058824
82 .: CNN Model Metrics: 0.4230769230769231 0.45454545454545453 0.7692307692307693 0.5714285714285714


  _warn_prf(average, modifier, msg_start, len(result))


83 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
84 .: CNN Model Metrics: 0.4230769230769231 0.42857142857142855 0.46153846153846156 0.4444444444444445
85 .: CNN Model Metrics: 0.6923076923076923 1.0 0.38461538461538464 0.5555555555555556


  _warn_prf(average, modifier, msg_start, len(result))


86 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
87 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
88 .: CNN Model Metrics: 0.5384615384615384 1.0 0.07692307692307693 0.14285714285714288
89 .: CNN Model Metrics: 0.4230769230769231 0.4583333333333333 0.8461538461538461 0.5945945945945945
90 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
91 .: CNN Model Metrics: 0.4230769230769231 0.4166666666666667 0.38461538461538464 0.4
92 .: CNN Model Metrics: 0.5384615384615384 0.5384615384615384 0.5384615384615384 0.5384615384615384
93 .: CNN Model Metrics: 0.5769230769230769 1.0 0.15384615384615385 0.2666666666666667


  _warn_prf(average, modifier, msg_start, len(result))


94 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
95 .: CNN Model Metrics: 0.5384615384615384 1.0 0.07692307692307693 0.14285714285714288


  _warn_prf(average, modifier, msg_start, len(result))


96 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
97 .: CNN Model Metrics: 0.46153846153846156 0.48 0.9230769230769231 0.631578947368421
98 .: CNN Model Metrics: 0.5 0.5 1.0 0.6666666666666666
99 .: CNN Model Metrics: 0.5769230769230769 1.0 0.15384615384615385 0.2666666666666667
Model metrics saved to: model_metrics_mobilenetv2_2024-04-17_17-23-06.csv
