In [3]:
import os
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from datetime import datetime
from tensorflow.keras.applications import ResNet50
from sklearn.metrics import precision_score, recall_score, f1_score

def calculate_metrics(stats_df):
    correct_predictions = stats_df[stats_df['ground_truth'] == stats_df['cnn_prediction']]
    accuracy = len(correct_predictions) / len(stats_df)
    precision = precision_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    recall = recall_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    f1 = f1_score(stats_df['ground_truth'], stats_df['cnn_prediction'])
    return accuracy, precision, recall, f1

def serialize_model(model, accuracy):
    if accuracy >= 0.80:
        current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        model_filename = f"trained_model_{current_datetime}.h5"
        model.save(model_filename)
        print("Model serialized successfully. Model saved as:", model_filename)
        return model_filename
    else:
        return None

def preprocess_png(image_path):
    image = cv2.imread(image_path)
    resized_image = cv2.resize(image, (224, 224))
    preprocessed_image = np.expand_dims(resized_image, axis=0)
    return preprocessed_image

def detect_lines_cnn(image_path, model):
    preprocessed_image = preprocess_png(image_path)
    prediction = model.predict(preprocessed_image, verbose=0)[0][0]
    if prediction >= 0.5:
        return True, prediction
    else:
        return False, prediction

def find_latest_dictionary_csv(folder_path):
    csv_files = [f for f in os.listdir(folder_path) if f.startswith('dictionary') and f.endswith('.csv')]
    if not csv_files:
        print("No dictionary CSV file found.")
        return None
    latest_csv = max(csv_files)
    return latest_csv

def load_dictionary(folder_path):
    latest_dictionary_csv = find_latest_dictionary_csv(folder_path)
    if latest_dictionary_csv:
        dictionary_df = pd.read_csv(os.path.join(folder_path, latest_dictionary_csv))
        return dictionary_df
    else:
        print("Exiting program.")
        exit()

def read_png_data(folder_path, dictionary_df, model):
    filenames = []
    ground_truth = []
    cnn_predictions = []
    cnn_raw_predictions = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.png'):
            image_path = os.path.join(folder_path, filename)
            matching_row = dictionary_df[dictionary_df['output'] == filename]
            if len(matching_row) == 0:
                print(f"No label found for filename: {filename}")
                continue
            label = matching_row.iloc[0]['label']
            ground_truth.append(label)
            cnn_prediction, cnn_raw_prediction = detect_lines_cnn(image_path, model)
            cnn_predictions.append(cnn_prediction)
            cnn_raw_predictions.append(cnn_raw_prediction)
            filenames.append(filename)
    return pd.DataFrame({
        'filename': filenames,
        'ground_truth': ground_truth,
        'cnn_prediction': cnn_predictions,
        'cnn_raw_prediction': cnn_raw_predictions
    })

def save_metrics_to_csv(metrics_list):
    current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    csv_filename = f"model_metrics_resnet_{current_datetime}.csv"
    metrics_df = pd.DataFrame(metrics_list, columns=['Accuracy', 'Precision', 'Recall', 'F1'])
    metrics_df.to_csv(csv_filename, index=False)
    print("Model metrics saved to:", csv_filename)

def display_prediction_table(stats_df):
    print("Files in the checking folder:")
    print(stats_df)

folder_path = './Data/fits_filtered8/'

metrics_list = []

for i in range(10):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    for layer in base_model.layers:
        layer.trainable = False

    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(256, activation='relu')(x)
    predictions = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

    dictionary_df = load_dictionary(folder_path)
    stats_df = read_png_data(folder_path, dictionary_df, model)

    accuracy, precision, recall, f1 = calculate_metrics(stats_df)

    print(f"Iteration {i+1}: CNN Model Metrics - Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1: {f1}")
    model_filename = serialize_model(model, accuracy)
    metrics_list.append({'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1})

    display_prediction_table(stats_df)

# Save the list of metrics to a CSV file
save_metrics_to_csv(metrics_list)


Iteration 1: CNN Model Metrics - Accuracy: 0.5, Precision: 0.0, Recall: 0.0, F1: 0.0
Files in the checking folder:
     filename  ground_truth  cnn_prediction  cnn_raw_prediction
0    tic1.png             1           False            0.305520
1   tic10.png             1           False            0.283045
2   tic11.png             1           False            0.293106
3   tic12.png             1           False            0.436487
4   tic13.png             1           False            0.288077
5   tic14.png             1           False            0.314262
6   tic15.png             1           False            0.266559
7   tic16.png             1           False            0.306427
8   tic17.png             1           False            0.323461
9   tic18.png             1           False            0.327351
10  tic19.png             0           False            0.352902
11   tic2.png             1           False            0.377392
12  tic20.png             0           False          

  _warn_prf(average, modifier, msg_start, len(result))


Iteration 4: CNN Model Metrics - Accuracy: 0.5294117647058824, Precision: 0.0, Recall: 0.0, F1: 0.0
Files in the checking folder:
     filename  ground_truth  cnn_prediction  cnn_raw_prediction
0    tic1.png             1           False            0.091008
1   tic10.png             1           False            0.089349
2   tic11.png             1           False            0.085393
3   tic12.png             1           False            0.146599
4   tic13.png             1           False            0.092880
5   tic14.png             1           False            0.086515
6   tic15.png             1           False            0.091339
7   tic16.png             1           False            0.094924
8   tic17.png             1           False            0.109176
9   tic18.png             1           False            0.095847
10  tic19.png             0           False            0.119645
11   tic2.png             1           False            0.126969
12  tic20.png             0           

  _warn_prf(average, modifier, msg_start, len(result))


Iteration 5: CNN Model Metrics - Accuracy: 0.5294117647058824, Precision: 0.0, Recall: 0.0, F1: 0.0
Files in the checking folder:
     filename  ground_truth  cnn_prediction  cnn_raw_prediction
0    tic1.png             1           False            0.080049
1   tic10.png             1           False            0.102583
2   tic11.png             1           False            0.099241
3   tic12.png             1           False            0.060338
4   tic13.png             1           False            0.086333
5   tic14.png             1           False            0.082612
6   tic15.png             1           False            0.061503
7   tic16.png             1           False            0.068904
8   tic17.png             1           False            0.075130
9   tic18.png             1           False            0.073368
10  tic19.png             0           False            0.081783
11   tic2.png             1           False            0.072209
12  tic20.png             0           

  _warn_prf(average, modifier, msg_start, len(result))


Iteration 9: CNN Model Metrics - Accuracy: 0.5294117647058824, Precision: 0.0, Recall: 0.0, F1: 0.0
Files in the checking folder:
     filename  ground_truth  cnn_prediction  cnn_raw_prediction
0    tic1.png             1           False            0.080335
1   tic10.png             1           False            0.127161
2   tic11.png             1           False            0.163464
3   tic12.png             1           False            0.125608
4   tic13.png             1           False            0.081549
5   tic14.png             1           False            0.100680
6   tic15.png             1           False            0.082735
7   tic16.png             1           False            0.074662
8   tic17.png             1           False            0.102638
9   tic18.png             1           False            0.067460
10  tic19.png             0           False            0.089152
11   tic2.png             1           False            0.132817
12  tic20.png             0           