
# AMS-UNet
This notebook contains the code for code for training and testing the AMS-UNET model, in order to start training the assumption is the data pre-processing is already compleyed

before starting create the following folders in the root folder:
```
- AMSUNET: root folder
    |
    |_ data 
    |   |_ train: folder for training data
    |   |   |_ image: folder contains training images
    |   |   |_ label: folder contains training masks
    |   |
    |   |_ test: folder for testing data
    |   |   |_ image: folder contains testing images
    |   |   |_ label: folder contains testing masks
    |   |
    |   |_ predictions: folder to save the predicted masks
    |
    |_ trained_models
    |   |_ <trained_model_name>.hdf5: trained model file
    |
    |_ experiments
    |   |_ postprocessed_predictions: folder for saving the postprocessed images
    |   |_ raw_analysis: folder for raw experiment analysis
    |       |_ <file>.csv: CSV containing raw metrics
    |
    |_ ams_uset_notebook.ipynb
    |
    |_ postprocessing.py
    |
    |_ metrics.py
    |
    |_ model.py
        

```

# 1. Loading Libraries

Import necessary Libraries and necessary functions

In [22]:
import os
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt


from model import unet
from data import trainGenerator, testGenerator, saveResult

# 2. Setup Training Variables

In [23]:
# Define the height of the input images
im_height = 512
# Define the width of the input images
im_width = 512


# Path to the folder containing training data (images and labels)
train_folder = './data/train'
# Path to the folder where trained models will be saved
models_folder = './trained_models'
model_name = 'model1.hdf5'


# Number of epochs for training (how many complete passes through the training data)
epochs = 2
# Batch size for training (number of samples processed before the model updates)
batch_size = 5

In [24]:
# Function to calculate the number of samples in a directory
def count_samples(data_path, subfolder):
    # Get the path to the specific subfolder (e.g., 'image' or 'label')
    folder_path = os.path.join(data_path, subfolder)
    # Count the number of files in the directory
    return len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])


In [None]:

# Count the total number of images in the training folder
total_samples = count_samples(train_folder, 'image')

# Dynamically calculate steps per epoch
# Steps per epoch is the number of batches required to process all training samples
steps_per_epoch = total_samples // batch_size
if total_samples % batch_size != 0:
    steps_per_epoch += 1  # Add an extra step if there are leftover samples in the batch

print(f"Total samples: {total_samples}")
print(f"Batch size: {batch_size}")
print(f"Steps per epoch: {steps_per_epoch}")

# 3. Training Data Generator

In [26]:
# Define data augmentation parameters for the training generator
data_gen_args = dict(
    rotation_range=0.2,         # Randomly rotate images within a range of ±20%
    width_shift_range=0.05,     # Randomly shift the image width by ±5% of the total width
    height_shift_range=0.05,    # Randomly shift the image height by ±5% of the total height
    shear_range=0.05,           # Apply random shearing transformations within ±5% range
    zoom_range=0.05,            # Randomly zoom in/out on images within ±5% range
    horizontal_flip=True,       # Randomly flip images horizontally
    fill_mode='nearest'         # Fill any gaps in transformed images using nearest neighbor pixels
)

# Create a data generator for training
myGene = trainGenerator(
    batch_size=batch_size,          # Number of images per batch (defined as `batch_size`)
    train_path=train_folder,        # Path to the folder containing training data (set to `train_folder`)
    image_folder='image',           # Subfolder name containing training images ('image')
    mask_folder='label',            # Subfolder name containing corresponding labels/masks ('label')
    aug_dict=data_gen_args,         # Dictionary of data augmentation parameters (`data_gen_args`)
    save_to_dir=None,               # Augmented images will not be saved to any directory
    target_size=(im_height, im_width)  # Resize all input images and masks to the specified dimensions (`(im_height, im_width)`)
)



### view one of the images generated by the generator 

notice how the data augmentation is applied on the images

In [None]:
# Get one batch of images and masks from the generator
image_batch, mask_batch = next(myGene)

