In [16]:
import tensorflow as tf

@tf.custom_gradient
def custom_square(x):
    # Forward computation
    y = x ** 2

    def grad(dy):
        # Custom gradient computation
        return 2 * x * dy

    return y, grad

# Example usage
x = tf.constant(3.0)
with tf.GradientTape() as tape:
    tape.watch(x)
    y = custom_square(x)

# Compute the gradient
dy_dx = tape.gradient(y, x)
print("y:", y.numpy())  # Output: y: 9.0
print("dy/dx:", dy_dx.numpy())  # Output: dy/dx: 6.0


y: 9.0
dy/dx: 6.0


In [17]:
import tensorflow as tf

@tf.custom_gradient
def custom_square(x):
    # Forward computation
    y = x ** 2

    def grad(dy):
        # Custom gradient computation
        return 2 * x * dy

    return y, grad

# Define a simple model with a single parameter
class SimpleModel(tf.Module):
    def __init__(self):
        self.weight = tf.Variable(1.0)

    def __call__(self, x):
        return custom_square(self.weight * x)

# Define a simple loss function (Mean Squared Error)
def loss_fn(y_true, y_pred):
    return tf.reduce_mean((y_true - y_pred) ** 2)

# Example usage
x = tf.constant(3.0)
y_true = tf.constant(9.0)

# Instantiate the model
model = SimpleModel()

# Compute the loss
with tf.GradientTape() as tape:
    y_pred = model(x)
    loss = loss_fn(y_true, y_pred)

# Compute the gradient of the loss with respect to the model's weight
grads = tape.gradient(loss, model.weight)
print("y_pred:", y_pred.numpy())  # Output: y_pred: 9.0 (or close, depending on the weight)
print("loss:", loss.numpy())  # Output: loss: 0.0 (or close, depending on the weight)
print("dL/dw:", grads.numpy())  # Gradient of the loss with respect to the weight


y_pred: 9.0
loss: 0.0
dL/dw: -0.0


In [18]:
import tensorflow as tf

@tf.custom_gradient
def custom_square(x):
    # Forward computation
    y = x ** 2

    def grad(dy):
        # Custom gradient computation
        return 2 * x * dy

    return y, grad

@tf.custom_gradient
def custom_regularization(weight):
    # Forward computation
    reg_loss = 1.0 / weight

    def grad(dy):
        # Custom gradient for the regularization loss
        return -1.0 / (weight ** 2) * dy

    return reg_loss, grad

# Define a simple model with a single parameter
class SimpleModel(tf.Module):
    def __init__(self):
        self.weight = tf.Variable(1.0)

    def __call__(self, x):
        return custom_square(self.weight * x)

# Define a simple loss function (Mean Squared Error)
def mse_loss_fn(y_true, y_pred):
    return tf.reduce_mean((y_true - y_pred) ** 2)

# Example usage
x = tf.constant(3.0)
y_true = tf.constant(9.0)

# Instantiate the model
model = SimpleModel()

# Compute the losses
with tf.GradientTape() as tape:
    y_pred = model(x)
    mse_loss = mse_loss_fn(y_true, y_pred)
    reg_loss = custom_regularization(model.weight)
    total_loss = mse_loss + reg_loss

# Compute the gradients of the total loss with respect to the model's weight
grads = tape.gradient(mse_loss, model.weight)
print("y_pred:", y_pred.numpy())  # Output: y_pred: 9.0 (or close, depending on the weight)
print("MSE loss:", mse_loss.numpy())  # Output: MSE loss
print("Regularization loss:", reg_loss.numpy())  # Output: Regularization loss
print("Total loss:", total_loss.numpy())  # Output: Total loss
print("dL/dw:", grads.numpy())  # Gradient of the total loss with respect to the weight


y_pred: 9.0
MSE loss: 0.0
Regularization loss: 1.0
Total loss: 1.0
dL/dw: -0.0
