# Using the MANN Package to train a Convolutional Neural Network

In this notebook, the MANN package will be used to train pruned convolutional neural networks.  We will train two single-task networks on two separate tasks and one multitask network which performs both tasks.

In [None]:
# Load the MANN package and TensorFlow
import tensorflow as tf
import mann

In [None]:
# Load both MNIST tasks
(digit_x_train, digit_y_train), (digit_x_test, digit_y_test) = tf.keras.datasets.mnist.load_data()
(fashion_x_train, fashion_y_train), (fashion_x_test, fashion_y_test) = tf.keras.datasets.fashion_mnist.load_data()

# Reshape the x data so they have channels
digit_x_train = digit_x_train.reshape(digit_x_train.shape + (1,))/255
digit_x_test = digit_x_test.reshape(digit_x_test.shape + (1,))/255
fashion_x_train = fashion_x_train.reshape(fashion_x_train.shape + (1,))/255
fashion_x_test = fashion_x_test.reshape(fashion_x_test.shape + (1,))/255

# Reshape the y data
digit_y_train = digit_y_train.reshape(-1, 1)
digit_y_test = digit_y_test.reshape(-1, 1)
fashion_y_train = fashion_y_train.reshape(-1, 1)
fashion_y_test = fashion_y_test.reshape(-1, 1)

# Create a callback to stop training early
callback = tf.keras.callbacks.EarlyStopping(min_delta = 0.01, patience = 3, restore_best_weights = True)

## Create the first model

In [None]:
# Create the input layer for the digit task
input_layer = tf.keras.layers.Input(digit_x_train.shape[1:])

