<a href="https://colab.research.google.com/github/MarMarhoun/freelance_work/blob/main/side_projects/NLP_projs/eda_streamlit/segmentation_streamlit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical images segmentation using deep learning model and streamlit and TensorFlow

In [None]:
# First, install necessary libraries:

!pip install tensorflow streamlit segmentation-models

In [None]:
# Now, let's create the training script (train.py):

import tensorflow as tf
from segmentation_models.models import UNet
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score
import pathlib
import segmentation_models as sm
import numpy as np

# Set up data paths
data_dir = pathlib.Path('data/ACDC_2017_Data')

# Define the model
model = UNet(
    'resnet34',
    encoder_weights='imagenet',
    classes=1,
    activation='sigmoid'
)

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=sm.losses.bce_jaccard_loss,
    metrics=[iou_score]
)

# Load and preprocess the data
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir / 'train',
    labels='inferred',
    label_mode='binary',
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(256, 256),
    batch_size=16
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir / 'train',
    labels='inferred',
    label_mode='binary',
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(256, 256),
    batch_size=16
)

# Augment the data
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

# Apply data augmentation
train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))

# Train the model
epochs = 50
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

# Save the model
model.save('model/unet_acdc.h5')

In [None]:
# Next, create the Streamlit app (app.py):

import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np

# Load the model
model = tf.keras.models.load_model('model/unet_acdc.h5')

# Define a function for prediction
def predict(image):
    image = image.resize((256, 256))
    image = np.expand_dims(image, axis=0)
    image = image / 255.0
    prediction = model.predict(image)
    return prediction[0]

# Set up the Streamlit app
st.set_page_config(page_title="Medical Image Segmentation", page_icon=":guardsman:", layout="wide")
st.title("Medical Image Segmentation using Deep Learning")

uploaded_file = st.file_uploader("Upload an image", type="jpg")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    prediction = predict(image)
    st.write("Prediction:", prediction)

In [None]:
# To train the model, run:

python train.py

In [None]:
# After training, run the Streamlit app:

streamlit run app.py

This example demonstrates how to train a deep learning model for medical image segmentation using TensorFlow and the U-Net architecture, and how to deploy it using Streamlit. Note that you may need to adjust the code to fit your specific use case and dataset.

## Advancement

To enhance the previous code and add more advanced features, we will:

Implement a more sophisticated data augmentation pipeline using Albumentations.

Use a learning rate scheduler to adjust the learning rate during training.
Add a callback to save the best model based on validation loss.
Implement a sliding window approach for inference on high-resolution images

First, install the necessary libraries:

In [None]:
!pip install albumentations tensorflow-addons

In [None]:
import tensorflow as tf
from segmentation_models.models import UNet
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score
import pathlib
import segmentation_models as sm
import albumentations as A
import numpy as np
from tensorflow_addons.optimizers import LearningRateScheduler
from tensorflow.keras.callbacks import ModelCheckpoint

# Set up data paths
data_dir = pathlib.Path('data/ACDC_2017_Data')

# Define the model
model = UNet(
    'resnet34',
    encoder_weights='imagenet',
    classes=1,
    activation='sigmoid'
)

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=sm.losses.bce_jaccard_loss,
    metrics=[iou_score]
)

# Define data augmentation
data_augmentation = A.Compose([
    A.RandomRotate(limit=15),
    A.RandomBrightness(limit=0.1),
    A.RandomContrast(limit=0.1),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.5),
    A.Cutout(num_holes=8, max_hole_size=64, p=0.5),
])

# Load and preprocess the data
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir / 'train',
    labels='inferred',
    label_mode='binary',
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(256, 256),
    batch_size=16
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir / 'train',
    labels='inferred',
    label_mode='binary',
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(256, 256),
    batch_size=16
)

# Apply data augmentation
train_ds = train_ds.map(lambda x, y: (data_augmentation(image=x)['image'], y))

# Define a learning rate scheduler
lr_scheduler = LearningRateScheduler(schedule=lambda epoch: 1e-4 * 0.1**(epoch / 10))

# Define a callback to save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
    filepath='model/unet_acdc_best.h5',
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True
)

# Train the model
epochs = 100
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=[lr_scheduler, checkpoint_callback]
)

# Save the final model
model.save('model/unet_acdc_final.h5')

Update the app.py script to include a sliding window approach for inference on high-resolution images:

In [None]:
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2

# Load the best model
model = tf.keras.models.load_model('model/unet_acdc_best.h5')

# Define a function for prediction with a sliding window
def predict_sliding_window(image, window_size, step_size):
    prediction = np.zeros(image.shape[:2])
    for i in range(0, image.shape[0] - window_size[0], step_size):
        for j in range(0, image.shape[1] - window_size[1], step_size):
            window = image[i:i + window_size[0], j:j + window_size[1]]
            window = cv2.resize(window, (256, 256))
            window = np.expand_dims(window, axis=0)
            window = window / 255.0
            segmentation = model.predict(window)
            prediction[i:i + window_size[0], j:j + window_size[1]] = segmentation[0]
    return prediction

