# Executing _@tf.function_ annotated function second time gives error:
The following code uses TensorFlow 2.0 along with Python 3.7.5, GradientTape and _tensorflow_model_optimization_ for model optimization (model pruning) for MNIST dataset classification

In [1]:
import tensorflow as tf
import numpy as np
import math
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.sparsity import keras as sparsity
# from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
from tensorflow.keras.layers import AveragePooling2D, Conv2D
from tensorflow.keras import models, layers, datasets
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input, InputLayer
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.initializers import RandomNormal
# import math
from sklearn.metrics import accuracy_score, precision_score, recall_score

In [2]:
tf.__version__

'2.0.0'

In [3]:
batch_size = 32
num_classes = 10
num_epochs = 50

In [4]:
# Data preprocessing and cleadning:
# input image dimensions
img_rows, img_cols = 28, 28

# Load MNIST dataset-
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
if tf.keras.backend.image_data_format() == 'channels_first':
    X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
    X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

print("\n'input_shape' which will be used = {0}\n".format(input_shape))


'input_shape' which will be used = (28, 28, 1)



In [6]:
# Convert datasets to floating point types-
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize the training and testing datasets-
X_train /= 255.0
X_test /= 255.0

In [7]:
# convert class vectors/target to binary class matrices or one-hot encoded values-
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

In [8]:
# Reshape training and testing sets-
X_train = X_train.reshape(X_train.shape[0], 784)
X_test = X_test.reshape(X_test.shape[0], 784)

In [9]:
print("\nDimensions of training and testing sets are:")
print("X_train.shape = {0}, y_train = {1}".format(X_train.shape, y_train.shape))
print("X_test.shape = {0}, y_test = {1}".format(X_test.shape, y_test.shape))


Dimensions of training and testing sets are:
X_train.shape = (60000, 784), y_train = (60000, 10)
X_test.shape = (10000, 784), y_test = (10000, 10)


In [10]:
# The model is first trained without any pruning for 'num_epochs' epochs-
epochs = num_epochs

num_train_samples = X_train.shape[0]

end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs

print("'end_step parameter' for this dataset =  {0}".format(end_step))

'end_step parameter' for this dataset =  93750


In [11]:
# Specify the parameters to be used for layer-wise pruning, NO PRUNING is done here:
pruning_params_unpruned = {
    'pruning_schedule': sparsity.ConstantSparsity(
        target_sparsity=0.0, begin_step=0,
        end_step = 0, frequency=100
    )
}

In [12]:
l = tf.keras.layers

In [13]:
def pruned_nn(pruning_params):
    """
    Function to define the architecture of a neural network model
    following 300 100 architecture for MNIST dataset and using
    provided parameter which are used to prune the model.
    
    Input: 'pruning_params' Python 3 dictionary containing parameters which are used for pruning
    Output: Returns designed and compiled neural network model
    """
    
    pruned_model = Sequential()
    pruned_model.add(l.InputLayer(input_shape=(784, )))
    pruned_model.add(Flatten())
    pruned_model.add(sparsity.prune_low_magnitude(
        Dense(units = 300, activation='relu', kernel_initializer=tf.initializers.GlorotUniform()),
        **pruning_params))
    # pruned_model.add(l.Dropout(0.2))
    pruned_model.add(sparsity.prune_low_magnitude(
        Dense(units = 100, activation='relu', kernel_initializer=tf.initializers.GlorotUniform()),
        **pruning_params))
    # pruned_model.add(l.Dropout(0.1))
    pruned_model.add(sparsity.prune_low_magnitude(
        Dense(units = num_classes, activation='softmax'),
        **pruning_params))
    
    # Compile pruned CNN-
    pruned_model.compile(
        loss=tf.keras.losses.categorical_crossentropy,
        # optimizer='adam',
        optimizer=tf.keras.optimizers.Adam(lr = 0.001),
        metrics=['accuracy'])
    
    return pruned_model


In [14]:
# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
             sparsity.UpdatePruningStep(),
             # sparsity.PruningSummaries(log_dir = logdir, profile_batch=0),
             tf.keras.callbacks.EarlyStopping(
                 monitor='val_loss', patience = 3,
                 min_delta=0.001
             )
]

In [15]:
# Instantiate NN-
orig_model = pruned_nn(pruning_params_unpruned)

Instructions for updating:
Please use `layer.add_weight` method instead.


In [16]:
# Save randomly initialized weights-
orig_model.save_weights("Random_Weights-Error_Recreation.h5", overwrite = True)

