# Create your model

#### 1. Install Dependencies and Setup

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import matplotlib
import os
import cv2
import yaml
import re
# from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
# from sklearn.utils.class_weight import compute_class_weight
# from tensorflow.keras.regularizers import l2
# from tqdm import tqdm
import openpyxl
from openpyxl.drawing.image import Image
from openpyxl.utils import get_column_letter
from io import BytesIO
from PIL import Image as PILImage
import seaborn as sns
import tempfile
from mpl_toolkits.axes_grid1 import make_axes_locatable

#### 2. Load Image Data

In [None]:
### RENAME THIS FILE WITH THE NAME OF THE METABOLITE YOU WANT ###

def get_metabolite_name():
    current_dir = os.getcwd()
    
    # Get the parent directory name (which should be the metabolite name)
    metabolite_name = os.path.basename(current_dir)
    
    return metabolite_name

##### Select Version

In [3]:
if __name__ == "__main__":
    metabolite_name = get_metabolite_name()
    version = 1 ## SELECT THE VERSION ##

In [8]:
def construct_path(metabolite_name, path_type):
    base_dir = r'C:\Users\PC\Documents\BIOSFER\CNN' # modify the path here
    
    if path_type in ["invalid", "valid"]:
        return os.path.join(base_dir, "data", metabolite_name, path_type)
    elif path_type == "models":
        return os.path.join(base_dir, "models", metabolite_name)
    elif path_type == "excel":
        return os.path.join(base_dir, f"model_results_{metabolite_name}.xlsx")
    else:
        raise ValueError("Invalid path type")
    
path_invalid = construct_path(metabolite_name, "invalid")
path_valid = construct_path(metabolite_name, "valid")

In [None]:
def load_images(path, label):
    images = []
    labels = []
    data = []
    for filename in os.listdir(path):
        if filename.endswith('.png'):
            img_path = os.path.join(path, filename)
            # Read PNG with all channels
            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            # Adding this correction turns the number of channels from 4 to 3, which affects the condition below!!
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if img is not None and img.shape == (600, 800, 3):
                images.append(img)
                labels.append(label)
                data.append((img, label, filename))
    df = pd.DataFrame(data, columns=['Image', 'Label', 'Filename'])
    df.set_index('Filename', inplace=True)

    return images, labels, df


# Load NA (invalid) images
na_images, na_labels, df_0 = load_images(path_invalid, 0)  # 0 for invalid

# Load normal (valid) images
normal_images, normal_labels, df_1 = load_images(path_valid, 1)  # 1 for valid

# Combine the data
X = na_images + normal_images
Y = na_labels + normal_labels

# Convert lists to numpy arrays
x = np.array(X)
Y = np.array(Y)

# Merging the two dfs
df = pd.concat([df_0, df_1])

In [None]:
# Better practice is to create dataframes, and have each row with its identifier, the image, and the label
print("Invalid df description")
print("----------------------------------")
print(df_0.shape)
print(df_0.dtypes)
print(" ")
print("Valid df description")
print("----------------------------------")
print(df_1.shape)
print(df_1.dtypes)
print(" ")
print("Valid complete df description")
print("----------------------------------")
print(df.shape)
print(df.dtypes)
print(" ")

#### 3. Load YAML config

In [8]:
def load_config(metabolite_name, version):
    yaml_file = f'config_{metabolite_name.lower()}.yml'
    with open(yaml_file, 'r') as file:
        config = yaml.safe_load(file)
    
    # Convert version to integer if it's a whole number
    version_int = int(version) if isinstance(version, float) and version.is_integer() else version
    
    if version_int not in config['versions']:
        raise ValueError(f"Version {version} not found in config file")
    
    return config['versions'][version_int], str(version)

config, version_str = load_config(metabolite_name, version)

#### 4. data pre-processing

