In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input, Layer, Lambda
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam

tf.__version__

'2.4.1'

In [2]:
# Configuration options
feature_vector_length = 784
num_classes = 10

# Set the input shape
input_shape = (feature_vector_length,)

# Create the model
model = Sequential()
model.add(Dense(350, input_shape=input_shape, activation='relu'))
model.add(Dense(100, activation='relu', name='mid_layer'))
model.add(Dense(50, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))

In [3]:
y_true = Input(name='y_true', shape=(num_classes,)) 
y_pred = model.output

In [4]:
def customLoss(args):
    y_true, y_pred = args
    squared_difference = K.square(y_true - y_pred)
    return K.sum(squared_difference, axis=-1)

loss_out = Lambda(customLoss, output_shape=(1,), name='loss')([y_true, y_pred])

In [5]:
ins = [model.input] if type(model.input) is not list else model.input
trainable_model = Model(inputs=ins + [y_true], outputs=[loss_out, y_pred])

In [6]:
def rescale(model, optimizer, loss, scale_val=3):
    """Rescale `loss` w.r.t `scale_val` in mid_layer 
    """
    with tf.GradientTape() as tape:
        loss = loss # call loss fn.
    var_list = model.get_layer(name='mid_layer').trainable_variables 
    grads = tape.gradient(loss, var_list)
    if scale_val is not None:
        shared_coeff = 1.0 / scale_val
        grads = [shared_coeff * g for g in grads]
            
    return optimizer.apply_gradients(zip(grads, vars))

In [7]:
train_ins = ins + [y_true]
def loss_fn(x):
    return trainable_model(x)[0]
train_loss = loss_fn(train_ins)

In [8]:
losses = [
lambda y_true, y_pred: y_pred,  # loss is computed in Lambda layer
]

In [9]:
opt = Adam(learning_rate=0.01)
opt = rescale(trainable_model, opt, loss=train_loss, scale_val=3)

# Configure the model
model.compile(loss=losses, optimizer=opt)

AttributeError: ignored