In [1]:
import torch
from PIL import Image
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    AutoProcessor, CLIPModel
)
from transformers.modeling_outputs import BaseModelOutput

###############
# Load Config #
###############
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_DIR = "./caption_model"  # directory where training saved adapter + BART

##################
# Load Tokenizer #
##################

# Tokenizer used during training (ensures consistent vocabulary + decoding)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# DistilBART decoder with our fine-tuned language weights
# .eval(): disables dropout and makes inference deterministic
bart = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR).to(DEVICE).eval()

#######################################################
# Load Trained Adapter (CLIP → BART projection layer) #
#######################################################

# Load saved adapter weights
adapter_state = torch.load(f"{MODEL_DIR}/adapter.pt", map_location=DEVICE)

# Rebuild same Linear layer structure dynamically (no hard-coded dims):
# Maps CLIP embedding dimension → BART encoder dimension
adapter = torch.nn.Linear(
    adapter_state["weight"].shape[1],  # input dimension (CLIP patch dim)
    adapter_state["weight"].shape[0],  # output dimension (BART hidden dim)
).to(DEVICE)

# Load weights into this newly-constructed layer
adapter.load_state_dict(adapter_state)
adapter.eval()  # disable dropout here too

#############
# Load CLIP #
#############

# Same CLIP model used for feature extraction during training
CLIP_MODEL_ID = "openai/clip-vit-large-patch14"
processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE).eval()

######################
# Inference Function #
######################
def caption_image(img_path):
    '''
    Pipeline:
    1. Load & preprocess image for CLIP
    2. CLIP produces patch-token embeddings (spatial visual features)
    3. Adapter maps visual features into BART hidden space
    4. DistilBART decoder autoregressively generates caption tokens
    5. Decode output tokens into readable text caption

    Returns:
        A human-readable caption (string)
    '''

    # Load image via PIL and ensure consistent RGB format
    img = Image.open(img_path).convert("RGB")

    # Prepare image for CLIP model (resize, normalize, batch)
    inp = processor(images=img, return_tensors="pt").to(DEVICE)

    with torch.no_grad():  # Faster + no training memory
        # CLIP returns hidden states including [CLS] + patch embeddings
        out = clip_model.vision_model(inp["pixel_values"], output_hidden_states=True)

        # Drop CLS token at index 0: we want local patch features only
        # Normalize embeddings → stabilizes values across images
        vis = torch.nn.functional.normalize(out.last_hidden_state[:, 1:, :], p=2, dim=-1)

        # Project CLIP embeddings → BART encoder dimension
        enc_vis = adapter(vis.float())

        # Wrap into a valid transformer encoder output object
        enc_out = BaseModelOutput(last_hidden_state=enc_vis)

        # Generate caption via beam search:
        # num_beams = explore multiple best-candidate sentences
        gen = bart.generate(
            encoder_outputs=enc_out,
            num_beams=5,          # keeps 5 best candidate captions alive
            early_stopping=True,  # stop when EOS reached
            max_length=16,        # max caption length
            length_penalty=1.0    # balances shorter/longer sentences
        )

    # Convert token IDs back to a text string (remove <s> <pad> </s>)
    return tokenizer.decode(gen[0], skip_special_tokens=True)


  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
image_paths = ["tests/animals.jpg", "tests/desk.jpg", "tests/dogs.jpg", "tests/paper_boat.jpg", "tests/receipt.jpg", "tests/sinner_tennis.png"]
for path in image_paths:
    print(f"{path}: {caption_image(path)}")



tests/animals.jpg: a close up of a baby and a bunny on a tree branch
tests/desk.jpg: A desk with a computer and keyboard on it.
tests/dogs.jpg: Two brown and white dogs are standing in the grass.
tests/paper_boat.jpg: A red and white boat in a body of water.
tests/receipt.jpg: A box of Dunkin Donuts sitting on top of a table
tests/sinner_tennis.png: A man holding a tennis racquet on top of a tennis court