In [9]:
# Preprocess images with normalization and standardization
def preprocess_images(images, coords, crop_size, resize_shape):
    processed_images = []

    for img in images:
        # Crop the image
        cropped_img = img[coords[0]:(
            coords[0] + crop_size[0]), coords[1]:(coords[1]+crop_size[1])]

        # Resize the image
        resized_img = cv2.resize(
            cropped_img, resize_shape, interpolation=cv2.INTER_AREA)

        # Normalize pixel values
        normalized_img = resized_img.astype(np.float32) / 255.0

        processed_images.append(normalized_img)

    return processed_images

In [10]:
coords = config['preprocess']['coords']
crop_size = config['preprocess']['crop_size']
resize_shape = tuple(config['preprocess']['resize_shape'])


X_processed = preprocess_images(
    df['Image'].tolist(), coords, crop_size, resize_shape)

df['Processed'] = X_processed

#### 5. Split data

In [None]:
# Ensure splits contain at least one NA sample
def ensure_na_in_split(X, Y, na_label=0):

    # Shuffle the data
    np.random.seed(42)
    indices = np.arange(X.shape[0])
    np.random.shuffle(indices)
    X = X[indices]
    Y = Y[indices]

    # Split data into training (80%), validation (15%), and test (5%)
    # First, split into training (80%) and temporary (20%)
    X_train, X_temp, Y_train, Y_temp = train_test_split(
        X, Y, test_size=0.2, random_state=42, stratify=Y)
    # Second, split the temporary set into validation (15% of total) and test (5% of total)
    X_val, X_test, Y_val, Y_test = train_test_split(
        X_temp, Y_temp, test_size=0.25, random_state=42, stratify=Y_temp)

    return X_train, X_val, X_test, Y_train, Y_val, Y_test

# Split and ensure each set has at least one NA sample
X_train, X_val, X_test, Y_train, Y_val, Y_test = ensure_na_in_split(
    df['Processed'], df['Label'])

In [12]:
X_train_array = np.stack(X_train.values)
X_val_array = np.stack(X_val.values)
X_test_array = np.stack(X_test.values)

In [None]:
# Convert to TensorFlow tensors
X_train_tensor = tf.convert_to_tensor(X_train_array, dtype=tf.float32)
X_val_tensor = tf.convert_to_tensor(X_val_array, dtype=tf.float32)
X_test_tensor = tf.convert_to_tensor(X_test_array, dtype=tf.float32)

Y_train_tensor = tf.convert_to_tensor(Y_train, dtype=tf.float32)
Y_val_tensor = tf.convert_to_tensor(Y_val, dtype=tf.float32)
Y_test_tensor = tf.convert_to_tensor(Y_test, dtype=tf.float32)

# Print the shapes of the arrays
print(f'Shape of X_train_tensor: {X_train_tensor.shape}')
print(f'Shape of X_val_tensor: {X_val_tensor.shape}')
print(f'Shape of X_test_tensor: {X_test_tensor.shape}')
print(f'Shape of Y_train_tensor: {Y_train_tensor.shape}')
print(f'Shape of Y_val_tensor: {Y_val_tensor.shape}')
print(f'Shape of Y_test_tensor: {Y_test_tensor.shape}')

# Print the number of NA samples in each split
print(f'Number of NA samples in Y_train: {np.sum(Y_train == 0)}')
print(f'Number of NA samples in Y_val: {np.sum(Y_val == 0)}')
print(f'Number of NA samples in Y_test: {np.sum(Y_test == 0)}')

In [None]:
# Assuming X_test and Y_test are your final test set arrays

# First, let's create a DataFrame that combines all the information
df_all = df.reset_index()  # Reset index to make 'Filename' a column
df_all['Processed'] = df_all['Image'].apply(lambda x: preprocess_images([x], coords, crop_size, resize_shape)[0])

# Now, we need to find which rows in df_all correspond to X_test
# We'll do this by comparing the 'Processed' images

def find_matching_filenames(X_test, df_all):
    filenames = []
    for test_img in X_test:
        # Find the index of the matching processed image
        matching_index = df_all['Processed'].apply(lambda x: np.array_equal(x, test_img)).idxmax()
        filenames.append(df_all.loc[matching_index, 'Filename'])
    return filenames

# Get the filenames for X_test
test_filenames = find_matching_filenames(X_test, df_all)

