In [6]:
# test: what's inside the hidden_states of the whole model
from transformers import AutoProcessor, AutoModel
import torch
from PIL import Image

# load model + processor
# model_name = "google/paligemma2-3b-pt-224"
# proc = AutoProcessor.from_pretrained(model_name)
# model = AutoModel.from_pretrained(
#     model_name, device_map="auto"
# )

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

model_id = "google/paligemma2-3b-mix-224"  # or your specific version
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to("cuda:0")

# dummy input
img = Image.new("RGB", (224, 224), color="gray")
enc = proc(images=img, text="<image>", return_tensors="pt").to("cuda:0")

with torch.inference_mode():
    out = model(**enc, output_hidden_states=True, return_dict=True)

print(type(out))
print(out.keys())          # if it's a ModelOutput, this shows named attributes
print([k for k in dir(out) if not k.startswith("_")])  # inspect methods/attrs


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


<class 'transformers.models.paligemma.modeling_paligemma.PaliGemmaCausalLMOutputWithPast'>
odict_keys(['logits', 'past_key_values', 'hidden_states', 'image_hidden_states'])
['attentions', 'clear', 'copy', 'fromkeys', 'get', 'hidden_states', 'image_hidden_states', 'items', 'keys', 'logits', 'loss', 'move_to_end', 'past_key_values', 'pop', 'popitem', 'setdefault', 'to_tuple', 'update', 'values']


In [11]:
model.language_model

Gemma2Model(
  (embed_tokens): Embedding(257216, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): 

In [7]:
out.hidden_states

(tensor([[[ 0.2427, -0.1656, -0.0941,  ...,  0.0923,  0.0095,  0.2427],
          [-0.1403,  0.2756,  0.0056,  ...,  0.2858, -0.2693, -0.1462],
          [-0.0837,  0.4461, -0.3173,  ...,  0.0874,  0.2473, -0.1706],
          ...,
          [ 0.2551, -0.0745, -0.1415,  ...,  0.1745, -0.0455,  0.3002],
          [ 0.0113, -0.2812,  1.0547,  ...,  0.6650,  0.1765, -0.4834],
          [-1.4941,  1.2891, -2.9883,  ..., -0.4805, -1.5703,  1.6055]]],
        device='cuda:0'),
 tensor([[[ 2.3324e-01, -3.3421e-01, -1.1617e-01,  ..., -5.4396e-01,
           -2.4852e-03,  6.8841e-01],
          [-4.7261e-01, -3.1818e-02, -6.7618e-04,  ..., -1.2205e-01,
           -1.3581e+00, -1.5657e-01],
          [-1.5893e-02,  8.9614e-02, -5.5212e-01,  ..., -7.1152e-01,
           -1.4815e+00,  4.5848e-01],
          ...,
          [ 3.6289e-01, -1.6670e-01, -7.5930e-01,  ..., -5.5617e-01,
           -2.4051e+00,  9.7328e-01],
          [-1.1761e-02, -1.6056e-01,  5.9141e-01,  ...,  2.5445e-02,
            5

In [8]:
out.image_hidden_states

tensor([[[ 5.0556e-03, -3.4501e-03, -1.9602e-03,  ...,  1.9237e-03,
           1.9697e-04,  5.0565e-03],
         [-2.9228e-03,  5.7411e-03,  1.1630e-04,  ...,  5.9539e-03,
          -5.6102e-03, -3.0450e-03],
         [-1.7445e-03,  9.2938e-03, -6.6108e-03,  ...,  1.8202e-03,
           5.1528e-03, -3.5535e-03],
         ...,
         [ 1.2053e-03,  1.2200e-03, -1.9339e-04,  ...,  2.3134e-03,
          -3.9507e-03, -3.0026e-03],
         [ 1.0556e-03,  2.7214e-03,  7.7434e-05,  ...,  5.4123e-03,
           1.0052e-03,  7.1102e-03],
         [ 5.3140e-03, -1.5527e-03, -2.9475e-03,  ...,  3.6345e-03,
          -9.4751e-04,  6.2551e-03]]], device='cuda:0')

In [9]:
out.hidden_states[0].shape, out.image_hidden_states.shape

(torch.Size([1, 258, 2304]), torch.Size([1, 256, 2304]))