In [13]:
from transformers import GPT2Config, ViTConfig, VisionEncoderDecoderConfig,VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor
import os
import sys

image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    image_encoder_model, text_decode_model
)
# model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

tokenizer = AutoTokenizer.from_pretrained(text_decode_model)
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)

# 设置一个少见字符为pad_token，并与eos_token区分
if "[PAD]" not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model.decoder.resize_token_embeddings(len(tokenizer))
print(f"Using pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

import datasets
ds_train = datasets.load_dataset("jxie/flickr8k", split='train')
ds_eval = datasets.load_dataset("jxie/flickr8k", split='validation')

sample = ds_train[0]

import numpy as np

def tokenization_fn(captions, max_target_length):
    # 手动在caption末尾加eos_token，确保一定有eos
    if isinstance(captions, str):
        captions = captions + tokenizer.eos_token
    else:
        captions = [c + tokenizer.eos_token for c in captions]

    inputs = tokenizer(
        captions,
        padding="max_length",
        truncation=True,
        max_length=max_target_length,
        return_tensors="np",
        add_special_tokens=False  # 已经手动加了eos，不再自动加
    )
    
    input_ids = inputs["input_ids"]

    # Replace pad_token_id with -100 for loss masking
    labels = np.where(input_ids == tokenizer.pad_token_id, -100, input_ids)
    return labels

def preprocess_function(examples):
    # Tokenize the captions
    labels = tokenization_fn(examples["caption_0"], max_target_length=128)
    
    # Extract image features
    images = feature_extractor(images=examples["image"], return_tensors="np")
    
    # Prepare the final output
    return {
        "pixel_values": images["pixel_values"],
        "labels": labels,
    }
    
label_ids = preprocess_function(sample)["labels"]
label_ids = np.array(label_ids)
label_ids = np.clip(label_ids, 0, tokenizer.vocab_size - 1).astype(np.int32)
decoded = tokenizer.decode(label_ids[0], skip_special_tokens=False)  
print(f"Decoded caption: {decoded}")
print(f"Label IDs: {label_ids[0]}")

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.10.cros

Using pad token: [PAD] (ID: 50257)
Decoded caption: A black dog is running after a white dog in the snow .<|endoftext|>!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Label IDs: [   32  2042  3290   318  2491   706   257  2330  3290   287   262  6729
   764 50256     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0 