In [1]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Conv2D, Reshape, Dense, Flatten, Input
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
import os

# a. Dataset Selection and Preprocessing
# Using CIFAR-10 dataset for demonstration (not ideal for object detection, but works for demo)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Create synthetic bounding boxes for demonstration
# Create synthetic data
x_train_boxes = np.zeros((x_train.shape[0], 4))
x_test_boxes = np.zeros((x_test.shape[0], 4))

# Set random seed for reproducibility
np.random.seed(42)

for i in range(x_train.shape[0]):
    box_size = np.random.uniform(0.6, 0.8)
    half_size = box_size / 2
    x_train_boxes[i] = [0.5-half_size, 0.5-half_size, 0.5+half_size, 0.5+half_size]

for i in range(x_test.shape[0]):
    box_size = np.random.uniform(0.6, 0.8)
    half_size = box_size / 2
    x_test_boxes[i] = [0.5-half_size, 0.5-half_size, 0.5+half_size, 0.5+half_size]

# Preprocessing function
def preprocess_data(images, labels, boxes):
    # Resize images to 224x224 (standard size for MobileNetV2)
    images = tf.image.resize(images, (224, 224))
    # Normalize pixel values to [0,1]
    images = tf.cast(images, tf.float32) / 255.0
    # Return as a tuple instead of a dictionary for simpler model handling
    return images, labels, boxes

# Preprocess datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, x_train_boxes))
train_dataset = train_dataset.map(preprocess_data).batch(32)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test, x_test_boxes))
test_dataset = test_dataset.map(preprocess_data).batch(32)

# b. Model Architecture (Simplified object detection model)
def build_object_detector(num_classes=10):  # CIFAR-10 has 10 classes
    # Input layer
    input_layer = Input(shape=(224, 224, 3))
    
    # Base network - use MobileNetV2 with pretrained weights
    base_model = MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
    # Freeze the base model layers
    base_model.trainable = False
    
    # Connect the input to the base model
    x = base_model(input_layer)
    
    # Add global average pooling
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    
    # Classification head
    class_output = Dense(num_classes, activation='softmax', name='class_output')(x)
    
    # Box regression head
    box_output = Dense(4, name='box_output')(x)
    
    # Create model
    model = Model(inputs=input_layer, outputs=[class_output, box_output])
    return model

# Create the model
model = build_object_detector()
print(model.summary())

# Compile the model
model.compile(
    optimizer='adam',
    loss={
        'class_output': 'sparse_categorical_crossentropy',
        'box_output': 'mse'
    },
    metrics={
        'class_output': 'accuracy',
        'box_output': 'mse'
    }
)

# Define a custom data generator to match the model's expected input/output format
def data_generator(dataset, batch_size=32):
    for images_batch, labels_batch, boxes_batch in dataset:
        yield images_batch, {'class_output': labels_batch, 'box_output': boxes_batch}

# Create generators
train_gen = data_generator(train_dataset)
test_gen = data_generator(test_dataset)

# c. Training and Evaluation
# Get the number of batches in each dataset
train_steps = len(x_train) // 32
test_steps = len(x_test) // 32

# Train the model
history = model.fit(
    train_gen,
    steps_per_epoch=train_steps,
    epochs=3,  # Reduced epochs for faster execution
    validation_data=test_gen,
    validation_steps=test_steps
)

# Evaluation Metrics
test_loss = model.evaluate(test_gen, steps=test_steps)
print(f"Test Loss: {test_loss[0]:.4f}")
print(f"Class Output Loss: {test_loss[1]:.4f}")
print(f"Box Output Loss: {test_loss[2]:.4f}")
print(f"Class Accuracy: {test_loss[3]*100:.2f}%")
print(f"Box MSE: {test_loss[4]:.4f}")

# Visualization Function
def visualize_predictions(image, box, class_id, class_names):
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    ax = plt.gca()
    
    # Extract box coordinates
    x1, y1, x2, y2 = box
    width = x2 - x1
    height = y2 - y1
    
    # Draw the bounding box
    rect = plt.Rectangle((x1*224, y1*224), 
                        width*224, height*224,
                        fill=False, color='red', linewidth=2)
    ax.add_patch(rect)
    
    # Add class label
    plt.text(x1*224, y1*224-5, 
            f'{class_names[class_id[0]]}',
            color='red', fontsize=12,
            backgroundcolor='white')
    
    plt.axis('off')
    plt.show()

# Make a prediction on a test image
for images_batch, labels_batch, boxes_batch in test_dataset.take(1):
    # Get the first image from the batch
    test_image = images_batch[0]
    test_label = labels_batch[0]
    test_box = boxes_batch[0]
    
    # Make prediction
    class_pred, box_pred = model.predict(tf.expand_dims(test_image, 0))
    
    # Get the predicted class
    predicted_class = np.argmax(class_pred, axis=1)
    
    # CIFAR-10 class names
    class_names = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]
    
    print(f"True class: {class_names[test_label[0]]}")
    print(f"Predicted class: {class_names[predicted_class[0]]}")
    print(f"True box: {test_box.numpy()}")
    print(f"Predicted box: {box_pred[0]}")
    
    # Visualize the ground truth
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.title('Ground Truth')
    visualize_predictions(test_image.numpy(), test_box.numpy(), test_label.numpy(), class_names)
    
    # Visualize the prediction
    plt.subplot(1, 2, 2)
    plt.title('Prediction')
    visualize_predictions(test_image.numpy(), box_pred[0], predicted_class, class_names)

KeyboardInterrupt: 