# Training a Single Model for Computer Vision and Natural Language Processing

In this notebook, we train a single MANN model to perform both a computer vision task (image classification) as well as a natural language processing task (sentiment analysis).  For this notebook, we will use the MNIST Fashion dataset as well as the IMDB Sentiment Analysis dataset to train and test on.

In [1]:
# Load the packages required for the experiment
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
import numpy as np
import mann

## Load and preprocess the data

In [2]:
# Load the data
(fashion_x_train, fashion_y_train), (fashion_x_test, fashion_y_test) = tf.keras.datasets.fashion_mnist.load_data()
(imdb_x_train, imdb_y_train), (imdb_x_test, imdb_y_test) = tf.keras.datasets.imdb.load_data(num_words = 10000)

# Preprocess each of the input datasets. For the images, normalize each of the pixels to values between 0 and 1.
# For the reviews, truncate and/or pad the lengths to 500 words each
fashion_x_train = fashion_x_train/255
fashion_x_test = fashion_x_test/255
imdb_x_train = tf.keras.preprocessing.sequence.pad_sequences(imdb_x_train, maxlen = 500)
imdb_x_test = tf.keras.preprocessing.sequence.pad_sequences(imdb_x_test, maxlen = 500)

# Reshape the target data to having one column
fashion_y_train = fashion_y_train.reshape(-1, 1)
fashion_y_test = fashion_y_test.reshape(-1, 1)
imdb_y_train = imdb_y_train.reshape(-1, 1)
imdb_y_test = imdb_y_test.reshape(-1, 1)

## Create the model

In [3]:
# Create the input block for the fashion data, which includes an input layer, a flatten layer, and a masked dense layer
fashion_input = tf.keras.layers.Input(fashion_x_train.shape[1:])
fashion_flatten = tf.keras.layers.Flatten()(fashion_input)
fashion_reshape = mann.layers.MaskedDense(512, activation = 'relu')(fashion_flatten)

# Create the input block for the reviews data, which includes an input layer, an embedding layer, a flatten layer,
# and a masked dense layer of equal output shape to the masked dense layer for the fashion input block
imdb_input = tf.keras.layers.Input(imdb_x_train.shape[1:])
imdb_embedding = tf.keras.layers.Embedding(10000, 2)(imdb_input)
imdb_flatten = tf.keras.layers.Flatten()(imdb_embedding)
imdb_reshape = mann.layers.MaskedDense(512, activation = 'relu')(imdb_flatten)

# Now that the shapes align for each of the tasks, we can push the data through multitask layers
x = mann.layers.MultiMaskedDense(256, activation = 'relu')([fashion_reshape, imdb_reshape])
x = mann.layers.MultiMaskedDense(256, activation = 'relu')(x)
x = mann.layers.MultiMaskedDense(256, activation = 'relu')(x)
x = mann.layers.MultiMaskedDense(256, activation = 'relu')(x)
x = mann.layers.MultiMaskedDense(256, activation = 'relu')(x)
x = mann.layers.MultiMaskedDense(256, activation = 'relu')(x)

# Output block for the fashion data
fashion_selector = mann.layers.SelectorLayer(0)(x)
fashion_output = mann.layers.MaskedDense(10, activation = 'softmax')(fashion_selector)

# Output block for the IMDB data
imdb_selector = mann.layers.SelectorLayer(1)(x)
imdb_output = mann.layers.MaskedDense(1, activation = 'sigmoid')(imdb_selector)

# Instantiate the model and compile it
model = tf.keras.models.Model([fashion_input, imdb_input], [fashion_output, imdb_output])
model.compile(
    loss = ['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics = 'accuracy',
    optimizer = 'adam'
)

# Mask (prune) the model and recompile for training
model = mann.utils.mask_model(
    model,
    90,
    x = [fashion_x_train[:1000], imdb_x_train[:1000]],
    y = [fashion_y_train[:1000], imdb_y_train[:1000]]
)
model.compile(
    loss = ['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics = 'accuracy',
    optimizer = 'adam'
)
model.summary()

2021-12-11 17:17:21.912569: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2021-12-11 17:17:21.912867: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 500)]        0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 500, 2)       20000       input_2[0][0]                    
__________________________________________________________________________________________________
flatten (Flatten)               (None, 784)          0           input_1[0][0]                    
_________________

## Train the model

