In [None]:
import os
import cv2
import keras
import numpy as np
import pandas as pd
import tensorflow as tf     
from matplotlib import pyplot as plt

In [None]:
data = tf.keras.utils.image_dataset_from_directory('../data')

In [None]:
data_iterator = data.as_numpy_iterator()


In [None]:
batch = data_iterator.next()

In [None]:
fig, ax = plt.subplots(ncols=4, figsize=(20,20))
for idx, img in enumerate(batch[0][:4]):
    ax[idx].imshow(img.astype(int))
    ax[idx].title.set_text(batch[1][idx])

In [None]:
data = data.map(lambda x,y: (x/255, y))
data.as_numpy_iterator().next()

In [None]:
train_size = int(len(data)*.7)
val_size = int(len(data)*.2)
test_size = int(len(data)*.1)

In [None]:
train = data.take(train_size)
val = data.skip(train_size).take(val_size)
test = data.skip(train_size+val_size).take(test_size)

In [None]:
len(data)

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout

model = Sequential()

model.add(Conv2D(16, (3,3), 1, activation='relu', input_shape=(256,256,3)))
model.add(MaxPooling2D())
model.add(Conv2D(32, (3,3), 1, activation='relu'))
model.add(MaxPooling2D())
model.add(Conv2D(16, (3,3), 1, activation='relu'))
model.add(MaxPooling2D())
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile('adam', loss=tf.losses.BinaryCrossentropy(), metrics=['accuracy'])
model.save('model.keras')

In [None]:
model.summary()

In [None]:
logdir='logs'

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

In [None]:
hist = model.fit(train, epochs=20, validation_data=val, callbacks=[tensorboard_callback])

In [None]:
model.save("model.keras")

In [None]:
fig = plt.figure()
plt.plot(hist.history['loss'], color='teal', label='loss')
plt.plot(hist.history['val_loss'], color='orange', label='val_loss')
fig.suptitle('Loss', fontsize=20)
plt.legend(loc="upper left")
plt.show()

In [None]:
fig = plt.figure()
plt.plot(hist.history['accuracy'], color='teal', label='accuracy')
plt.plot(hist.history['val_accuracy'], color='orange', label='val_accuracy')
fig.suptitle('Accuracy', fontsize=20)
plt.legend(loc="upper left")
plt.show()

In [None]:
def mask_applier(model_path: str, csv_path: str, cells_path: str, images_path: str) -> None:

    """
    Adds bounding boxes to the cells of the original image by filtering out the noise with a ml model

    Parameters:
    - model_path:  str : path to the keras model to use
    - csv_path:    str : path to the csv with the data of the masks - output of sam generated masks
    - cells_path:  str : path to the individual cells images - output of sam generated masks
    - images_path: str : path to the full images

    Outputs:
    If it does not exists, creates the detected_cells folder where it stores the full images with 
    the new bounding boxed added
    """

    OUTPUT_PATH = './detected_cells' 

    model = keras.models.load_model(model_path)
    df = pd.read_csv(csv_path)
    imgs = sorted(os.listdir(cells_path))
    os.makedirs(OUTPUT_PATH, exist_ok=True)

    prv_image = ""

    for idx, file in enumerate(imgs):
        print(f"Image: {idx + 1}/{len(imgs)}", end='\r')

        #Cargo imagen
        image_name = os.fsdecode(file)
        image_path = cells_path + image_name
        img = cv2.imread(image_path)

        cell_id = image_name.split('_')[1].split('.')[0]
        img_nbr = image_name.split('_')[0]
        og_image = f'{img_nbr}.png'

        #Load the full image only when there is an image change
        if (prv_image != og_image):
            full_image_path = os.path.join(images_path, og_image) 
            full_image = cv2.imread(full_image_path)
        prv_image = og_image

        #Get prediction, 0 for cell 1 for noise
        resize = tf.image.resize(img, (256,256))
        yhat = model.predict(np.expand_dims(resize/255, 0),verbose = 0)
        yhat = np.where(yhat[0][0] >= 0.5, 1, 0)    

        #If there is a cell, draw a rectangle
        if not yhat:
            row = df.loc[(df['image'] == og_image) & (df['cell_id'] == int(cell_id))].to_dict('records')[0]
        
            x, y, w, h = row['x'], row['y'], row['w'], row['h']
            cv2.rectangle(full_image, (x, y), (x + w, y + h), 255, 10)
        
        #Save the new image when there is an image change or is the last file
        if (prv_image != og_image) or (idx + 1 == len(imgs)):
            cv2.imwrite(os.path.join(OUTPUT_PATH, f"{img_nbr}.png"), full_image)