# Create the convolutional blocks
x = mann.layers.MaskedConv2D(
    filters = 32,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(input_layer)
x = mann.layers.MaskedConv2D(
    filters = 32,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D(
    pool_size = 2,
    strides = 1,
    padding = 'valid'
)(x)
x = mann.layers.MaskedConv2D(
    filters = 64,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = mann.layers.MaskedConv2D(
    filters = 64,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D(
    pool_size = 2,
    strides = 1,
    padding = 'valid'
)(x)
x = tf.keras.layers.Flatten()(x)
x = mann.layers.MaskedDense(256, activation = 'relu')(x)
x = mann.layers.MaskedDense(256, activation = 'relu')(x)
output_layer = mann.layers.MaskedDense(10, activation = 'softmax')(x)

# Create the model
model = tf.keras.Model(input_layer, output_layer)

In [None]:
# Compile the model for training and to prepare for masking
model.compile(
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'],
    optimizer = 'adam'
)

# Mask (prune) the model using the MANN package
model = mann.utils.mask_model(
    model = model,              # The model to be pruned
    percentile = 80,            # The percentile to be masked, for example, if the value is 90, then 90% of weights will be masked
    method = 'gradients',       # The method to use to mask, either 'gradients' or 'magnitude'
    exclusive = True,           # Whether weight locations must be exclusive to each task
    x = digit_x_train[:1000],   # The input data (using a subset to calculate gradients)
    y = digit_y_train[:1000]    # The expected outputs (using a subset to calculate gradients)
)

# Recompile the model
model.compile(
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'],
    optimizer = 'adam'
)

In [None]:
# Show the sparsity of the model
model.layers[1].get_weights()[0]

In [None]:
# Fit the model on the first dataset
model.fit(
    digit_x_train,
    digit_y_train,
    batch_size = 128,
    epochs = 100,
    validation_split = 0.2,
    callbacks = [callback]
)
print(f'Digit Model Accuracy: {(model.predict(digit_x_test).argmax(axis = 1).flatten() == digit_y_test.flatten()).sum()/digit_y_test.shape[0]}')

## Create the second model

In [None]:
# Create the input layer for the fashion task
input_layer = tf.keras.layers.Input(fashion_x_train.shape[1:])

# Create the convolutional blocks
x = mann.layers.MaskedConv2D(
    filters = 32,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(input_layer)
x = mann.layers.MaskedConv2D(
    filters = 32,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D(
    pool_size = 2,
    strides = 1,
    padding = 'valid'
)(x)
x = mann.layers.MaskedConv2D(
    filters = 64,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = mann.layers.MaskedConv2D(
    filters = 64,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D(
    pool_size = 2,
    strides = 1,
    padding = 'valid'
)(x)
x = tf.keras.layers.Flatten()(x)
x = mann.layers.MaskedDense(256, activation = 'relu')(x)
x = mann.layers.MaskedDense(256, activation = 'relu')(x)
output_layer = mann.layers.MaskedDense(10, activation = 'softmax')(x)

# Create the model
model = tf.keras.Model(input_layer, output_layer)

In [None]:
# Compile the model for training and to prepare for masking
model.compile(
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'],
    optimizer = 'adam'
)

# Mask (prune) the model using the MANN package
model = mann.utils.mask_model(
    model = model,              # The model to be pruned
    percentile = 80,            # The percentile to be masked, for example, if the value is 90, then 90% of weights will be masked
    method = 'gradients',       # The method to use to mask, either 'gradients' or 'magnitude'
    exclusive = True,           # Whether weight locations must be exclusive to each task
    x = fashion_x_train[:1000], # The input data (using a subset to calculate gradients)
    y = fashion_y_train[:1000]  # The expected outputs (using a subset to calculate gradients)
)

# Recompile the model
model.compile(
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'],
    optimizer = 'adam'
)

In [None]:
# Fit the model on the second dataset
model.fit(
    fashion_x_train,
    fashion_y_train,
    batch_size = 128,
    epochs = 100,
    validation_split = 0.2,
    callbacks = [callback]
)
print(f'Fashion Model Accuracy: {(model.predict(fashion_x_test).argmax(axis = 1).flatten() == fashion_y_test.flatten()).sum()/fashion_y_test.shape[0]}')

## Create the MANN

In [None]:
# Train the Multitask Model
digit_input = tf.keras.layers.Input(digit_x_train.shape[1:])
fashion_input = tf.keras.layers.Input(fashion_x_train.shape[1:])

# Create the convolutional blocks
x = mann.layers.MultiMaskedConv2D(
    filters = 32,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)([digit_input, fashion_input])
x = mann.layers.MultiMaskedConv2D(
    filters = 32,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = mann.layers.MultiMaxPool2D(
    pool_size = 2,
    strides = 1,
    padding = 'valid'
)(x)
x = mann.layers.MultiMaskedConv2D(
    filters = 64,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = mann.layers.MultiMaskedConv2D(
    filters = 64,
    kernel_size = 3,
    padding = 'same',
    strides = 1,
    activation = 'relu'
)(x)
x = mann.layers.MultiMaxPool2D(
    pool_size = 2,
    strides = 1,
    padding = 'valid'
)(x)

# SelectorLayer for the first task
sel1 = mann.layers.SelectorLayer(0)(x)
digit_flatten = tf.keras.layers.Flatten()(sel1)

# SelectorLayer for the second task
sel2 = mann.layers.SelectorLayer(1)(x)
fashion_flatten = tf.keras.layers.Flatten()(sel2)

x = mann.layers.MultiMaskedDense(256, activation = 'relu')([digit_flatten, fashion_flatten])
x = mann.layers.MultiMaskedDense(256, activation = 'relu')(x)
output_layer = mann.layers.MultiMaskedDense(10, activation = 'softmax')(x)

# Create the model
model = tf.keras.Model([digit_input, fashion_input], output_layer)

In [None]:
# Perform masking
model.compile(
    loss = 'sparse_categorical_crossentropy',
    
    metrics = ['accuracy'],
    optimizer = 'adam'
)
model = mann.utils.mask_model(
    model,
    80,
    method = 'gradients',
    exclusive = True,
    x = [digit_x_train[:1000], fashion_x_train[:1000]],
    y = [digit_y_train[:1000], fashion_y_train[:1000]]
)
model.compile(
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'],
    optimizer = 'adam'
)

model.fit(
    [digit_x_train, fashion_x_train],
    [digit_y_train, fashion_y_train],
    epochs = 100,
    batch_size = 128,
    callbacks = [callback],
    validation_split = 0.2
)

In [None]:
digit_preds, fashion_preds = model.predict([digit_x_test, fashion_x_test])
digit_preds = digit_preds.argmax(axis = 1)
fashion_preds = fashion_preds.argmax(axis = 1)

print(f'Multitask Model Digit Accuracy: {(digit_preds.flatten() == digit_y_test.flatten()).sum()/digit_y_test.flatten().shape[0]}')
print(f'Multitask Model Fashion Accuracy: {(fashion_preds.flatten() == fashion_y_test.flatten()).sum()/fashion_y_test.flatten().shape[0]}')