In [4]:
# This cell compiles the model for training task 1 (fashion) and trains the model for that task
callback = tf.keras.callbacks.EarlyStopping(min_delta = 0.01, patience = 3, restore_best_weights = True)
model.compile(
    loss = ['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics = 'accuracy',
    optimizer = 'adam',
    loss_weights = [1, 0]
)
model.fit(
    [fashion_x_train, np.zeros((fashion_x_train.shape[0], imdb_x_train.shape[1]))],
    [fashion_y_train, np.zeros(fashion_y_train.shape)],
    batch_size = 512,
    epochs = 100,
    callbacks = [callback],
    validation_split = 0.2
)

2021-12-11 17:17:23.607473: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 1/100


2021-12-11 17:17:24.052184: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.




2021-12-11 17:17:27.494694: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100


Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100


<keras.callbacks.History at 0x177969610>

In [5]:
# This cell compiles the model for training task 2 (IMDB) and trains the model for that task
model.compile(
    loss = ['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics = 'accuracy',
    optimizer = 'adam',
    loss_weights = [0, 1]
)
model.fit(
    [np.zeros((imdb_x_train.shape[0],) + fashion_x_train.shape[1:]), imdb_x_train],
    [np.zeros(imdb_y_train.shape[0]), imdb_y_train],
    batch_size = 128,
    epochs = 100,
    callbacks = [callback],
    validation_split = 0.2
)

Epoch 1/100


2021-12-11 17:18:29.579010: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.




2021-12-11 17:19:00.124824: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100


<keras.callbacks.History at 0x358ceb070>

## Remove masks

The following cell removes the masks within the model, which are only used for training and contribute a large number of weights to the model which are no longer needed

In [6]:
simplified_model = mann.utils.remove_layer_masks(model)
simplified_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 500)]        0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 500, 2)       20000       input_2[0][0]                    
__________________________________________________________________________________________________
flatten (Flatten)               (None, 784)          0           input_1[0][0]                    
______________________________________________________________________________________________

## Get predictions and report performance

In [7]:
fashion_preds = simplified_model.predict([fashion_x_test, np.zeros((fashion_x_test.shape[0], imdb_x_test.shape[1]))])[0].argmax(axis = 1)
imdb_preds = (simplified_model.predict([np.zeros((imdb_x_test.shape[0],) + fashion_x_test.shape[1:]), imdb_x_test])[1].flatten() >= 0.5).astype(int)

2021-12-11 17:22:12.961101: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-11 17:22:14.035447: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


In [8]:
print('Fashion Test Performance:')
print('\n')
print(confusion_matrix(fashion_y_test, fashion_preds))
print(classification_report(fashion_y_test, fashion_preds))

Fashion Test Performance:


[[836   0  14  29   2   1 110   0   8   0]
 [  3 962   5  22   3   0   4   0   1   0]
 [ 15   3 795   5 110   0  70   0   2   0]
 [ 28   6   7 862  51   1  43   0   2   0]
 [  0   0 146  26 773   0  54   0   1   0]
 [  0   0   0   0   0 949   0  23   8  20]
 [175   1 120  22 133   0 534   0  15   0]
 [  0   0   0   0   0  24   0 924   0  52]
 [  0   0   4   7   4   7  35   3 939   1]
 [  0   0   0   0   0  13   0  35   1 951]]
              precision    recall  f1-score   support

           0       0.79      0.84      0.81      1000
           1       0.99      0.96      0.98      1000
           2       0.73      0.80      0.76      1000
           3       0.89      0.86      0.87      1000
           4       0.72      0.77      0.74      1000
           5       0.95      0.95      0.95      1000
           6       0.63      0.53      0.58      1000
           7       0.94      0.92      0.93      1000
           8       0.96      0.94      0.95      1000


In [9]:
print('IMDB Test Performance:')
print('\n')
print(confusion_matrix(imdb_y_test, imdb_preds))
print(classification_report(imdb_y_test, imdb_preds))

IMDB Test Performance:


[[11065  1435]
 [ 1901 10599]]
              precision    recall  f1-score   support

           0       0.85      0.89      0.87     12500
           1       0.88      0.85      0.86     12500

    accuracy                           0.87     25000
   macro avg       0.87      0.87      0.87     25000
weighted avg       0.87      0.87      0.87     25000



## Save and load the model

To show how the model can be safely saved and loaded with the custom layers we have developed, we provide the following code

In [10]:
simplified_model.save('cv_and_nlp_model.h5')
loaded_model = tf.keras.models.load_model('cv_and_nlp_model.h5', custom_objects = mann.utils.get_custom_objects())
loaded_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 500)]        0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 500, 2)       20000       input_2[0][0]                    
__________________________________________________________________________________________________
flatten (Flatten)               (None, 784)          0           input_1[0][0]                    
______________________________________________________________________________________________