## Introduction

In the previous notebook we very able to successfully build a Keras NLP Backbone model using the Hugging Face configuration.

The next step would be to assign pre-trained weights into the randomly initialized Backbone model. Before assigning the weights, I like to double check them.

In this notebook we build the KerasNLP Backbone from Kaggle's preset (with trained weights) and compare the weights of the model with the `safetensor` checkpoint of Hugging Face.

## Setup and Imports

In [None]:
!pip install -q _U safetensors
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

[31mERROR: Invalid requirement: '_U'[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m571.8/571.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m590.6/590.6 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m100.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m102.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m125.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m75.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m347.7/347.7 kB[0m [31m35.4 MB/s[0m eta [36m

In [None]:
import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

In [None]:
import keras
keras.config.set_dtype_policy("bfloat16")

import json
import numpy as np
from functools import partial

from safetensors import safe_open
from huggingface_hub import hf_hub_download

from keras_nlp.models import (
    PaliGemmaBackbone,
)

## Download Safetensor Files from HF Hub

In [None]:
hf_model_id = "google/paligemma-3b-pt-224"

transformers_config = hf_hub_download(
    repo_id=hf_model_id,
    filename="config.json"
)
safetensor_config = hf_hub_download(
    repo_id=hf_model_id,
    filename="model.safetensors.index.json"
)

config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

In [None]:
with open(transformers_config, "r") as f:
    transformers_config = json.load(f)

with open(safetensor_config, "r") as f:
    safetensor_config = json.load(f)

In [None]:
# Here we map safe tensor file names with the safe
# path to which they are downloaded
safetensor_files = {
    fname:hf_hub_download(repo_id=hf_model_id, filename=fname) for fname in set(safetensor_config['weight_map'].values())
}

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

## Download the Keras NLP backbone

In [None]:
kaggle_model_id = "pali_gemma_3b_224"
keras_backbone = PaliGemmaBackbone.from_preset(
    kaggle_model_id,
    load_weights=True,
)

Downloading from https://www.kaggle.com/api/v1/models/keras/paligemma/keras/pali_gemma_3b_224/1/download/model.safetensors...
Downloading from https://www.kaggle.com/api/v1/models/keras/paligemma/keras/pali_gemma_3b_224/1/download/model.safetensors.index.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/paligemma/keras/pali_gemma_3b_224/1/download/metadata.json...
100%|██████████| 143/143 [00:00<00:00, 172kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/paligemma/keras/pali_gemma_3b_224/1/download/config.json...
100%|██████████| 861/861 [00:00<00:00, 1.25MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/paligemma/keras/pali_gemma_3b_224/1/download/model.weights.h5...
100%|██████████| 5.45G/5.45G [04:55<00:00, 19.8MB/s]


## Check Keras and HF weights

The `check_keras_weight` function directly comes from [`set_keras_weight`](https://github.com/keras-team/keras-nlp/blob/be524fc3c2fe955b7977bcc49a72036eb7d92cae/keras_nlp/src/utils/transformers/safetensor_utils.py#L20) in the Keras NLP repository.

In [None]:
def check_keras_weight(
    safetensor_files,
    safetensor_config,
    keras_variable,
    hf_weight_key,
    hook_fn=None,
):
    safetensor_file = safetensor_files[
        safetensor_config["weight_map"][hf_weight_key]
    ]
    with safe_open(safetensor_file, framework="np") as f:
        hf_tensor = f.get_tensor(hf_weight_key)

        print(hf_weight_key)
        print(f"{hf_tensor.shape=}")
        print(f"{keras_variable.shape=}")

        if hook_fn:
            hf_tensor = hook_fn(hf_tensor, list(keras_variable.shape))

        np.testing.assert_allclose(
            keras_variable,
            hf_tensor,
            atol=1e-02,
            rtol=1e-02,
        )

port_weight = partial(
    check_keras_weight,
    safetensor_files=safetensor_files,
    safetensor_config=safetensor_config,
)

# Image Tower

## Embedding

In [None]:
port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").vision_embeddings.patch_embedding.bias,
    hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias",
)

port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").vision_embeddings.patch_embedding.kernel,
    hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight",
    hook_fn=lambda hf_tensor, keras_shape: np.transpose(
        hf_tensor,
        axes=(2, 3, 1, 0),
    ),
)

vision_tower.vision_model.embeddings.patch_embedding.weight
hf_tensor.shape=(1152, 3, 14, 14)
keras_variable.shape=TensorShape([14, 14, 3, 1152])


In [None]:
port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").vision_embeddings.position_embedding.embeddings,
    hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight",
)

vision_tower.vision_model.embeddings.position_embedding.weight
hf_tensor.shape=(256, 1152)
keras_variable.shape=TensorShape([256, 1152])


## Norms

In [None]:
port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").encoder_layer_norm.gamma,
    hf_weight_key="vision_tower.vision_model.post_layernorm.weight",
)

port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").encoder_layer_norm.beta,
    hf_weight_key="vision_tower.vision_model.post_layernorm.bias",
)

vision_tower.vision_model.post_layernorm.weight
hf_tensor.shape=(1152,)
keras_variable.shape=TensorShape([1152])
vision_tower.vision_model.post_layernorm.bias
hf_tensor.shape=(1152,)
keras_variable.shape=TensorShape([1152])


## Encoder

In [None]:
for index in range(keras_backbone.vit_encoder.get_layer("image_encoder").num_layers):

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].layer_norm_1.beta,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].layer_norm_1.gamma,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.weight",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].layer_norm_2.beta,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].layer_norm_2.gamma,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.weight",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].mlp_dense_1.kernel,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.weight",
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            hf_tensor,
            axes=(1, 0),
        ),
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].mlp_dense_1.bias,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].mlp_dense_2.kernel,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.weight",
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            hf_tensor,
            axes=(1, 0),
        ),
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].mlp_dense_2.bias,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.key_proj.bias,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.key_proj.kernel,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.weight",
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            hf_tensor,
            axes=(1, 0),
        ),
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.out_proj.bias,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.out_proj.kernel,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.weight",
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            hf_tensor,
            axes=(1, 0),
        ),
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.query_proj.bias,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.query_proj.kernel,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.weight",
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            hf_tensor,
            axes=(1, 0),
        ),
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.value_proj.bias,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.bias",
    )

    port_weight(
        keras_variable=keras_backbone.vit_encoder.get_layer("image_encoder").resblocks[index].attn.value_proj.kernel,
        hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.weight",
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            hf_tensor,
            axes=(1, 0),
        ),
    )

