In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the images to [0, 1] range
train_images = train_images / 255.0
test_images = test_images / 255.0

# Function to create circle masks (as in Q1(b))
def create_circle_masks(images):
    circle_masks = []
    for img in images:
        # Convert to proper format for OpenCV
        img_uint8 = (img * 255).astype(np.uint8)
        
        # Apply Otsu's thresholding
        _, binary_mask = cv2.threshold(img_uint8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Find contours in the binary mask
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # Create an empty mask
        circle_mask = np.zeros_like(binary_mask)
        
        if contours:
            # Find the largest contour (should be the digit)
            largest_contour = max(contours, key=cv2.contourArea)
            
            # Find the minimum enclosing circle
            (x, y), radius = cv2.minEnclosingCircle(largest_contour)
            
            # Convert to integer values
            center = (int(x), int(y))
            radius = int(radius)
            
            # Draw the circle on the mask
            cv2.circle(circle_mask, center, radius, 255, -1)
        
        circle_masks.append(circle_mask / 255.0)  # Normalize to [0, 1]
    
    return np.array(circle_masks)

# Generate circle masks
train_circle_masks = create_circle_masks(train_images)
test_circle_masks = create_circle_masks(test_images)

# Reshape for CNN input (add channel dimension)
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)
train_circle_masks = train_circle_masks.reshape(-1, 28, 28, 1)
test_circle_masks = test_circle_masks.reshape(-1, 28, 28, 1)

# One-hot encode the labels
train_labels_one_hot = to_categorical(train_labels, 10)
test_labels_one_hot = to_categorical(test_labels, 10)

# Split training data into training and validation sets
train_images, val_images, train_circle_masks, val_circle_masks, train_labels_one_hot, val_labels_one_hot = train_test_split(
    train_images, train_circle_masks, train_labels_one_hot, test_size=0.1, random_state=42
)

# Define IoU metric for circlization
def iou_metric(y_true, y_pred):
    # Threshold predictions
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    
    # Calculate intersection and union
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection
    
    # Calculate IoU
    iou = (intersection + 1e-7) / (union + 1e-7)
    return tf.reduce_mean(iou)

# Build a model for classification with circlization
def build_classification_circlization_model(input_shape, num_classes):
    inputs = Input(input_shape)
    
    # Shared convolutional layers
    x = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(64, 3, activation='relu', padding='same')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(128, 3, activation='relu', padding='same')(x)
    
    # Classification branch
    class_branch = Flatten()(x)
    class_branch = Dense(128, activation='relu')(class_branch)
    class_branch = Dropout(0.5)(class_branch)
    class_output = Dense(num_classes, activation='softmax', name='classification')(class_branch)
    
    # Circlization branch (segmentation)
    circle_branch = Conv2D(64, 3, activation='relu', padding='same')(x)
    circle_branch = Conv2D(32, 3, activation='relu', padding='same')(circle_branch)
    circle_branch = Conv2D(16, 3, activation='relu', padding='same')(circle_branch)
    # Upsampling to original size
    circle_branch = tf.keras.layers.UpSampling2D(size=(4, 4))(circle_branch)
    circle_output = Conv2D(1, 1, activation='sigmoid', name='circlization')(circle_branch)
    
    model = Model(inputs=inputs, outputs=[class_output, circle_output])
    return model

# Create and compile the model
model = build_classification_circlization_model((28, 28, 1), 10)
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss={
        'classification': 'categorical_crossentropy',
        'circlization': 'binary_crossentropy'
    },
    metrics={
        'classification': 'accuracy',
        'circlization': iou_metric
    }
)

# Display model summary
model.summary()

# Train the model
history = model.fit(
    train_images,
    {'classification': train_labels_one_hot, 'circlization': train_circle_masks},
    validation_data=(
        val_images,
        {'classification': val_labels_one_hot, 'circlization': val_circle_masks}
    ),
    batch_size=64,
    epochs=15,
    verbose=1
)