# Now we can get the filenames for label 0 in the test set
label_0_indices = np.where(Y_test == 0)[0]
label_0_filenames = [test_filenames[i] for i in label_0_indices]

# Print the filenames
print("Test set filenames with label 0:")
for filename in label_0_filenames:
    print(filename)

# Print the count
print(f"\nTotal number of test samples with label 0: {len(label_0_filenames)}")

#### 6. Build Deep Learning Model

In [None]:
exec(config['code'])
model.summary()  # type: ignore

#### 7. Train the model

In [None]:
exec(config['training'])

##### Display and DL activation maps

In [None]:
def display_image_filtered(name_image, model, layer_name, image):
    inp = model.inputs 
    out1 = model.get_layer(layer_name).output  
    feature_map_1 = tf.keras.Model(inputs=inp, outputs=out1)  
    # Resize the image to match the model's input shape
    input_img = tf.image.resize(image, (X_train_tensor.shape[1], X_train_tensor.shape[2]))
    input_img = np.expand_dims(input_img, axis=0)      
    f = feature_map_1.predict(input_img) 
    dim = f.shape[3]
    print(f'{layer_name} | Features Shape: {f.shape}')
    print(f'Dimension {dim}')
    fig = plt.figure(figsize=(30, 30))
    if not os.path.exists(f'results_{name_image}'):
        os.makedirs(f'results_{name_image}')        
    for i in range(dim):
        ax = fig.add_subplot(dim//2, dim//2 + dim%2, i+1)
        ax.axis('off')
        ax.imshow(f[0, :, :, i])
        plt.imsave(f'results_{name_image}/{name_image}_{layer_name}_{i}.jpg', f[0, :, :, i])
    plt.show()

# Display activation maps for ten images in the training set
num_images_to_visualize = 10
for i in range(num_images_to_visualize):
    original_index = df[df['Processed'].apply(lambda x: np.array_equal(x, X_train_tensor[i]))].index[0]
    original_image = df.loc[original_index, 'Image']
    
    # Display original image
    plt.axis('off')
    plt.imshow(original_image)
    plt.show()
    
    # Display activation maps for each layer
    layers = ['conv2d_1','max_pooling2d_1', 'conv2d_2','max_pooling2d_2', 'conv2d_3','max_pooling2d_3']
    for layer in layers:
        display_image_filtered(f'train_image_{i}', model, layer, original_image)

##### Display filters

In [None]:
def display_filter(model, layer_name):
    # Get layer weights
    layer = model.get_layer(layer_name)
    filter, bias = layer.get_weights()
    dim = filter.shape[3]
    
    print(f'{layer_name} | Filter Shape: {filter.shape} Bias Shape: {bias.shape}')
    print(f'Dimension {dim}')
    
    # Normalize filter values
    f_min, f_max = filter.min(), filter.max()
    filter = (filter - f_min) / (f_max - f_min)
    
    # Calculate grid dimensions
    rows = dim // 2
    cols = dim // 2 + dim % 2
    
    # Create figure with proper size and spacing
    fig = plt.figure(figsize=(30, 30))
    plt.subplots_adjust(wspace=0.3)
    
    for i in range(dim):
        ax = fig.add_subplot(rows, cols, i+1)
        
        # Get the filter slice and handle dimensionality
        filter_slice = filter[:, :, :, i]
        
        # For first layer (RGB)
        if filter_slice.shape[-1] == 3:
            display_data = filter_slice
        # For subsequent layers (single channel)
        else:
            # Take the first channel if multiple channels exist
            if len(filter_slice.shape) == 3:
                display_data = filter_slice[:, :, 0]
            else:
                display_data = filter_slice
                
        # Display filter
        im = ax.imshow(display_data, cmap='viridis')
        
        # Create colorbar
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        
        # Add title with value range
        ax.set_title(f'Filter {i+1}\nRange: [{display_data.min():.2f}, {display_data.max():.2f}]')
        
        # Remove axis ticks
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.suptitle(f'Filters from layer: {layer_name}', fontsize=16, y=0.95)
    plt.show()

# Display filters for each convolutional layer
def display_all_filters(model):
    conv_layers = [layer for layer in model.layers if 'conv2d' in layer.name]
    for layer in conv_layers:
        display_filter(model, layer.name)

display_all_filters(model)

#### 8. Plot Performance

In [None]:
matplotlib.use('module://matplotlib_inline.backend_inline')

# Plot Performance
plt.figure(figsize=(11, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'],label='Training Accuracy')  # type: ignore
plt.plot(history.history['val_accuracy'],label='Validation Accuracy')  # type: ignore
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')  # type: ignore
plt.plot(history.history['val_loss'], label='Validation Loss')  # type: ignore
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()

# Evaluate the model on the test set
test_loss, test_accuracy = model.evaluate(  # type: ignore
    X_test_tensor, Y_test_tensor, verbose=2)
print(f"Test Accuracy: {test_accuracy:.4f}")

# Predict and Generate Classification Report

y_pred = model.predict(X_test_tensor)  # type: ignore
y_pred_classes = (y_pred > 0.5).astype(int).reshape(-1)


print("Classification Report:")
print(classification_report(Y_test, y_pred_classes))
print("Confusion Matrix:")
print(confusion_matrix(Y_test, y_pred_classes))

In [37]:
def plot_images_with_probabilities(df, probabilities, save_path=None):
    probabilities = np.squeeze(probabilities)
    num_images = len(df)
    grid_size = int(num_images**0.5)
    if grid_size**2 < num_images:
        grid_size += 1

    fig, axes = plt.subplots(grid_size, grid_size, figsize=(100, 100))
    axes = axes.flatten()

    for i, (_, row) in enumerate(df.iterrows()):
        img = row['Image']  # Use the original image from the DataFrame
        axes[i].imshow(img)
        axes[i].axis('off')  # Hide axis

        # Determine the color based on the probability
        probability = float(probabilities[i])
        if probability >= 0.995:
            color = 'green'
        elif 0.05 < probability < 0.995:
            color = 'orange'
        else:
            color = 'red'

        prob_text = f"{probability * 100:.4f}%"
        axes[i].text(10, 20, prob_text, color='white', fontsize=60,
                     bbox=dict(facecolor=color, alpha=0.5))

    for i in range(num_images, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, format='png')
    else:
        plt.show()
        
# Assuming X_test contains the filenames of the test set images
X_test_filenames = X_test.index if hasattr(X_test, 'index') else X_test

# Create a DataFrame for the test set
df_test = df.loc[X_test_filenames]

In [None]:
# Predict probabilities
y_pred_prob = model.predict(X_test_tensor)  # type: ignore
# Plot original images with probabilities
plot_images_with_probabilities(df_test, y_pred_prob)

#### 9. Excel table compile model result 

In [None]:
def compile_model_results_to_excel(version, model, history, X_test, Y_test, X_test_tensor, Y_test_tensor, df_test, 
                                   metabolite_name,  # This parameter is already included
                                   excel_path=None):
    if excel_path is None:
        excel_path = f'model_results_{metabolite_name}.xlsx'
    matplotlib.use('Agg')
     
    # Load the specific version of the configuration
    try:
        config, version_str = load_config(metabolite_name, version)  # Use both arguments here
    except ValueError as e:
        print(f"Error: {e}")
        return
    except Exception as e:
        print(f"Error loading configuration: {e}")
        return

    # Load existing Excel file or create a new one
    try:
        workbook = openpyxl.load_workbook(excel_path)
    except FileNotFoundError:
        workbook = openpyxl.Workbook()

    # Get or create the sheet for this version
    sheet_name = f'Version {version}'
    if sheet_name in workbook.sheetnames:
        sheet = workbook[sheet_name]
        # Clear the existing content
        for row in sheet[sheet.dimensions]:
            for cell in row:
                cell.value = None
    else:
        sheet = workbook.create_sheet(title=sheet_name)

    # 1. Model Summary
    sheet['A1'] = 'Model Summary'
    stringlist = []
    model.summary(print_fn=lambda x: stringlist.append(x))
    summary_string = "\n".join(stringlist)

    # Save the summary to a temporary text file
    with tempfile.NamedTemporaryFile(delete=False, mode='w', encoding='utf-8', suffix='.txt') as tmpfile:
        tmpfile.write(summary_string)
        tmpfile_path = tmpfile.name

    # Convert the text file to an image
    img_buf = BytesIO()
    plt.figure(figsize=(9, 6))
    plt.text(0.01, 0.99, summary_string, va='top', ha='left', wrap=True, fontsize=10, family='monospace')
    plt.axis('off')
    plt.savefig(img_buf, format='png', bbox_inches='tight', pad_inches=0.1)
    plt.close()
    img_buf.seek(0)

    # Insert the image into the Excel sheet
    img = Image(img_buf)
    sheet.add_image(img, 'A2')
        
    # 2. Training History
    sheet['M1'] = 'Training History'
    sheet['M2'] = 'Accuracy'
    sheet['N2'] = 'Loss'
    sheet['O2'] = 'Val Accuracy'
    sheet['P2'] = 'Val Loss'
    history_df = pd.DataFrame(history.history)
    for r, row in enumerate(history_df.values, start=3):
        for c, value in enumerate(row, start=13):
            sheet.cell(row=r, column=c, value=value)

    # 3. Performance Plots

    # Model Accuracy Plot
    plt.figure(figsize=(7, 4))
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    img_buf = BytesIO()
    plt.savefig(img_buf, format='png')
    plt.close()
    img_buf.seek(0)
    img = Image(img_buf)
    sheet.add_image(img, 'R1')

    # Model Loss Plot
    plt.figure(figsize=(7, 4))
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    img_buf = BytesIO()
    plt.savefig(img_buf, format='png')
    plt.close()
    img_buf.seek(0)
    img = Image(img_buf)
    sheet.add_image(img, 'R22')

    # Confusion Matrix
    y_pred = model.predict(X_test_tensor)
    y_pred_classes = (y_pred > 0.5).astype(int)
    cm = confusion_matrix(Y_test, y_pred_classes)
    plt.figure(figsize=(6, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    img_buf = BytesIO()
    plt.savefig(img_buf, format='png')
    plt.close()
    img_buf.seek(0)
    img = Image(img_buf)
    sheet.add_image(img, 'R43')

    # 4. Image Probabilities
    sheet['AD1'] = 'Image Probabilities'
    img_buf = BytesIO()
    plot_images_with_probabilities(df_test, y_pred, save_path=img_buf)
    img_buf.seek(0)
    img = Image(img_buf)
    sheet.add_image(img, 'AD2')

    # 5. Model Code
    sheet['A35'] = 'Model Code'
    model_code = config.get('code', 'Model code not available') if isinstance(config, dict) else 'Model code not available'
    model_code_lines = model_code.split('\n')
    for i, line in enumerate(model_code_lines, start=36):
        sheet.cell(row=i, column=1, value=line)

    # 6. Training Code
    sheet['A76'] = 'Training Code'
    training_code = config.get('training', 'Training code not available') if isinstance(config, dict) else 'Training code not available'
    training_code_lines = training_code.split('\n')
    for i, line in enumerate(training_code_lines, start=77):
        sheet.cell(row=i, column=1, value=line)

    # Save the workbook
    workbook.save(excel_path)
    print(f"Results for version {version} compiled and saved to {excel_path}")

compile_model_results_to_excel(version, model, history, X_test, Y_test, X_test_tensor, Y_test_tensor, df_test, 
                               metabolite_name)

#### 10. Save the final model

In [None]:
# Save the model
model_save_dir = construct_path(metabolite_name, "models")
model_filename = f'model_{metabolite_name}_v{version_str}.keras'
model_save_path = os.path.join(model_save_dir, model_filename)

os.makedirs(model_save_dir, exist_ok=True)  # Create directory if it doesn't exist
model.save(model_save_path)  # Attempt to save the model

print(f"Model saved as: {model_save_path}")