vision_tower.vision_model.encoder.layers.0.layer_norm1.bias
hf_tensor.shape=(1152,)
keras_variable.shape=TensorShape([1152])
vision_tower.vision_model.encoder.layers.0.layer_norm1.weight
hf_tensor.shape=(1152,)
keras_variable.shape=TensorShape([1152])
vision_tower.vision_model.encoder.layers.0.layer_norm2.bias
hf_tensor.shape=(1152,)
keras_variable.shape=TensorShape([1152])
vision_tower.vision_model.encoder.layers.0.layer_norm2.weight
hf_tensor.shape=(1152,)
keras_variable.shape=TensorShape([1152])
vision_tower.vision_model.encoder.layers.0.mlp.fc1.weight
hf_tensor.shape=(4304, 1152)
keras_variable.shape=TensorShape([1152, 4304])
vision_tower.vision_model.encoder.layers.0.mlp.fc1.bias
hf_tensor.shape=(4304,)
keras_variable.shape=TensorShape([4304])
vision_tower.vision_model.encoder.layers.0.mlp.fc2.weight
hf_tensor.shape=(1152, 4304)
keras_variable.shape=TensorShape([4304, 1152])
vision_tower.vision_model.encoder.layers.0.mlp.fc2.bias
hf_tensor.shape=(1152,)
keras_variable.shape=Tensor

## Multimodal Projection

In [None]:
port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_classifier").kernel,
    hf_weight_key="multi_modal_projector.linear.weight",
    hook_fn=lambda hf_tensor, keras_shape: np.transpose(
        hf_tensor,
        axes=(1, 0),
    ),
)

port_weight(
    keras_variable=keras_backbone.vit_encoder.get_layer("image_classifier").bias,
    hf_weight_key="multi_modal_projector.linear.bias",
)

# Language Tower

# Embedding

In [None]:
index = 0

