In [5]:
from transformers import GPT2LMHeadModel, GPT2Config, AutoTokenizer
import torch

# Load GPT-2 with cross-attention
config = GPT2Config.from_pretrained("gpt2", add_cross_attention=True)
model = GPT2LMHeadModel.from_pretrained("gpt2", config=config, attn_implementation="eager")

# Tokenize caption input
tokenizer = AutoTokenizer.from_pretrained("gpt2")
input_ids = tokenizer("A cat", return_tensors="pt").input_ids

# Dummy encoder output (e.g., from ViT or ResNet)
encoder_hidden_states = torch.randn(1, 49, config.n_embd)

# Forward pass with attention output
outputs = model(
    input_ids=input_ids,
    encoder_hidden_states=encoder_hidden_states,
    output_attentions=True,
    return_dict=True
)
print(outputs.logits.shape)
# Extract cross-attn weights
cross_attentions = outputs.cross_attentions  # List of (batch, heads, tgt_len, src_len)

# Example: Cross-attn from last decoder layer
last_cross_attn = cross_attentions  # Shape: (1, num_heads, seq_len, encoder_seq_len)

print("Cross-attention shape:", len(last_cross_attn))


Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

torch.Size([1, 2, 50257])
Cross-attention shape: 12


In [10]:
print(tokenizer._pad_token_type_id)

0


In [11]:
tokenizer.decode(0)

'!'