# Display the first image and its corresponding mask in the batch
plt.figure(figsize=(10, 5))
# Display the image
plt.subplot(1, 2, 1)
plt.imshow(image_batch[0].squeeze(), cmap="gray")  # Use squeeze() to remove extra dimensions
plt.title("Image")
plt.axis("off")
# Display the mask
plt.subplot(1, 2, 2)
plt.imshow(mask_batch[0].squeeze(), cmap="gray")  # Use squeeze() to remove extra dimensions
plt.title("Mask")
plt.axis("off")
plt.show()

# 4. Define and train the model

In [None]:
# Create a U-Net model with the specified input size
# Input size is defined as (im_height, im_width, 1), where 1 represents one channel for grayscale images
model = unet(input_size=(im_height, im_width, 1))

# Define a callback to save the model during training
model_checkpoint = ModelCheckpoint(
    f'{models_folder}/{model_name}',    # Path where the model will be saved
    monitor='loss',                     # Monitor the loss during training
    verbose=1,                          # Print messages about the saving process
    save_best_only=True                 # Save only the model with the best (lowest) loss
)

# Train the U-Net model using the data generator
history = model.fit_generator(
    myGene,                                 # The training data generator
    steps_per_epoch = steps_per_epoch,      # Number of batches of data per epoch
    epochs = epochs,                        # Number of epochs for training
    callbacks=[model_checkpoint]            # Use the ModelCheckpoint callback to save the best model
)


In [None]:
# Visualize the training history
plt.figure(figsize=(12, 4))

# Plot training loss
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()

# If accuracy metric is available
if 'accuracy' in history.history:
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy')
    plt.legend()

plt.show()

# 5. Testing the model

### Define the testing variables

In [30]:
# Define the height of the input images for testing
im_height = 1024

# Define the width of the input images for testing
im_width = 1024

# Path to the folder containing testing images
test_folder = './data/test/image'

# Path to save the predicted masks
predictions_folder = "./data/predictions/"

### Load the trained model

In [None]:
# Load the pre-trained model
model_path = f'{models_folder}/{model_name}'
model = unet(input_size=(im_height, im_width, 1))
model.load_weights(model_path)

### Setup Testing Data Generator

In [None]:
# Count the total number of images in the training folder
total_test_samples = count_samples(test_folder, '')

print(f"Total samples: {total_test_samples}")

In [33]:
# Generate test data
testGene = testGenerator(test_folder, target_size=(im_height, im_width))

# Generate test data and collect file names
test_data = list(testGenerator(test_folder, target_size=(512, 512)))
# Separate images and file names
test_images, file_names = zip(*test_data)

### Run the prediction on the testing data

In [None]:
# Run predictions
results = model.predict(testGene, steps=total_test_samples, verbose=1)

In [None]:
# Display the first 2 predicted masks
plt.figure(figsize=(10, 5))
# Display the mask
plt.subplot(1, 2, 1)
plt.imshow(results[0,:,:,:].squeeze(), cmap="gray")  # Use squeeze() to remove extra dimensions
plt.title("Mask 1")
plt.axis("off")
# Display the mask
plt.subplot(1, 2, 2)
plt.imshow(results[1,:,:,:].squeeze(), cmap="gray")  # Use squeeze() to remove extra dimensions
plt.title("Mask 2")
plt.axis("off")
plt.show()

In [36]:
# Save results
saveResult(predictions_folder, results, file_names)

## 5.1 Post Processing

In [37]:
from postprocessing import gaus_otsu_thresh

In [38]:
# Define the folder where post-processed masks will be saved
postprocessed_masks_folder = './experiments/postprocessed_predictions/'

In [None]:
# Apply Gaussian blur and Otsu thresholding on the predicted masks
# gaus_otsu_thresh: Function to perform Gaussian blur followed by Otsu thresholding
# predictions_folder: Path to the folder containing predicted masks (input folder)
# postprocessed_masks_folder: Path to save the post-processed masks (output folder)
processed_images = gaus_otsu_thresh(predictions_folder, postprocessed_masks_folder)

