In [5]:
import os
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from datetime import datetime
from tensorflow.keras.applications import ResNet50
from sklearn.metrics import precision_score, recall_score, f1_score

def calculate_metrics(stats_df):
    correct_predictions = stats_df[stats_df['ground_truth'] == stats_df['cnn_prediction']]
    accuracy = len(correct_predictions) / len(stats_df)
    precision = precision_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    recall = recall_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    f1 = f1_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    return accuracy, precision, recall, f1

def serialize_model(model, accuracy):
    if accuracy >= 0.80:
        current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        model_filename = f"trained_model_{current_datetime}.h5"
        model.save(model_filename)
        print("Model serialized successfully. Model saved as:", model_filename)
        return model_filename
    else:
        return None

def preprocess_png(image_path):
    image = cv2.imread(image_path)
    resized_image = cv2.resize(image, (224, 224))
    preprocessed_image = np.expand_dims(resized_image, axis=0)
    return preprocessed_image

def detect_lines_cnn(image_path, model):
    preprocessed_image = preprocess_png(image_path)
    prediction = model.predict(preprocessed_image, verbose=0)[0][0]
    if prediction >= 0.5:
        return True
    else:
        return False

def find_latest_dictionary_csv(folder_path):
    csv_files = [f for f in os.listdir(folder_path) if f.startswith('dictionary') and f.endswith('.csv')]
    if not csv_files:
        print("No dictionary CSV file found.")
        return None
    latest_csv = max(csv_files)
    return latest_csv

def load_dictionary(folder_path):
    latest_dictionary_csv = find_latest_dictionary_csv(folder_path)
    if latest_dictionary_csv:
        dictionary_df = pd.read_csv(os.path.join(folder_path, latest_dictionary_csv))
        return dictionary_df
    else:
        print("Exiting program.")
        exit()

def read_png_data(folder_path, dictionary_df, model):
    filenames = []
    ground_truth = []
    cnn_predictions = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.png'):
            image_path = os.path.join(folder_path, filename)
            matching_row = dictionary_df[dictionary_df['output'] == filename]
            if len(matching_row) == 0:
                print(f"No label found for filename: {filename}")
                continue
            label = matching_row.iloc[0]['label']
            ground_truth.append(label)
            cnn_prediction = detect_lines_cnn(image_path, model)
            cnn_predictions.append(cnn_prediction)
            filenames.append(filename)
    return pd.DataFrame({
        'filename': filenames,
        'ground_truth': ground_truth,
        'cnn_prediction': cnn_predictions
    })

def save_metrics_to_csv(metrics_list):
    current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    csv_filename = f"model_metrics_resnet_{current_datetime}.csv"
    metrics_df = pd.DataFrame(metrics_list, columns=['Accuracy', 'Precision', 'Recall', 'F1'])
    metrics_df.to_csv(csv_filename, index=False)
    print("Model metrics saved to:", csv_filename)

folder_path = './Data/fits_filtered8/'

metrics_list = []