# Define a function for prediction
def predict(image):
    image = image.resize((256, 256))
    image = np.expand_dims(image, axis=0)
    image = image / 255.0
    prediction = model.predict(image)
    return prediction[0]

# Set up the Streamlit app
st.set_page_config(page_title="Medical Image Segmentation", page_icon=":guardsman:", layout="wide")
st.title("Medical Image Segmentation using Deep Learning")

uploaded_file = st.file_uploader("Upload an image", type="jpg")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Sliding window for high-resolution images
    if image.height > 256 or image.width > 256:
        st.write("Processing high-resolution image...")
        window_size = (256, 256)
        step_size = (128, 128)
        prediction = predict_sliding_window(np.array(image), window_size, step_size)
    else:
        prediction = predict(image)

    # Display the prediction
    prediction_image = Image.fromarray((prediction * 255).astype(np.uint8))
    st.image(prediction_image, caption='Prediction', use_column_width=True)

In [None]:
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2

# Load the best model
model = tf.keras.models.load_model('model/unet_acdc_best.h5')

# Define a function for prediction with a sliding window
def predict_sliding_window(image, window_size, step_size):
    prediction = np.zeros(image.shape[:2])
    for i in range(0, image.shape[0] - window_size[0], step_size):
        for j in range(0, image.shape[1] - window_size[1], step_size):
            window = image[i:i + window_size[0], j:j + window_size[1]]
            window = cv2.resize(window, (256, 256))
            window = np.expand_dims(window, axis=0)
            window = window / 255.0
            segmentation = model.predict(window)
            prediction[i:i + window_size[0], j:j + window_size[1]] = segmentation[0]
    return prediction

# Define a function for prediction
def predict(image):
    image = image.resize((256, 256))
    image = np.expand_dims(image, axis=0)
    image = image / 255.0
    prediction = model.predict(image)
    return prediction[0]

# Set up the Streamlit app
st.set_page_config(page_title="Medical Image Segmentation", page_icon=":guardsman:", layout="wide")
st.title("Medical Image Segmentation using Deep Learning")

uploaded_file = st.file_uploader("Upload an image", type="jpg")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Sliding window for high-resolution images
    if image.height > 256 or image.width > 256:
        st.write("Processing high-resolution image...")
        window_size = (256, 256)
        step_size = (128, 128)
        prediction = predict_sliding_window(np.array(image), window_size, step_size)
    else:
        prediction = predict(image)

    # Display the prediction
    prediction_image = Image.fromarray((prediction * 255).astype(np.uint8))
    st.image(prediction_image, caption='Prediction', use_column_width=True)

In [None]:
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2

# Load the best model
model = tf.keras.models.load_model('model/unet_acdc_best.h5')

# Define a function for prediction with a sliding window
def predict_sliding_window(image, window_size, step_size):
    prediction = np.zeros(image.shape[:2])
    for i in range(0, image.shape[0] - window_size[0], step_size):
        for j in range(0, image.shape[1] - window_size[1], step_size):
            window = image[i:i + window_size[0], j:j + window_size[1]]
            window = cv2.resize(window, (256, 256))
            window = np.expand_dims(window, axis=0)
            window = window / 255.0
            segmentation = model.predict(window)
            prediction[i:i + window_size[0], j:j + window_size[1]] = segmentation[0]
    return prediction

# Define a function for prediction
def predict(image):
    image = image.resize((256, 256))
    image = np.expand_dims(image, axis=0)
    image = image / 255.0
    prediction = model.predict(image)
    return prediction[0]

# Set up the Streamlit app
st.set_page_config(page_title="Medical Image Segmentation", page_icon=":guardsman:", layout="wide")
st.title("Medical Image Segmentation using Deep Learning")

uploaded_file = st.file_uploader("Upload an image", type="jpg")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Sliding window for high-resolution images
    if image.height > 256 or image.width > 256:
        st.write("Processing high-resolution image...")
        window_size = (256, 256)
        step_size = (128, 128)
        prediction = predict_sliding_window(np.array(image), window_size, step_size)
    else:
        prediction = predict(image)

    # Display the prediction
    prediction_image = Image.fromarray((prediction * 255).astype(np.uint8))
    st.image(prediction_image, caption='Prediction', use_column_width=True)

This script now includes a sliding window approach for inference on high-resolution images. The script checks if the uploaded image's height or width is greater than 256 pixels. If so, it uses the predict_sliding_window function to generate the prediction. Otherwise, it uses the predict function to generate the prediction. The prediction is then displayed using Streamlit.

To run the app, execute the following command in the terminal:

In [None]:
streamlit run app.py

This will start a local Streamlit server, and you can access the app by opening a web browser and navigating to http://localhost:8501.