# Semantic Segmentation of Tomatoes with U-Net

## Table of Contents

1. [Introduction](#1--introduction)
2. [Setup and Imports](#2--setup-and-imports)
3. [Data Loading and Preprocessing](#3--data-loading-and-preprocessing)
4. [Data Augmentation](#4--data-augmentation)
5. [U-Net Model Definition](#5--u-net-model-definition)
6. [Training](#6--training)
7. [Evaluation](#7--evaluation)
8. [Visualization](#8--visualization)

## 1- Introduction
This notebook implements a U-Net model for semantic segmentation of tomato images. Semantic segmentation is the task of classifying each pixel in an image into a specific category or class. The U-Net architecture is particularly effective for biomedical image segmentation tasks.

## 2- Setup and Imports

In [None]:
import os
from glob import glob
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.metrics import MeanIoU


## 3- Data Loading and Preprocessing

In [None]:
# Define paths
base_dir = 'data/tomato_dataset'
train_images = sorted(glob(os.path.join(base_dir, 'Train', '.png')))
train_masks = sorted(glob(os.path.join(base_dir, 'Mask', '.png')))
val_images = sorted(glob(os.path.join(base_dir, 'Test', '.png')))
val_masks = sorted(glob(os.path.join(base_dir, 'Mask2', '.png')))
test_images = sorted(glob(os.path.join(base_dir, 'Train2', '.png')))
test_masks = sorted(glob(os.path.join(base_dir, 'Test2', '.png')))


## 4- Data Augmentation

In [None]:
data_gen_args = dict(rotation_range=20, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, zoom_range=0.1, horizontal_flip=True, fill_mode='nearest')
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
# Fit generators
seed = 42
image_datagen.fit(X_train, augment=True, seed=seed)
mask_datagen.fit(y_train, augment=True, seed=seed)
train_generator = zip(image_datagen.flow(X_train, batch_size=16, seed=seed), mask_datagen.flow(y_train, batch_size=16, seed=seed))


## 5- U-Net Model Definition

In [None]:
def unet(input_size=(128,128,3)):
    inputs = layers.Input(input_size)
    # Contracting path
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D()(c1)
    c2 = layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, 3, activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D()(c2)
    # Bottleneck
    c5 = layers.Conv2D(512, 3, activation='relu', padding='same')(p2)
    c5 = layers.Conv2D(512, 3, activation='relu', padding='same')(c5)
    # Expanding path
    u6 = layers.Conv2DTranspose(128, 2, strides=2, padding='same')(c5)
    u6 = layers.concatenate([u6, c2])
    c6 = layers.Conv2D(128, 3, activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(128, 3, activation='relu', padding='same')(c6)
    u7 = layers.Conv2DTranspose(64, 2, strides=2, padding='same')(c6)
    u7 = layers.concatenate([u7, c1])
    c7 = layers.Conv2D(64, 3, activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(64, 3, activation='relu', padding='same')(c7)
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(c7)
    model = models.Model(inputs=[inputs], outputs=[outputs])
    return model
# Instantiate and compile
model = unet()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', MeanIoU(num_classes=2)])
model.summary()


## 6- Training

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint('unet_tomato.h5', save_best_only=True),
    tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
]
history = model.fit(train_generator, validation_data=(X_val, y_val), epochs=50, steps_per_epoch=len(X_train)//16, callbacks=callbacks)


## 7- Evaluation

In [None]:
# Load best model
model.load_weights('unet_tomato.h5')
# Evaluate on test set
results = model.evaluate(X_test, y_test)
print(dict(zip(model.metrics_names, results)))


## 8- Visualization

In [None]:
# Predict and plot sample
idx = 0  # change index for different samples
pred = model.predict(X_test[idx:idx+1])[0,:,:,0] > 0.5
plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.title('Image'); plt.imshow(X_test[idx])
plt.subplot(1,3,2); plt.title('Ground Truth'); plt.imshow(y_test[idx,:,:,0], cmap='gray')
plt.subplot(1,3,3); plt.title('Prediction'); plt.imshow(pred, cmap='gray')
plt.show()
