<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/vit/paligemma/utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from modelling_gemma import PaliGemmaForConditionalGeneration, PaliGemmaConfig
from transformers import AutoTokenizer
import json
import glob
from safetensors import safe_open
from typing import Tuple
import os

def load_hf_model(model_path: str, device: str) -> Tuple[PaliGemmaForConditionalGeneration, AutoTokenizer]:
  # Load the tokenizer
  tokenizer = Autotokenizer.from_pretrained(model_path, padding_side="right")
  assert tokenizer.padding_side == "right"

  # Find all the *.safetensors files
  safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))

  # ... and load them one by one in the tensors dictionary
  tensors = {}
  for safetensors_file in safetensors_files:
    with safe_open(safetensors_file, framework="pt", device="cpu") as f:
      for key in f.keys():
        tensors[key] = f.get_tensor(key)

  # Load the model's config
  with open(os.path.join(model_path, "config.json"), "r") as f:
    model_config_file = json.load(f)
    config = PaliGemmaConfig(**model_config_file)

  # Create the model using the configuration
  model = PaliGemmaForConditionalGeneration(config).to(device)

  # Load the state dict of the model
  model.load_state_dict(tensors, strict=False)

  # Tie weights
  model.tie_weights()

  return (model, tokenizer)
