In [71]:
from my_ai_utils import *
import keras
import tensorflow as tf

In [72]:
# Example input data
batch_size = 10
sequence_length = 5
input_size = 2
hidden_size = 3

X = np.random.randn(batch_size, sequence_length, input_size).astype(np.float32)
y = np.random.randn(batch_size, hidden_size).astype(np.float32)

In [73]:
# Initialize the models
model = tf.keras.layers.LSTM(hidden_size, return_sequences=False, return_state=False)
model.build(input_shape=(None, sequence_length, input_size))
weights = model.get_weights()
custom_lstm = LSTM(in_features=input_size, hidden_features=hidden_size, load_weights=weights)

In [74]:
# Forward pass test
y_pred = model(X)
custom_y_pred = custom_lstm(X)
print("Forward pass output:", y_pred.shape, y_pred[0].numpy())  # Expected shape: (batch_size, output_size)
print("Forward custom pass output :", custom_y_pred.shape, custom_y_pred[0])

Forward pass output: (10, 3) [ 0.08472192  0.52326137 -0.07848136]
Forward custom pass output : (10, 3) [ 0.08472192  0.52326122 -0.07848134]


In [75]:
# Define a simple loss function
loss_fn = tf.keras.losses.MeanSquaredError()

# Compute the loss
with tf.GradientTape() as tape:
    y_pred = model(X)
    loss = loss_fn(y, y_pred)

print("Loss:", loss.numpy())

### custom
custom_loss, custom_gradient = Loss()(custom_y_pred, y)
print("Custom loss:", custom_loss)

Loss: 1.6117493
Custom loss: 1.6117494072131708


In [76]:
_, custom_params_updates = custom_lstm.backward(custom_gradient)
custom_params_updates = Optimizer(lr=0.01)(custom_params_updates, 0, step=0, epoch=0)
custom_lstm.update_params(custom_params_updates)

gradients = tape.gradient(loss, model.trainable_variables)
# print("Gradients brute :", gradients)
# print("custom Gradients brute :", custom_params_updates)

# Optionally, perform a gradient update to check if loss decreases
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# Perform a single optimization step
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=1>

In [77]:
# Compute the loss again to check if it has decreased
with tf.GradientTape() as tape:
    y_pred = model(X)
    new_loss = loss_fn(y, y_pred)

custom_y_pred = custom_lstm(X)
new_custom_loss, _ = Loss()(custom_y_pred, y)

print("New Loss:", new_loss.numpy())
print("New custom Loss:", new_custom_loss)

assert new_loss < loss, "Loss did not decrease after gradient update"
assert new_custom_loss < custom_loss, "Custom Loss did not decrease after gradient update"

New Loss: 1.6111672
New custom Loss: 1.6101004530777707
