In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import cv2

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the images to [0, 1] range
train_images = train_images / 255.0
test_images = test_images / 255.0

# Create Otsu masks for training and testing data
def create_otsu_masks(images):
    masks = []
    for img in images:
        # Scale back to 0-255 for Otsu thresholding
        img_255 = (img * 255).astype(np.uint8)
        _, binary_mask = cv2.threshold(img_255, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        masks.append(binary_mask)
    return np.array(masks)

# Generate ground truth masks using Otsu thresholding
train_masks = create_otsu_masks(train_images)
test_masks = create_otsu_masks(test_images)

# Reshape for CNN input (add channel dimension)
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)
train_masks = train_masks.reshape(-1, 28, 28, 1)
test_masks = test_masks.reshape(-1, 28, 28, 1)

# Split training data into training and validation sets
train_images, val_images, train_masks, val_masks = train_test_split(
    train_images, train_masks, test_size=0.1, random_state=42
)

# Define IoU metric
def iou_metric(y_true, y_pred):
    # Threshold predictions
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    
    # Calculate intersection and union
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection
    
    # Calculate IoU
    iou = tf.reduce_mean((intersection + 1e-7) / (union + 1e-7))
    return iou

# Build U-Net model for segmentation
def build_unet_model(input_shape):
    inputs = Input(input_shape)
    
    # Encoder (downsampling path)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    # Bridge
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    
    # Decoder (upsampling path)
    up1 = UpSampling2D(size=(2, 2))(conv3)
    up1 = Conv2D(64, 2, activation='relu', padding='same')(up1)
    merge1 = concatenate([conv2, up1], axis=3)
    conv4 = Conv2D(64, 3, activation='relu', padding='same')(merge1)
    conv4 = Conv2D(64, 3, activation='relu', padding='same')(conv4)
    
    up2 = UpSampling2D(size=(2, 2))(conv4)
    up2 = Conv2D(32, 2, activation='relu', padding='same')(up2)
    merge2 = concatenate([conv1, up2], axis=3)
    conv5 = Conv2D(32, 3, activation='relu', padding='same')(merge2)
    conv5 = Conv2D(32, 3, activation='relu', padding='same')(conv5)
    
    # Output layer
    outputs = Conv2D(1, 1, activation='sigmoid')(conv5)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Create and compile the model
model = build_unet_model((28, 28, 1))
model.compile(optimizer=Adam(learning_rate=0.001), 
              loss='binary_crossentropy', 
              metrics=[iou_metric])

# Display model summary
model.summary()

# Train the model
history = model.fit(
    train_images, train_masks,
    validation_data=(val_images, val_masks),
    batch_size=64,
    epochs=10,
    verbose=1
)

# Evaluate on test set
test_loss, test_iou = model.evaluate(test_images, test_masks, verbose=1)
print(f"Test IoU: {test_iou:.4f}")

# Plot training history
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.subplot(1, 2, 2)
plt.plot(history.history['iou_metric'])
plt.plot(history.history['val_iou_metric'])
plt.title('Model IoU')
plt.ylabel('IoU')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.tight_layout()
plt.show()

# Visualize some predictions
def visualize_predictions(images, masks, predictions, num_samples=5):
    indices = np.random.choice(range(len(images)), num_samples, replace=False)
    
    plt.figure(figsize=(15, 5*num_samples))
    for i, idx in enumerate(indices):
        # Original image
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(images[idx].reshape(28, 28), cmap='gray')
        plt.title('Original Image')
        plt.axis('off')
        
        # Ground truth mask
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(masks[idx].reshape(28, 28), cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')
        
        # Predicted mask
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(predictions[idx].reshape(28, 28), cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Make predictions on test set
predictions = model.predict(test_images)
predictions_binary = (predictions > 0.5).astype(np.float32)

# Calculate IoU for each test image
def calculate_iou(y_true, y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    return (intersection + 1e-7) / (union + 1e-7)

test_ious = [calculate_iou(test_masks[i], predictions_binary[i]) for i in range(len(test_masks))]
mean_iou = np.mean(test_ious)
print(f"Mean IoU on test set: {mean_iou:.4f}")

# Visualize some predictions
visualize_predictions(test_images, test_masks, predictions_binary)


Epoch 1/10


TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type uint8 of argument 'x'.