In [17]:
# Train model untile convergence-
# Train unpruned Neural Network-
history_orig = orig_model.fit(
    x = X_train, y = y_train,
    batch_size = batch_size,
    epochs = epochs,
    verbose = 1,
    callbacks = callbacks,
    validation_data = (X_test, y_test),
    shuffle = True
)

Train on 60000 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50


In [37]:
# Strip model of pruning parameters-
orig_model_stripped = sparsity.strip_pruning(orig_model)

In [18]:
# Train parameters of neural network AFTER training UNPRUNED model-
orig_model.save_weights("Trained_Weights-Error_Recreation.h5", overwrite = True)

In [19]:
# Count number of non-zero parameters in trained neural network-

print("\nIn original unpruned model, number of nonzero parameters in each layer are: \n")

orig_sum_params = 0

for layer in orig_model.trainable_weights:
    print(tf.math.count_nonzero(layer, axis = None).numpy())
    orig_sum_params += tf.math.count_nonzero(layer, axis = None).numpy()

print("\nTotal number of trainable parameters = {0}\n".format(orig_sum_params))


In original unpruned model, number of nonzero parameters in each layer are: 

235200
300
30000
100
1000
10

Total number of trainable parameters = 266610



In [20]:
# Specify the parameters to be used for layer-wise pruning, NO PRUNING is done here:
pruning_params_constantsparsity = {
    'pruning_schedule': sparsity.ConstantSparsity(
        target_sparsity=0.2674, begin_step=100,
        end_step = end_step, frequency=100
    )
}

In [21]:
# Instantiate a Nueal Network model to be pruned using parameters from above-
pruned_model = pruned_nn(pruning_params_constantsparsity)

In [22]:
# Load weights from original trained and unpruned model-
pruned_model.load_weights("Trained_Weights-Error_Recreation.h5")

In [23]:
# Train pruned NN-
history_pruned = pruned_model.fit(
    x = X_train, y = y_train,
    batch_size = batch_size,
    epochs = epochs,
    verbose = 1,
    callbacks = callbacks,
    validation_data = (X_test, y_test),
    shuffle = True
)


Train on 60000 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50


In [24]:
# Strip the pruning wrappers from pruned model-
pruned_model_stripped = sparsity.strip_pruning(pruned_model)

In [25]:
print("\nIn pruned model, number of nonzero parameters in each layer are: \n")

pruned_sum_params = 0

for layer in pruned_model.trainable_weights:
# for layer in pruned_model_stripped.trainable_weights:
    print(tf.math.count_nonzero(layer, axis = None).numpy())
    pruned_sum_params += tf.math.count_nonzero(layer, axis = None).numpy()

print("\nTotal number of trainable parameters = {0}\n".format(pruned_sum_params))


In pruned model, number of nonzero parameters in each layer are: 

172308
300
21978
100
733
10

Total number of trainable parameters = 195429



In [27]:
print("\n% of weights pruned = {0:.2f}%\n".format(
    ((orig_sum_params - pruned_sum_params) / orig_sum_params) * 100
))


% of weights pruned = 26.70%



In [28]:
# Save weights of PRUNED and Trained model BEFORE stripping-
pruned_model.save_weights("Pruned_Weights-Error_Recreation.h5", overwrite = True)

### Create mask for maintaining the sparsity of winning tickets:
In order for pruned model to maintain it's sparsity, a mask is created which will be used by _GradientTape_ subsequently to train models.

The mask is created as follows-
1. Weights surviving the pruning are initialized to one (1)
1. Weights which are pruned are initialized to zero (0)

In [29]:
# Instantiate a new neural network model for which, the mask is to be created,
# according to the paper-
mask_model = pruned_nn(pruning_params_unpruned)

In [30]:
# Load weights of PRUNED model-
mask_model.load_weights("Pruned_Weights-Error_Recreation.h5")

In [31]:
# Strip the model of its pruning parameters-
mask_model_stripped = sparsity.strip_pruning(mask_model)

In [32]:
# For each layer, for each weight which is 0, leave it, as is.
# And for weights which survive the pruning,reinitialize it to ONE (1)-

for wts in mask_model_stripped.trainable_weights:
    wts.assign(tf.where(tf.equal(wts, 0.), 0., 1.))

### Reset remaining parameters:
Reset the remaining parameters in the _pruned_ model to their random weights when the model was initially created.

In order to extract the winning ticket from the pruned neural network, __reset__ the weights of the surviving parts of the _pruned_ neural network to their _original randomly initialized and unpruned_ weights which were received before the training of the neural network model began (from above)

