In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models, datasets, utils
import numpy as np

In [3]:
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsules, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsules = dim_capsules
        self.routings = routings

    def build(self, input_shape):
        # Dimensions: (num_capsules, input-num-capsules, dim_capsules, input_dims)
        self.kernel = self.add_weight(name='capsule_kernel',
                                      shape=(self.num_capsules, input_shape[-2],
                                             self.dim_capsules, input_shape[-1]),
                                      initializer='glorot_uniform',
                                      trainable=True)

    def call(self, inputs, training=None):
        # Compute dot product between inputs and weights
        u_hat = tf.keras.backend.batch_dot(inputs, self.kernel, [-1, -1])
        return u_hat  # Quick example without routing



In [4]:

def create_capsnet(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(256, 9, activation='relu')(inputs)
    x = layers.Conv2D(256, 9, activation='relu')(x)
    capsule = CapsuleLayer(num_capsules=10, dim_capsules=16)(x)  # Example params
    outputs = capsule
    model = models.Model(inputs=inputs, outputs=outputs, name='capsnet_model')
    return model

capsnet = create_capsnet((28, 28, 1))
capsnet.summary()





In [5]:
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

# Normalize images to the range of [0,1]
train_images = train_images.astype('float32') / 255.
test_images = test_images.astype('float32') / 255.

# Reshape images to include channels
train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)

# Convert labels to one-hot encoding
train_labels = utils.to_categorical(train_labels, 10)
test_labels = utils.to_categorical(test_labels, 10)

In [6]:
capsnet.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

# Train the network
history = capsnet.fit(train_images, train_labels,
                      batch_size=128,
                      epochs=10,
                      validation_split=0.2)

Epoch 1/10


ValueError: Exception encountered when calling CapsuleLayer.call().

[1mCannot do batch_dot on inputs with different batch sizes. Received inputs with tf.shapes (128, 12, 12, 256) and (10, 12, 16, 256).[0m

Arguments received by CapsuleLayer.call():
  • inputs=tf.Tensor(shape=(128, 12, 12, 256), dtype=float32)
  • training=True

In [7]:
test_loss, test_accuracy = capsnet.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_accuracy}")

ValueError: Dimensions must be equal, but are 10 and 16 for '{{node compile_loss/mse/sub}} = Sub[T=DT_FLOAT](data_1, capsnet_model_1/capsule_layer_1/Reshape_2)' with input shapes: [?,10], [10,12,12,12,16].

In [8]:
import matplotlib.pyplot as plt

def plot_examples(images, labels, predictions):
    fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(15, 3))
    for i in range(5):
        axes[i].imshow(images[i].squeeze(), cmap='gray')
        true_label = np.argmax(labels[i])
        predicted_label = np.argmax(predictions[i])
        axes[i].set_title(f'True: {true_label}, Pred: {predicted_label}')
        axes[i].axis('off')

predictions = capsnet.predict(test_images[:5])
plot_examples(test_images[:5], test_labels[:5], predictions)

ValueError: Exception encountered when calling CapsuleLayer.call().

[1mCannot do batch_dot on inputs with different batch sizes. Received inputs with tf.shapes (5, 12, 12, 256) and (10, 12, 16, 256).[0m

Arguments received by CapsuleLayer.call():
  • inputs=tf.Tensor(shape=(5, 12, 12, 256), dtype=float32)
  • training=False