In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.layers import Dense, Input, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras.optimizers import Adadelta, Adam, SGD
from tensorflow.keras.backend import clear_session
import numpy as np

In [12]:
class MaskingSGD(SGD):
    def set_masks(self, masks):
      self.masks = masks

    def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
        grads_and_vars = self._compute_gradients(
          loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
        if hasattr(self, 'masks'):
          grads_and_vars = [
              (tf.multiply(grad, mask), var) for mask, (grad, var) in zip(self.masks, grads_and_vars)
          ]
        return self.apply_gradients(grads_and_vars)

In [3]:
import numpy as np
from scipy.stats import multivariate_normal

x, y = np.mgrid[-1.0:1.0:30j, -1.0:1.0:30j]
# Need an (N, 2) array of (x, y) pairs.
xy = np.column_stack([x.flat, y.flat])

mu = np.array([0.0, 0.0])

sigma = np.array([.025, .025])
covariance = np.diag(sigma**2)

z = multivariate_normal.pdf(xy, mean=mu, cov=covariance)

In [4]:
#xy.shape
z.shape

(900,)

In [5]:
masks = [
    np.asarray([[1, 0, 1], [0, 1, 1]]),
    np.asarray([[1],[1],[0]])
]

In [16]:
input_layer = Input(shape=(2,))
d = Dense(3, activation="relu", use_bias=False)(input_layer)
o = Dense(1, activation="linear", use_bias=False)(d)
model = Model(input_layer, o)
model.summary()

opt = MaskingSGD()
model.compile(optimizer=opt, loss='mean_squared_error', metrics=['mean_squared_error'])

opt.set_masks(masks)

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_6 (Dense)             (None, 3)                 6         
                                                                 
 dense_7 (Dense)             (None, 1)                 3         
                                                                 
Total params: 9
Trainable params: 9
Non-trainable params: 0
_________________________________________________________________


In [17]:
model.get_weights()

[array([[-0.81048656, -0.12635726,  0.7718816 ],
        [ 0.1200515 , -0.02018356, -0.46926993]], dtype=float32),
 array([[-0.670567  ],
        [-0.25856918],
        [-1.2105925 ]], dtype=float32)]

In [18]:
model.fit(xy, z, epochs=1)



<keras.callbacks.History at 0x141423d71c0>

In [19]:
model.get_weights()

[array([[-0.77114314, -0.12635726,  0.6647044 ],
        [ 0.1200515 , -0.01912871, -0.40267518]], dtype=float32),
 array([[-0.6221142 ],
        [-0.25081754],
        [-1.2105925 ]], dtype=float32)]

In [None]:
model.save("test_model")



In [None]:
model_load = tf.keras.models.load_model("test_model", custom_objects={"MSGD": MaskingSGD})

OSError: ignored

In [None]:
model_load.optimizer

<keras.optimizers.optimizer_experimental.sgd.SGD at 0x7f92b635cb80>

In [None]:
tf.compat.v1.reset_default_graph() 
clear_session()

input_layer = Input(shape=(28, 28))
x = Flatten()(input_layer)
#x = Dense(units=10_000, activation='relu')(x)
x = Dense(units=200, activation="relu", use_bias=False)(x)
x = Dense(units=80, activation="relu", use_bias=False)(x)

prediction_layer = Dense(10, activation='softmax', use_bias=False)(x)

model = Model(inputs=input_layer, outputs=prediction_layer)
model.summary()
loss = sparse_categorical_crossentropy 
optimizer = MaskingSGD()
model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
model.save("mnist_before_training_maskingsgd.h5")

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28)]          0         
                                                                 
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 200)               156800    
                                                                 
 dense_1 (Dense)             (None, 80)                16000     
                                                                 
 dense_2 (Dense)             (None, 10)                800       
                                                                 
Total params: 173,600
Trainable params: 173,600
Non-trainable params: 0
_________________________________________________________________
