# Convert JAX npz to safetensors

In [1]:
import os
import json
import torch
import argparse
import numpy as np
from huggingface_hub import login
import os
hf_token = os.getenv('HF_TOKEN')
login(token=hf_token)
import transformers

import struct
def convert_void_array(arr):
    """Convert void array to float32."""
    # Convert void array to bytes
    flat_bytes = arr.tobytes()
    # Interpret each 2-byte chunk as a bfloat16
    num_elements = len(flat_bytes) // 2
    floats = []
    
    for i in range(num_elements):
        # Extract 2 bytes and convert to float32
        two_bytes = flat_bytes[i*2:(i+1)*2]
        # Pad to 4 bytes (float32)
        padded = two_bytes + b'\x00\x00'
        # Convert to float32
        val = struct.unpack('f', padded)[0]
        floats.append(val)
    
    # Reshape to original shape
    return np.array(floats, dtype=np.float32).reshape(arr.shape)


jax_ckpt = "/data/austin/jax_ckpts"
output_dir = "/data/austin/jax_ckpts"
model_id = "google/paligemma-3b-pt-224"
args = argparse.Namespace(
    jax_ckpt=jax_ckpt,
	output_dir=output_dir,
	model_id=model_id,
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
print(f"Loading JAX checkpoint from {args.jax_ckpt}")
npz_path = os.path.join(args.jax_ckpt,"hf_state_dict.npz")
np_file = np.load(npz_path)

print(f"Converting JAX checkpoint dtype to float32")
params = {}
for k, v in np_file.items():
	if v.dtype == np.dtype('|V2'):
		print(f"Converting {k} from |V2 to float32")
		params[k] = torch.tensor(convert_void_array(v))
	else:
		params[k] = torch.tensor(v)

# check params
print(f"Verifying all params are float32")
for k, v in params.items():
	if v.dtype != torch.float32:
		print(k, v.dtype)

Loading JAX checkpoint from /data/austin/jax_ckpts
Converting JAX checkpoint dtype to float32
Converting vision_tower.vision_model.embeddings.position_embedding.weight from |V2 to float32
Verifying all params are float32


In [5]:
# Load config and model
config_path = os.path.join(os.path.dirname(npz_path), "hf_config.json")
print(f"Loading config from {config_path}")
with open(config_path) as f:
	config_dict = json.load(f)
config = transformers.PaliGemmaConfig(**config_dict)
model = transformers.PaliGemmaForConditionalGeneration(config)
print(f"Loaded model from config: {config}")


for k,v in model.state_dict().items():
	if not k in params:
		print(f"WARNING: {k} not found in JAX checkpoint")
		continue

Loading config from /data/austin/jax_ckpts/hf_config.json
Loaded model from config: PaliGemmaConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "google/paligemma-3b-pt-224",
  "_vocab_size": 257216,
  "architectures": [
    "PaliGemmaForConditionalGeneration"
  ],
  "bos_token_id": 2,
  "eos_token_id": 1,
  "hidden_size": 2048,
  "image_token_index": 256127,
  "model_type": "paligemma",
  "pad_token_id": 0,
  "projection_dim": 2048,
  "text_config": {
    "hidden_size": 2048,
    "intermediate_size": 16384,
    "model_type": "gemma",
    "num_attention_heads": 8,
    "num_hidden_layers": 18,
    "num_image_tokens": 196,
    "num_key_value_heads": 1,
    "torch_dtype": "float32",
    "vocab_size": 256128
  },
  "torch_dtype": "float32",
  "transformers_version": "4.48.1",
  "vision_config": {
    "hidden_size": 768,
    "intermediate_size": 3072,
    "model_type": "siglip_vision_model",
    "num_attention_heads": 12,
    "num_hidden_layers": 12,
    "num_image_tokens": 

In [6]:
print("Loading state dict into model")
missing, unexpected = model.load_state_dict(params, strict=False)
print(f"Missing keys: {missing}")
print(f"Unexpected keys: {unexpected}")

print(f"Saving model to {args.output_dir}")
model.save_pretrained(args.output_dir)


Loading state dict into model
Missing keys: []
Unexpected keys: []
Saving model to /data/austin/jax_ckpts


# Load model from safetensors

In [None]:
import torch
import requests
from PIL import Image
import transformers

model_path = "/data/austin/jax_ckpts/cambrian_gemma_datacomp_recap-alt-10M+datacomp_recap-recap-10M+datacomp_dense-ram_owlvitv2_20250112-07-matched_mix-divdesF_divdetF_10epoch_01-19_0123"
model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(model_path)
refer_id = "google/paligemma-3b-pt-224"
processor = transformers.PaliGemmaProcessor.from_pretrained(refer_id)
# transformers.PaliGemmaProcessor.from_pretrained(model_path)
processor.tokenizer.vocab_size, model.state_dict()["language_model.lm_head.weight"].shape,model.state_dict()["language_model.model.embed_tokens.weight"].shape


In [None]:
# change the lm_head to match the vocab size
corrected_state_dict = model.state_dict()
corrected_state_dict["language_model.lm_head.weight"] = corrected_state_dict["language_model.lm_head.weight"][:processor.tokenizer.vocab_size]
corrected_state_dict["language_model.model.embed_tokens.weight"] = corrected_state_dict["language_model.model.embed_tokens.weight"][:processor.tokenizer.vocab_size]
model.load_state_dict(corrected_state_dict)


In [None]:
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "caption en"
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
input_len = model_inputs["input_ids"].shape[-1]

In [None]:
with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)
