# Efficient-CapsNet Model Train

In this notebook we provide a simple interface to train Efficient-CapsNet on the three dataset discussed in "Efficient-CapsNet: Capsule Network with Self-Attention Routing":

- MNIST (MNIST)
- smallNORB (SMALLNORB)
- Multi-MNIST (MULTIMNIST)

The hyperparameters have been only slightly investigated. So, there's a lot of room for improvements. Good luck!

**NB**: remember to modify the "config.json" file with the appropriate parameters.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
from utils import Dataset, plotImages, plotWrongImages, plotHistory
from models import EfficientCapsNet

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
# some parameters
model_name = 'MNIST'

# 1.0 Import the Dataset

In [None]:
dataset = Dataset(model_name, config_path='config.json')

## 1.1 Visualize imported dataset

In [None]:
n_images = 20 # number of images to be plotted
plotImages(dataset.X_test[:n_images,...,0], dataset.y_test[:n_images], n_images, dataset.class_names)

# 2.0 Load the Model

In [None]:
model_train = EfficientCapsNet(model_name, mode='train', verbose=True)

# 3.0 Train the Model

In [None]:
dataset_train, dataset_val = dataset.get_tf_data() 

In [None]:
history = model_train.train(dataset, initial_epoch=0)

In [None]:
plotHistory(history)