In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist, cifar10, fashion_mnist
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Definition of the squash function
def squash(vectors, axis=-1):
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm)
    return scale * vectors / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())

# Definition of the Capsule layer
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsules, num_routing=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsules = dim_capsules
        self.num_routing = num_routing

    def build(self, input_shape):
        self.input_num_capsules = input_shape[1]
        self.input_dim_capsules = input_shape[2]
        self.W = self.add_weight(shape=[self.input_num_capsules, self.num_capsules,
                                        self.input_dim_capsules, self.dim_capsules],
                                 initializer='glorot_uniform',
                                 trainable=True)

    def call(self, inputs):
        inputs_expand = tf.expand_dims(inputs, 2)
        inputs_tiled = tf.tile(inputs_expand, [1, 1, self.num_capsules, 1])
        inputs_tiled = tf.expand_dims(inputs_tiled, -1)
        inputs_hat = tf.reduce_sum(self.W * inputs_tiled, axis=3)

        if inputs_hat.shape[-1] == 1:
            inputs_hat = tf.squeeze(inputs_hat, axis=-1)

        b = tf.zeros(shape=[tf.shape(inputs_hat)[0], self.input_num_capsules, self.num_capsules])

        for i in range(self.num_routing):
            c = tf.nn.softmax(b, axis=2)
            c = tf.expand_dims(c, -1)
            outputs = squash(tf.reduce_sum(c * inputs_hat, axis=1))
            if i < self.num_routing - 1:
                outputs_expand = tf.expand_dims(outputs, 1)
                agreement = tf.reduce_sum(inputs_hat * outputs_expand, axis=-1)
                b += agreement

        return outputs

# Reconstruction network
def create_reconstruction_network(digit_caps, y, input_shape):
    masked = layers.Lambda(lambda inputs: inputs[0] * tf.expand_dims(inputs[1], -1))([digit_caps, y])
    recon = layers.Flatten()(masked)
    recon = layers.Dense(512, activation='relu')(recon)
    recon = layers.Dense(1024, activation='relu')(recon)
    recon = layers.Dense(np.prod(input_shape), activation='sigmoid')(recon)
    recon = layers.Reshape(target_shape=input_shape)(recon)
    return recon

def CapsNet(input_shape, n_class, num_routing):
    x = layers.Input(shape=input_shape)
    y = layers.Input(shape=(n_class,))

    # First convolutional layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x)

    # Second convolutional layer
    conv2 = layers.Conv2D(filters=256, kernel_size=9, strides=2, padding='valid', activation='relu')(conv1)

    # Primary Capsules layer
    primary_caps = layers.Reshape(target_shape=[-1, 8])(conv2)
    primary_caps = layers.Lambda(squash)(primary_caps)

    # Digit Capsules layer
    digit_caps = CapsuleLayer(num_capsules=n_class, dim_capsules=16, num_routing=num_routing)(primary_caps)

    # Output layer
    out_caps = layers.Lambda(lambda z: tf.sqrt(tf.reduce_sum(tf.square(z), axis=2)), name='out_caps')(digit_caps)

    # Reconstruction network
    recon = create_reconstruction_network(digit_caps, y, input_shape)
    recon = layers.Lambda(lambda x: x, name='reconstruction')(recon)  # Explicitly name the reconstruction output

    # Model creation
    model = models.Model([x, y], [out_caps, recon])
    model.compile(optimizer=optimizers.Adam(learning_rate=0.001),
                  loss={'out_caps': 'mse', 'reconstruction': 'binary_crossentropy'},
                  loss_weights={'out_caps': 1., 'reconstruction': 0.0005},
                  metrics={'out_caps': 'accuracy'})  # Metrics for the classification output only
    return model


# Load datasets
def load_data(dataset):
    if dataset == 'mnist':
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
        x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
    elif dataset == 'fashion_mnist':
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
        x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
        x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
    elif dataset == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        x_train = x_train.astype('float32') / 255.
        x_test = x_test.astype('float32') / 255.
    else:
        raise ValueError("Dataset not supported")

    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    return (x_train, y_train), (x_test, y_test)

# Train and evaluate model
def train_and_evaluate(dataset, input_shape, n_class, num_routing):
    (x_train, y_train), (x_test, y_test) = load_data(dataset)
    model = CapsNet(input_shape=input_shape, n_class=n_class, num_routing=num_routing)
    history = model.fit([x_train, y_train], [y_train, x_train], batch_size=100, epochs=10, validation_data=([x_test, y_test], [y_test, x_test]))

    # Evaluate model
    results = model.evaluate([x_test, y_test], [y_test, x_test])
    print(f"Dataset: {dataset}, Results: {results}")

    # Generate detailed metrics
    y_pred, _ = model.predict([x_test, y_test])
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_true = np.argmax(y_test, axis=1)

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred_classes)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(n_class), yticklabels=range(n_class))
    plt.title(f'Confusion Matrix for {dataset}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

    return history

# Compare results across datasets
results = {}
for dataset, input_shape, n_class in zip(['mnist', 'fashion_mnist', 'cifar10'],
                                         [(28, 28, 1), (28, 28, 1), (32, 32, 3)],
                                         [10, 10, 10]):
    print(f"Training on {dataset}...")
    history = train_and_evaluate(dataset, input_shape, n_class, num_routing=3)
    results[dataset] = history

# Plot training results
for dataset, history in results.items():
    plt.figure(figsize=(16, 6))
    plt.subplot(1, 3, 1)
    plt.plot(history.history['loss'], label=f'{dataset} Training Loss')
    plt.plot(history.history['val_loss'], label=f'{dataset} Validation Loss')
    plt.title(f'{dataset} Loss')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(history.history['out_caps_accuracy'], label=f'{dataset} Training Accuracy')
    plt.plot(history.history['val_out_caps_accuracy'], label=f'{dataset} Validation Accuracy')
    plt.title(f'{dataset} Accuracy')
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.bar(range(len(history.history['val_out_caps_accuracy'])), history.history['val_out_caps_accuracy'], label=f'{dataset} Validation Accuracy')
    plt.title(f'{dataset} Validation Accuracy Per Epoch')
    plt.legend()

    plt.show()

# Comparison plot
plt.figure(figsize=(12, 8))
for dataset, history in results.items():
    plt.plot(history.history['val_out_caps_accuracy'], label=f'{dataset} Validation Accuracy')
plt.title('Validation Accuracy Comparison Across Datasets')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


Training on mnist...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1002)