In [3]:
from tensorflow import keras
import json

# Load the full model
epoch = 1
mlp_model = keras.models.load_model(f"checkpoint_epoch_{epoch}.keras")

# Load the metadata
with open(f"checkpoint_metadata_epoch_{epoch}.json", "r") as f:
    metadata = json.load(f)

print(f"Model loaded successfully!")
print(f"Epoch: {metadata['epoch']}, Validation Accuracy: {metadata['val_accuracy']}")


ValueError: The `{arg_name}` of this `Lambda` layer is a Python lambda. Deserializing it is unsafe. If you trust the source of the config artifact, you can override this error by passing `safe_mode=False` to `from_config()`, or calling `keras.config.enable_unsafe_deserialization().

In [2]:
from keras.models import Sequential
from keras import regularizers
from keras.layers import (Input, Conv2D, BatchNormalization, ReLU, MaxPooling2D, 
                          Flatten, Dense, Dropout, Lambda)
from keras.initializers import HeNormal
import keras.ops as K

def get_model(hidden_units, output_units, input_shape, rate, l2_coeff=1e-5):
    """
    Creates a face verification model that outputs normalized embeddings.
    """

    model = Sequential([Input(shape=input_shape)])

    # --- Convolutional blocks / Feature extraction backbone ---

    # note we use he kaiming initialization for the weights
    model.add(Conv2D(32, (3, 3), padding='same', kernel_initializer=HeNormal(),
                     kernel_regularizer=regularizers.l2(l2_coeff)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(MaxPooling2D((2, 2)))

    # 2nd block
    model.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=HeNormal(),
                     kernel_regularizer=regularizers.l2(l2_coeff)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(MaxPooling2D((2, 2)))

    # 3rd block
    model.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=HeNormal(),
                     kernel_regularizer=regularizers.l2(l2_coeff)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(MaxPooling2D((2, 2)))

    model.add(Flatten())

    # --- Fully connected layers ---
    for units in hidden_units:
        model.add(Dense(units, kernel_initializer=HeNormal(),
                        kernel_regularizer=regularizers.l2(l2_coeff)))
        model.add(BatchNormalization())
        model.add(ReLU())
        model.add(Dropout(rate))

    # --- Output layer + normalization ---
    model.add(Dense(output_units, kernel_initializer=HeNormal()))
    model.add(Lambda(lambda x: x / K.norm(x, axis=1, keepdims=True)))

    return model

model = get_model(
    hidden_units=[1024, 128],
    output_units=128,
    input_shape=(112, 112, 3),
    rate=0.5
)
model.load_state_dict(state_dict)

2025-03-03 23:21:20.039737: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-03 23:21:20.054969: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-03 23:21:20.059715: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-03 23:21:20.070945: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
I0000 00:00:1741044082.589729  405689 cuda_executor.c

AttributeError: 'Sequential' object has no attribute 'load_state_dict'

In [None]:
import torch

# Load the checkpoint
checkpoint = torch.load(f"checkpoint_epoch_{epoch}.pth")

# Load the model state dictionary
mlp_model.load_state_dict(checkpoint['model_state_dict'])

# Verify if model weights match the saved checkpoint
for key in checkpoint['model_state_dict']:
    if torch.equal(mlp_model.state_dict()[key], checkpoint['model_state_dict'][key]):
        print(f"✔ Weights for {key} loaded correctly")
    else:
        print(f"❌ Mismatch in {key}")
