In [1]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [2]:
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
import gc
import csv
import os

In [3]:
gc.collect()
torch.cuda.empty_cache()

In [4]:

PROMPT = "Write a long descriptive caption for this image in a formal tone."
MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"


In [5]:
# Load JoyCaption
# bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
# device_map=0 loads the model into the first GPU
processor = AutoProcessor.from_pretrained(MODEL_NAME)
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="float16", device_map="auto")
llava_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



LlavaForConditionalGeneration(
  (model): LlavaModel(
    (vision_tower): SiglipVisionModel(
      (vision_model): SiglipVisionTransformer(
        (embeddings): SiglipVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
          (position_embedding): Embedding(729, 1152)
        )
        (encoder): SiglipEncoder(
          (layers): ModuleList(
            (0-26): 27 x SiglipEncoderLayer(
              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
              (self_attn): SiglipAttention(
                (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
              )
              (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_

In [8]:
DATASET_DIR = "test"
OUTPUT_CSV = "captions.csv"
results = []

# Build the conversation
convo = [
            {
                "role": "system",
                "content": "You are a helpful image captioner.",
            },
            {
                "role": "user",
                "content": PROMPT,
            },
        ]
# Format the conversation
        # WARNING: HF's handling of chat's on Llava models is very fragile.  This specific combination of processor.apply_chat_template(), and processor() works
        # but if using other combinations always inspect the final input_ids to ensure they are correct.  Often times you will end up with multiple <bos> tokens
        # if not careful, which can make the model perform poorly.
convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
assert isinstance(convo_string, str)

with torch.no_grad():
  for filename in os.listdir(DATASET_DIR):
        if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
            continue  # skip non-image files
        # Load image
        image_path = os.path.join(DATASET_DIR, filename)
        image = Image.open(image_path)


        # Process the inputs
        inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda')
        inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)

        # Generate the captions
        generate_ids = llava_model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            suppress_tokens=None,
            use_cache=True,
            temperature=0.6,
            top_k=None,
            top_p=0.9,
        )[0]

        # Trim off the prompt
        generate_ids = generate_ids[inputs['input_ids'].shape[1]:]

        # Decode the caption
        caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        caption = caption.strip()

        print(f"{filename}: {caption}")
        results.append([filename, caption])

    # Save results to CSV
with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["filename", "caption"])  # header
    writer.writerows(results)

print(f"Saved {len(results)} captions to {OUTPUT_CSV}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


102-dalmatians-puppies-to-the-rescue--1.png: This is a digitally created movie poster for "102 Dalmatians: puppies to the rescue." The foreground features two animated Dalmatian puppies with white fur and black spots, one with a red collar and a tag, and the other with a pink collar. Both puppies have expressive, cheerful faces. The background shows a large, menacing, animated figure of Cruella de Vil with an orange, wavy hairstyle and a sinister grin. The iconic Big Ben clock tower is visible to the right, set against a twilight sky. The title "102 Dalmatians" is prominently displayed in bold, white, and red letters with black spots, and "Puppies to the Rescue" is written in blue beneath it. The Disney logo is on the top left. The overall style is colorful and cartoonish.
asterix.png: This is a digital cartoon illustration featuring Asterix, a character from the popular French comic series. Asterix, a white anthropomorphic Gaul with white wings, blonde hair, and a large nose, is depic