## 1. Train the model

In [None]:
# load dataset
import tensorflow as tf
from utils.dataset import Dataset
from models.model import EfficientCapsNet

# loads dataset
dataset = Dataset(config_path='config_STSL.json')

# check dataset
dataset.print_ds_info()

In [None]:
# instantiate model
model_train = EfficientCapsNet(mode='train', verbose=True)

In [None]:
# train model
history = model_train.train(dataset, initial_epoch=0)

In [None]:
# Plot history
from utils.visualisation import plotHistory

plotHistory(history)

## 2. Test model

In [None]:
# load test model
model_test = EfficientCapsNet(mode='test', verbose=True)
model_test.load_graph_weights()

In [None]:
# evaluate
model_test.evaluate(dataset.X_test, dataset.y_test) 

In [None]:
# plot wrong images

from utils.visualisation import plotWrongImagesWithCharts

y_pred, reconstructed_imgs = model_test.predict(dataset.X_test)

plotWrongImagesWithCharts(dataset.X_test, dataset.y_test, y_pred, reconstructed_imgs, 3, dataset.class_names)

## 3. Visualise generator reconstruction

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# pass test data through model
capsule_outputs, reconstructed_images = model_test.predict(dataset.X_test)

# Function to decode one-hot encoded labels
def decode_onehot(onehot_encoded):
    return np.argmax(onehot_encoded)

# Fetch one instance of each class from the test set
class_indices = {i: None for i in range(11)}  # Assuming 11 classes, indexed 0 to 10

for idx, label in enumerate(dataset.y_test):
    label_decoded = decode_onehot(label)
    if class_indices[label_decoded] is None:
        class_indices[label_decoded] = idx
    if all(v is not None for v in class_indices.values()):  # Break once we've found one of each
        break


# Visualize the original and reconstructed images for each class
num_classes = 11
images_per_row = 6  # Number of images per row

plt.figure(figsize=(2*images_per_row, 8))  # Adjust the figure size

for i, (class_label, idx) in enumerate(class_indices.items()):
    
    row = (i // images_per_row) * 2  # Determine which row to place the original image in
    col = i % images_per_row   # Determine which column to place the image in
    
    # Original images
    plt.subplot(4, images_per_row, row * images_per_row + col + 1)
    plt.imshow(dataset.X_test[idx].reshape(128, 128), cmap='gray')
    plt.title(f"Class {class_label}")
    rect = patches.Rectangle((0, 0), 127, 127, linewidth=1, edgecolor='black', facecolor='none')
    plt.gca().add_patch(rect)
    plt.axis('off')
    
    # Reconstructed images
    plt.subplot(4, images_per_row, (row + 1) * images_per_row + col + 1)
    plt.imshow(reconstructed_images[idx].reshape(128, 128), cmap='gray')
    rect = patches.Rectangle((0, 0), 127, 127, linewidth=1, edgecolor='black', facecolor='none')
    plt.gca().add_patch(rect)
    plt.axis('off')

plt.tight_layout()
plt.show()


## Save the model

In [10]:
model_test.save_full_model('bin/model.keras')

Model saved successfully to bin/model.keras