In [33]:
# Instantiate a new neural network model for which, the weights are to be extracted, according to the paper-
winning_ticket_model = pruned_nn(pruning_params_unpruned)

In [34]:
# Load weights of PRUNED model-
winning_ticket_model.load_weights("Pruned_Weights-Error_Recreation.h5")

In [35]:
# Strip the model of its pruning parameters-
winning_ticket_model_stripped = sparsity.strip_pruning(winning_ticket_model)

In [38]:
# For each layer, for each weight which is 0, leave it, as is. And for weights which survive the pruning,
# reinitialize it to the value, the model received BEFORE it was trained and pruned-
for orig_wts, pruned_wts in zip(orig_model_stripped.trainable_weights, winning_ticket_model_stripped.trainable_weights):
    pruned_wts.assign(tf.where(tf.equal(pruned_wts, 0), pruned_wts, orig_wts))

In [39]:
# Save the weights (with pruning parameters) extracted to a file-
winning_ticket_model.save_weights("Winning_Ticket_Weights-Error_Recreation.h5", overwrite=True)

In [40]:
# Create training and testing datasets-
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

In [41]:
train_dataset = train_dataset.shuffle(
    buffer_size = 20000, reshuffle_each_iteration = True).batch(batch_size = batch_size, drop_remainder = False)

test_dataset = test_dataset.batch(batch_size=batch_size, drop_remainder=False)

In [42]:
# Choose an optimizer and loss function for training-
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(lr = 0.001)

In [43]:
# Select metrics to measure the error & accuracy of model.
# These metrics accumulate the values over epochs and then
# print the overall result-
train_loss = tf.keras.metrics.Mean(name = 'train_loss')
train_accuracy = tf.keras.metrics.BinaryAccuracy(name = 'train_accuracy')

test_loss = tf.keras.metrics.Mean(name = 'test_loss')
test_accuracy = tf.keras.metrics.BinaryAccuracy(name = 'train_accuracy')

In [44]:
@tf.function
def train_one_step(model, mask_model, optimizer, x, y):
    '''
    def train_step(data, labels):
    Function to compute one step of gradient descent optimization
    '''
    with tf.GradientTape() as tape:
        # Make predictions using defined model-
        y_pred = model(x)

        # Compute loss-
        loss = loss_fn(y, y_pred)
        
    # Compute gradients wrt defined loss and weights and biases-
    grads = tape.gradient(loss, model.trainable_variables)
    
    # type(grads)
    # list
    
    # List to hold element-wise multiplication between-
    # computed gradient and masks-
    grad_mask_mul = []
    
    # Perform element-wise multiplication between computed gradients and masks-
    for grad_layer, mask in zip(grads, mask_model.trainable_weights):
        grad_mask_mul.append(tf.math.multiply(grad_layer, mask))
    
    # Apply computed gradients to model's weights and biases-
    optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))

    # Compute accuracy-
    train_loss(loss)
    train_accuracy(y, y_pred)

    return None