for i in range(100):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    for layer in base_model.layers:
        layer.trainable = False

    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(256, activation='relu')(x)
    predictions = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

    dictionary_df = load_dictionary(folder_path)
    stats_df = read_png_data(folder_path, dictionary_df, model)

    accuracy, precision, recall, f1 = calculate_metrics(stats_df)

    print(i, ".: CNN Model Metrics:", accuracy, precision, recall, f1)
    model_filename = serialize_model(model, accuracy)
    metrics_list.append({'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1})

# Save the list of metrics to a CSV file
save_metrics_to_csv(metrics_list)


0 .: CNN Model Metrics: 0.47058823529411764 0.45454545454545453 0.625 0.5263157894736842


  _warn_prf(average, modifier, msg_start, len(result))


1 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
2 .: CNN Model Metrics: 0.5588235294117647 1.0 0.0625 0.11764705882352941
3 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
4 .: CNN Model Metrics: 0.4411764705882353 0.45454545454545453 0.9375 0.6122448979591837
5 .: CNN Model Metrics: 0.5882352941176471 0.5555555555555556 0.625 0.5882352941176471
6 .: CNN Model Metrics: 0.47058823529411764 0.4666666666666667 0.875 0.608695652173913
7 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
8 .: CNN Model Metrics: 0.47058823529411764 0.4375 0.4375 0.4375
9 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
10 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


11 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


12 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
13 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


14 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
15 .: CNN Model Metrics: 0.35294117647058826 0.2857142857142857 0.25 0.26666666666666666
16 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
17 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


18 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
19 .: CNN Model Metrics: 0.5 0.4838709677419355 0.9375 0.6382978723404255


  _warn_prf(average, modifier, msg_start, len(result))


20 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
21 .: CNN Model Metrics: 0.47058823529411764 0.3333333333333333 0.125 0.18181818181818182
22 .: CNN Model Metrics: 0.47058823529411764 0.45454545454545453 0.625 0.5263157894736842
23 .: CNN Model Metrics: 0.5588235294117647 0.5384615384615384 0.4375 0.4827586206896552
24 .: CNN Model Metrics: 0.5 0.3333333333333333 0.0625 0.10526315789473684
25 .: CNN Model Metrics: 0.6176470588235294 1.0 0.1875 0.3157894736842105
26 .: CNN Model Metrics: 0.5882352941176471 1.0 0.125 0.2222222222222222
27 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


28 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
29 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
30 .: CNN Model Metrics: 0.35294117647058826 0.2 0.125 0.15384615384615385
31 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


32 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
33 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


34 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
35 .: CNN Model Metrics: 0.5882352941176471 0.55 0.6875 0.6111111111111112


  _warn_prf(average, modifier, msg_start, len(result))


36 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


37 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


38 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


39 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
40 .: CNN Model Metrics: 0.5294117647058824 0.5 0.0625 0.1111111111111111
41 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
42 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
43 .: CNN Model Metrics: 0.5 0.48484848484848486 1.0 0.653061224489796
44 .: CNN Model Metrics: 0.5588235294117647 0.5172413793103449 0.9375 0.6666666666666667
45 .: CNN Model Metrics: 0.38235294117647056 0.2222222222222222 0.125 0.16
46 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
47 .: CNN Model Metrics: 0.47058823529411764 0.46875 0.9375 0.625
48 .: CNN Model Metrics: 0.5588235294117647 0.5555555555555556 0.3125 0.39999999999999997
49 .: CNN Model Metrics: 0.5 0.48484848484848486 1.0 0.653061224489796


  _warn_prf(average, modifier, msg_start, len(result))


50 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


51 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
52 .: CNN Model Metrics: 0.5882352941176471 1.0 0.125 0.2222222222222222
53 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


54 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


55 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


56 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
57 .: CNN Model Metrics: 0.47058823529411764 0.0 0.0 0.0
58 .: CNN Model Metrics: 0.38235294117647056 0.2727272727272727 0.1875 0.2222222222222222
59 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
60 .: CNN Model Metrics: 0.4117647058823529 0.16666666666666666 0.0625 0.09090909090909091
61 .: CNN Model Metrics: 0.47058823529411764 0.46875 0.9375 0.625


  _warn_prf(average, modifier, msg_start, len(result))


62 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


63 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
64 .: CNN Model Metrics: 0.47058823529411764 0.25 0.0625 0.1
65 .: CNN Model Metrics: 0.5588235294117647 0.5555555555555556 0.3125 0.39999999999999997
66 .: CNN Model Metrics: 0.38235294117647056 0.41935483870967744 0.8125 0.5531914893617021
67 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
68 .: CNN Model Metrics: 0.5882352941176471 0.5833333333333334 0.4375 0.5


  _warn_prf(average, modifier, msg_start, len(result))


69 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
70 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
71 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
72 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
73 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
74 .: CNN Model Metrics: 0.5588235294117647 0.52 0.8125 0.6341463414634146


  _warn_prf(average, modifier, msg_start, len(result))


75 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
76 .: CNN Model Metrics: 0.5294117647058824 0.5 0.9375 0.6521739130434783


  _warn_prf(average, modifier, msg_start, len(result))


77 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
78 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


79 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
80 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
81 .: CNN Model Metrics: 0.5 0.48484848484848486 1.0 0.653061224489796
82 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
83 .: CNN Model Metrics: 0.5294117647058824 0.5 0.5625 0.5294117647058824
84 .: CNN Model Metrics: 0.5588235294117647 1.0 0.0625 0.11764705882352941
85 .: CNN Model Metrics: 0.47058823529411764 0.4666666666666667 0.875 0.608695652173913
86 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
87 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
88 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999
89 .: CNN Model Metrics: 0.5 0.0 0.0 0.0
90 .: CNN Model Metrics: 0.4117647058823529 0.0 0.0 0.0


  _warn_prf(average, modifier, msg_start, len(result))


91 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
92 .: CNN Model Metrics: 0.47058823529411764 0.47058823529411764 1.0 0.6399999999999999


  _warn_prf(average, modifier, msg_start, len(result))


93 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
94 .: CNN Model Metrics: 0.5 0.48484848484848486 1.0 0.653061224489796


  _warn_prf(average, modifier, msg_start, len(result))


95 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
96 .: CNN Model Metrics: 0.5 0.4827586206896552 0.875 0.6222222222222222
97 .: CNN Model Metrics: 0.5 0.48484848484848486 1.0 0.653061224489796
98 .: CNN Model Metrics: 0.5588235294117647 1.0 0.0625 0.11764705882352941
99 .: CNN Model Metrics: 0.5294117647058824 0.0 0.0 0.0
Model metrics saved to: model_metrics_resnet_2024-05-09_16-02-22.csv


  _warn_prf(average, modifier, msg_start, len(result))