In [None]:
# Display the first 2 predicted masks
plt.figure(figsize=(10, 5))
# Display the mask
plt.subplot(1, 2, 1)
plt.imshow(processed_images[0], cmap="gray")  # Use squeeze() to remove extra dimensions
plt.title("Mask 1")
plt.axis("off")
# Display the mask
plt.subplot(1, 2, 2)
plt.imshow(processed_images[1], cmap="gray")  # Use squeeze() to remove extra dimensions
plt.title("Mask 2")
plt.axis("off")
plt.show()

## 5.2 calculate the evaluation metrics

In [41]:
import numpy as np
import pandas as pd
from metrics import find_metrics, tpr, fpr, f_score, iou_score, sensitivity, specificity, accuracy 

In [42]:
# save folder
save_to = './experiments/raw_analysis' # change output folder as needed
no_images = total_test_samples
csv_file_name = 'metrics.csv'

# Load data
label_mono_path = postprocessed_masks_folder # folder for the predictions that we need to analyze
label_truth_path = './data/test/label/' # ground truth folder

In [43]:

# Metrics
image_list = []
TPR = []
FPR = []
F_SCORE = []
IOU = []
IOU_SCORE = []
SENSITIVITY = []
SPECIFICITY = []
ACCURACY = []
TP = []
TN = []
FP = []
FN = []


In [None]:
i = 0
predictions = os.listdir(label_mono_path)
print('working on {}'.format(label_mono_path))
print(predictions)
labels = os.listdir(label_truth_path)
print(labels)


In [None]:
for prediction, label in zip(predictions, labels):
    # print(f'analyzing {prediction}') # optional

    total_bg_pxl_truth, total_obj_pxl_truth, tp, tn, fp, fn = find_metrics(label_truth_path + label,
                                                                        label_mono_path + prediction)

    # Append metrics for each image
    image_list.append(prediction)
    TP.append(tp)
    TN.append(tn)
    FP.append(fp)
    FN.append(fn)
    TPR.append(tpr(tp, total_obj_pxl_truth))
    FPR.append(fpr(fp, total_bg_pxl_truth))
    F_SCORE.append(f_score(tp, fp, fn))
    IOU_SCORE.append(iou_score(label_truth_path + label, label_mono_path + prediction))
    SENSITIVITY.append(sensitivity(tp, fn))
    SPECIFICITY.append(specificity(fp, tn))
    ACCURACY.append(accuracy(tp, fp, tn, fn))

    # Append metrics for each image
    image_list.append(prediction)
    TP.append(tp)
    TN.append(tn)
    FP.append(fp)
    FN.append(fn)
    TPR.append(tpr(tp, total_obj_pxl_truth))
    FPR.append(fpr(fp, total_bg_pxl_truth))
    F_SCORE.append(f_score(tp, fp, fn))
    IOU_SCORE.append(iou_score(label_truth_path + label, label_mono_path + prediction))
    SENSITIVITY.append(sensitivity(tp, fn))
    SPECIFICITY.append(specificity(fp, tn))
    ACCURACY.append(accuracy(tp, fp, tn, fn))

    i += 1
    if i % 10 == 0:
        print('analyzed ', i, ' images')

# Create a DataFrame
df = pd.DataFrame({
    'Image': image_list,
    'TPR': TPR,
    'FPR': FPR,
    'F-Score': F_SCORE,
    'IOU Score': IOU_SCORE,
    'Sensitivity': SENSITIVITY,
    'Specificity': SPECIFICITY,
    'Accuracy': ACCURACY,
    'TP': TP,
    'TN': TN,
    'FP': FP,
    'FN': FN
})


print('average TPR: ', np.average(TPR))
print('average FPR: ', np.average(FPR))
print('average F-Score: ', np.average(F_SCORE))
print('average IOU: ', np.average(IOU))
print('average sensitivity: ', np.average(SENSITIVITY))
print('average specificity: ', np.average(SPECIFICITY))
print('average accuracy: ', np.average(ACCURACY))
print('average IOU score: ', np.average(IOU_SCORE))

# Save to CSV
df.to_csv(os.path.join(save_to, csv_file_name), index=False)
print(f"Metrics saved to {save_to}/{csv_file_name}")

In [None]:
df.head()