In [None]:
for index in range(keras_backbone.num_layers):
    decoder_layer = keras_backbone.transformer_layers[index]

    # Norm layers
    port_weight(
        keras_variable=decoder_layer.pre_attention_norm.scale,
        hf_weight_key=f"language_model.model.layers.{index}.input_layernorm.weight",
    )
    port_weight(
        keras_variable=decoder_layer.pre_ffw_norm.scale,
        hf_weight_key=f"language_model.model.layers.{index}.post_attention_layernorm.weight",
    )

    # Attention layers
    port_weight(
        keras_variable=decoder_layer.attention.query_dense.kernel,
        hf_weight_key=f"language_model.model.layers.{index}.self_attn.q_proj.weight",
        # rearrange_patterns="(a c) b -> a b c",
        # rearrange_dims={"a": backbone.num_query_heads},
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            np.reshape(
                hf_tensor,
                (keras_shape[0], keras_shape[2], keras_shape[1]),
            ),
            axes=(0, 2, 1),
        ),
    )
    port_weight(
        keras_variable=decoder_layer.attention.key_dense.kernel,
        hf_weight_key=f"language_model.model.layers.{index}.self_attn.k_proj.weight",
        # rearrange_patterns="(a c) b -> a b c",
        # rearrange_dims={"a": backbone.num_key_value_heads},
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            np.reshape(
                hf_tensor,
                (keras_shape[0], keras_shape[2], keras_shape[1]),
            ),
            axes=(0, 2, 1),
        ),
    )
    port_weight(
        keras_variable=decoder_layer.attention.value_dense.kernel,
        hf_weight_key=f"language_model.model.layers.{index}.self_attn.v_proj.weight",
        # rearrange_patterns="(a c) b -> a b c",
        # rearrange_dims={"a": backbone.num_key_value_heads},
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            np.reshape(
                hf_tensor,
                (keras_shape[0], keras_shape[2], keras_shape[1]),
            ),
            axes=(0, 2, 1),
        ),
    )
    port_weight(
        keras_variable=decoder_layer.attention.output_dense.kernel,
        hf_weight_key=f"language_model.model.layers.{index}.self_attn.o_proj.weight",
        # rearrange_patterns="c (a b) -> a b c",
        # rearrange_dims={"a": backbone.num_query_heads},
        hook_fn=lambda hf_tensor, keras_shape: np.transpose(
            np.reshape(
                hf_tensor,
                (keras_shape[2], keras_shape[0], keras_shape[1]),
            ),
            axes=(1, 2, 0),
        ),
    )

    # MLP layers
    port_weight(
        keras_variable=decoder_layer.gating_ffw.variables[0],
        hf_weight_key=f"language_model.model.layers.{index}.mlp.gate_proj.weight",
        # rearrange_patterns="b a -> a b",
        hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
    )
    port_weight(
        keras_variable=decoder_layer.gating_ffw_2.variables[0],
        hf_weight_key=f"language_model.model.layers.{index}.mlp.up_proj.weight",
        # rearrange_patterns="b a -> a b",
        hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
    )
    port_weight(
        keras_variable=decoder_layer.ffw_linear.variables[0],
        hf_weight_key=f"language_model.model.layers.{index}.mlp.down_proj.weight",
        # rearrange_patterns="b a -> a b",
        hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
    )

language_model.model.layers.0.input_layernorm.weight
hf_tensor.shape=(2048,)
keras_variable.shape=TensorShape([2048])
language_model.model.layers.0.post_attention_layernorm.weight
hf_tensor.shape=(2048,)
keras_variable.shape=TensorShape([2048])
language_model.model.layers.0.self_attn.q_proj.weight
hf_tensor.shape=(2048, 2048)
keras_variable.shape=TensorShape([8, 2048, 256])
language_model.model.layers.0.self_attn.k_proj.weight
hf_tensor.shape=(256, 2048)
keras_variable.shape=TensorShape([1, 2048, 256])
language_model.model.layers.0.self_attn.v_proj.weight
hf_tensor.shape=(256, 2048)
keras_variable.shape=TensorShape([1, 2048, 256])
language_model.model.layers.0.self_attn.o_proj.weight
hf_tensor.shape=(2048, 2048)
keras_variable.shape=TensorShape([8, 256, 2048])
language_model.model.layers.0.mlp.gate_proj.weight
hf_tensor.shape=(16384, 2048)
keras_variable.shape=TensorShape([2048, 16384])
language_model.model.layers.0.mlp.up_proj.weight
hf_tensor.shape=(16384, 2048)
keras_variable.shape=

## Norm

In [None]:
port_weight(
    keras_variable=keras_backbone.layer_norm.scale,
    hf_weight_key="language_model.model.norm.weight",
)

# Rest

In [None]:
port_weight(
    keras_variable=keras_backbone.token_embedding.embeddings,
    hf_weight_key="language_model.model.embed_tokens.weight",
    hook_fn=lambda hf_tensor, keras_shape: hf_tensor[:keras_shape[0]]
)

language_model.model.embed_tokens.weight
hf_tensor.shape=(257216, 2048)
keras_variable.shape=TensorShape([257152, 2048])


## Congratulations

Now that you have figured out all the weights match, go to Keras NLP and build your own `conver_model.py` script [here](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/src/utils/transformers).