# Evaluate on test set
test_results = model.evaluate(
    test_images,
    {'classification': test_labels_one_hot, 'circlization': test_circle_masks},
    verbose=1
)

print(f"Test Loss (Total): {test_results[0]:.4f}")
print(f"Test Classification Loss: {test_results[1]:.4f}")
print(f"Test Circlization Loss: {test_results[2]:.4f}")
print(f"Test Classification Accuracy: {test_results[3]:.4f}")
print(f"Test Circlization IoU: {test_results[4]:.4f}")

# Make predictions on test set
class_preds, circle_preds = model.predict(test_images)
class_preds_labels = np.argmax(class_preds, axis=1)
circle_preds_binary = (circle_preds > 0.5).astype(np.float32)

# Calculate IoU for each test image
def calculate_iou(y_true, y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    return (intersection + 1e-7) / (union + 1e-7)

# Calculate IoU only when classification is correct (otherwise 0)
test_ious = []
for i in range(len(test_images)):
    if class_preds_labels[i] == test_labels[i]:
        iou = calculate_iou(test_circle_masks[i], circle_preds_binary[i])
    else:
        iou = 0.0
    test_ious.append(iou)

mean_iou = np.mean(test_ious)
print(f"Mean IoU on test set (with classification constraint): {mean_iou:.4f}")

# Plot training history
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(history.history['classification_accuracy'])
plt.plot(history.history['val_classification_accuracy'])
plt.title('Classification Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

plt.subplot(1, 3, 2)
plt.plot(history.history['circlization_iou_metric'])
plt.plot(history.history['val_circlization_iou_metric'])
plt.title('Circlization IoU')
plt.ylabel('IoU')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

plt.subplot(1, 3, 3)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.tight_layout()
plt.show()

# Visualize some predictions
def visualize_predictions(images, masks, pred_masks, true_labels, pred_labels, num_samples=5):
    indices = np.random.choice(range(len(images)), num_samples, replace=False)
    
    plt.figure(figsize=(15, 5*num_samples))
    for i, idx in enumerate(indices):
        # Original image
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(images[idx].reshape(28, 28), cmap='gray')
        plt.title(f'Image (True: {true_labels[idx]}, Pred: {pred_labels[idx]})')
        plt.axis('off')
        
        # Ground truth mask
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(masks[idx].reshape(28, 28), cmap='gray')
        plt.title('Ground Truth Circle')
        plt.axis('off')
        
        # Predicted mask
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(pred_masks[idx].reshape(28, 28), cmap='gray')
        iou_val = test_ious[idx]
        plt.title(f'Predicted Circle (IoU: {iou_val:.4f})')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize some predictions
visualize_predictions(
    test_images, 
    test_circle_masks, 
    circle_preds_binary, 
    test_labels, 
    class_preds_labels
)


Epoch 1/15
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m98s[0m 103ms/step - circlization_iou_metric: 0.7657 - circlization_loss: 0.2581 - classification_accuracy: 0.8439 - classification_loss: 0.4816 - loss: 0.7397 - val_circlization_iou_metric: 0.8020 - val_circlization_loss: 0.1988 - val_classification_accuracy: 0.9852 - val_classification_loss: 0.0501 - val_loss: 0.2489
Epoch 2/15
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m228s[0m 270ms/step - circlization_iou_metric: 0.7999 - circlization_loss: 0.1974 - classification_accuracy: 0.9775 - classification_loss: 0.0785 - loss: 0.2759 - val_circlization_iou_metric: 0.8009 - val_circlization_loss: 0.1924 - val_classification_accuracy: 0.9893 - val_classification_loss: 0.0380 - val_loss: 0.2305
Epoch 3/15
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 112ms/step - circlization_iou_metric: 0.8027 - circlization_loss: 0.1920 - classification_accuracy: 0.9849 - classification_loss: 0.0502 - 

KeyboardInterrupt: 