In [45]:
@tf.function
def test_step(model, optimizer, data, labels):
    """
    Function to test model performance
    on testing dataset
    """
    
    predictions = model(data)
    t_loss = loss_fn(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

    return None


### Retrain Pruned model:
The pruned model from above is retrained while _maintaining sparsity_ using mask from _mask_model_stripped_

In [46]:
# Instantiate a model
model_gt = pruned_nn(pruning_params_unpruned)

In [47]:
# Load winning ticket (from above-)
model_gt.load_weights("Winning_Ticket_Weights-Error_Recreation.h5")

In [48]:
# Strip model of pruning parameters-
model_gt_stripped = sparsity.strip_pruning(model_gt)

In [49]:
print("\nnumber of trainable parameters in original model = {0}".format(orig_sum_params))
print("number of trainable parameters in pruned model = {0}\n".format(pruned_sum_params))


number of trainable parameters in original model = 266610
number of trainable parameters in pruned model = 195429



In [50]:
# Define variables for manual Early Stopping-
best_val_loss = 1
loc_patience = 0

In [51]:
# Define variables for manual Early Stopping-
patience = 3
minimum_delta = 0.001

In [52]:
for epoch in range(num_epochs):
    
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
        break
        
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
        
    for x, y in train_dataset:
        # train_step(x, y)
        train_one_step(model_gt_stripped, mask_model_stripped, optimizer, x, y)

    for x_t, y_t in test_dataset:
        # test_step(x_t, y_t)
        test_step(model_gt_stripped, optimizer, x_t, y_t)

    template = 'Epoch {0}, Loss: {1:.4f}, Accuracy: {2:.4f}, Test Loss: {3:.4f}, Test Accuracy: {4:4f}'
    
    print(template.format(epoch + 1, 
                              train_loss.result(), train_accuracy.result()*100,
                              test_loss.result(), test_accuracy.result()*100))
    
    # Count number of non-zero parameters in each layer and in total-
    # print("layer-wise manner model, number of nonzero parameters in each layer are: \n")

    model_sum_params = 0
    
    for layer in model_gt_stripped.trainable_weights:
        # print(tf.math.count_nonzero(layer, axis = None).numpy())
        model_sum_params += tf.math.count_nonzero(layer, axis = None).numpy()
    
    print("Total number of trainable parameters = {0}\n".format(model_sum_params))

    
    # Code for manual Early Stopping:
    if np.abs(test_loss.result() < best_val_loss) >= minimum_delta:
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = test_loss.result()
        
        # reset 'loc_patience' variable-
        loc_patience = 0
        
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement
    


Epoch 1, Loss: 0.0110, Accuracy: 99.9307, Test Loss: 0.0909, Test Accuracy: 99.612961
Total number of trainable parameters = 195429

Epoch 2, Loss: 0.0082, Accuracy: 99.9489, Test Loss: 0.0937, Test Accuracy: 99.623940
Total number of trainable parameters = 195429

Epoch 3, Loss: 0.0092, Accuracy: 99.9385, Test Loss: 0.0900, Test Accuracy: 99.623955
Total number of trainable parameters = 195429

Epoch 4, Loss: 0.0073, Accuracy: 99.9530, Test Loss: 0.1048, Test Accuracy: 99.608978
Total number of trainable parameters = 195429

Epoch 5, Loss: 0.0071, Accuracy: 99.9548, Test Loss: 0.0984, Test Accuracy: 99.631973
Total number of trainable parameters = 195429

Epoch 6, Loss: 0.0064, Accuracy: 99.9566, Test Loss: 0.1090, Test Accuracy: 99.591972
Total number of trainable parameters = 195429


'EarlyStopping' called!



In [53]:
# Save weights of winning ticket trained with GradientTape
# WITH pruning parameter, so that it can be used to prune-
model_gt.save_weights("Trained_Weights-Error_Recreation.h5", overwrite=True)

### Prune the trained model further:

In [54]:
# Specify the parameters to be used for layer-wise pruning-
pruning_params_constantsparsity = {
    'pruning_schedule': sparsity.ConstantSparsity(
        target_sparsity=0.4633, begin_step=100,
        end_step = end_step, frequency=100
    )
}

In [55]:
# Instantiate a Neural Network model to be pruned-
pruned_model = pruned_nn(pruning_params_constantsparsity)

In [56]:
# Load weights of winning ticket from previous round to be PRUNED-
pruned_model.load_weights("Trained_Weights-Error_Recreation.h5")

In [57]:
print("Pruning parameters of (GradientTape) trained model\n")

# Train pruned Neural Network-
history_pruned_gt = pruned_model.fit(
    x = X_train, y = y_train,
    batch_size = batch_size,
    epochs = epochs,
    verbose = 1,
    callbacks = callbacks,
    validation_data = (X_test, y_test),
    shuffle = True
)

Pruning parameters of (GradientTape) trained model

Train on 60000 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50


In [58]:
# Strip pruning wrapper off-
pruned_model_stripped = sparsity.strip_pruning(pruned_model)

In [59]:
# Count number of non-zero parameters [Round - 2]-
pruned_sum_params2 = 0
    
for layer in pruned_model_stripped.trainable_weights:
    pruned_sum_params2 += tf.math.count_nonzero(layer, axis = None).numpy()

print("\nPruned model, # of non-zero trainable parameters = {0}".format(pruned_sum_params2))


Pruned model, # of non-zero trainable parameters = 143280


In [60]:
# Sanity-check: confirm that 46.33% of the weights are actually pruned away from the network-
print("\n% of weights pruned away = {0:.2f}%\n".format( \
    (orig_sum_params - pruned_sum_params2) / orig_sum_params * 100))


% of weights pruned away = 46.26%



In [61]:
# Save weights of PRUNED and Trained model BEFORE stripping-
pruned_model.save_weights("Pruned_Weights-Error_Recreation.h5", overwrite = True)

### Create mask to maintain sparsity of pruned model:

In [62]:
# Instantiate a new neural network model for which, the mask is to be created,
# according to the paper-
mask_model = pruned_nn(pruning_params_unpruned)

In [63]:
# Load weights of PRUNED model-
mask_model.load_weights("Pruned_Weights-Error_Recreation.h5")

In [64]:
# Strip the model of its pruning parameters-
mask_model_stripped = sparsity.strip_pruning(mask_model)

In [65]:
# For each layer, for each weight which is 0, leave it, as is.
# And for weights which survive the pruning,reinitialize it to ONE (1)-

for wts in mask_model_stripped.trainable_weights:
    wts.assign(tf.where(tf.equal(wts, 0.), 0., 1.))

### Reset weights:
Weights which are zero are left as it is, but for non-zero weights, reset them to random weights when the model was initialized

In [66]:
# Instantiate a new neural network model for which, the weights are to be extracted, according to the paper-
winning_ticket_model = pruned_nn(pruning_params_unpruned)

In [67]:
# Load weights of PRUNED model-
winning_ticket_model.load_weights("Pruned_Weights-Error_Recreation.h5")

In [68]:
# Strip the model of its pruning parameters-
winning_ticket_model_stripped = sparsity.strip_pruning(winning_ticket_model)

In [69]:
# For each layer, for each weight which is 0, leave it, as is. And for weights which survive the pruning,
# reinitialize it to the value, the model received BEFORE it was trained and pruned-
for orig_wts, pruned_wts in zip(orig_model_stripped.trainable_weights, winning_ticket_model_stripped.trainable_weights):
    pruned_wts.assign(tf.where(tf.equal(pruned_wts, 0), pruned_wts, orig_wts))

In [70]:
# Save the weights (with pruning parameters) extracted to a file-
winning_ticket_model.save_weights("Winning_Ticket_Weights-Error_Recreation.h5", overwrite=True)

### Train pruned model (sparsity = 46.26%) using _GradientTape_

In [71]:
# Instantiate a model
model_gt = pruned_nn(pruning_params_unpruned)

In [72]:
# Load winning ticket (from above-)
model_gt.load_weights("Winning_Ticket_Weights-Error_Recreation.h5")

In [73]:
# Strip model of pruning parameters-
model_gt_stripped = sparsity.strip_pruning(model_gt)

In [74]:
print("\nnumber of trainable parameters in original model = {0}".format(orig_sum_params))
print("number of trainable parameters [in Round - 2] in pruned model = {0}\n".format(pruned_sum_params2))


number of trainable parameters in original model = 266610
number of trainable parameters [in Round - 2] in pruned model = 143280



In [75]:
# Define variables for manual Early Stopping-
best_val_loss = 1
loc_patience = 0

In [76]:
for epoch in range(num_epochs):
    
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
        break
        
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
        
    for x, y in train_dataset:
        # train_step(x, y)
        train_one_step(model_gt_stripped, mask_model_stripped, optimizer, x, y)

    for x_t, y_t in test_dataset:
        # test_step(x_t, y_t)
        test_step(model_gt_stripped, optimizer, x_t, y_t)

    template = 'Epoch {0}, Loss: {1:.4f}, Accuracy: {2:.4f}, Test Loss: {3:.4f}, Test Accuracy: {4:4f}'
    
    print(template.format(epoch + 1, 
                              train_loss.result(), train_accuracy.result()*100,
                              test_loss.result(), test_accuracy.result()*100))
    
    # Count number of non-zero parameters in each layer and in total-
    # print("layer-wise manner model, number of nonzero parameters in each layer are: \n")

    model_sum_params = 0
    
    for layer in model_gt_stripped.trainable_weights:
        # print(tf.math.count_nonzero(layer, axis = None).numpy())
        model_sum_params += tf.math.count_nonzero(layer, axis = None).numpy()
    
    print("Total number of trainable parameters = {0}\n".format(model_sum_params))

    
    # Code for manual Early Stopping:
    if np.abs(test_loss.result() < best_val_loss) >= minimum_delta:
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = test_loss.result()
        
        # reset 'loc_patience' variable-
        loc_patience = 0
        
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement


ValueError: in converted code:

    <ipython-input-44-d0ca499a4063>:29 train_one_step  *
        optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:435 apply_gradients
        self._create_slots(var_list)
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/adam.py:146 _create_slots
        self.add_slot(var, 'm')
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:587 add_slot
        initial_value=initial_value)
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.
