In [None]:
!pip install --upgrade peft transformers accelerate

In [None]:
!pip install -U bitsandbytes 

In [3]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import get_scheduler

import os



# Configuration
BATCH_SIZE = 8
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 4
WARMUP_STEPS = 100
EVAL_STEPS = 200
GRADIENT_ACCUMULATION_STEPS = 4
MAX_LENGTH = 50
DATASET_SIZE = 5000 

In [4]:
ds = load_dataset("tomytjandra/h-and-m-fashion-caption-12k", split='train')
ds_small = ds.select(range(DATASET_SIZE))
print(f"Dataset size: {len(ds_small)}")


Dataset size: 5000


In [5]:
ds_small = ds.select(range(2500))
print(f"Dataset size: {len(ds_small)}")

Dataset size: 2500


In [2]:
ds_small

NameError: name 'ds_small' is not defined

In [5]:
ds_small = ds_small.shuffle(seed=42)
train_size = int(0.9 * len(ds_small))
train_dataset_raw = ds_small.select(range(train_size))
eval_dataset_raw = ds_small.select(range(train_size, len(ds_small)))

In [6]:
train_size

2250

In [7]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor, max_length=MAX_LENGTH):
        self.dataset = dataset
        self.processor = processor
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(
            images=item["image"], 
            text=item["text"],
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        # Remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        return encoding

In [8]:
def collate_fn(batch):
    processed_batch = {}
    
    # Process image inputs
    processed_batch["pixel_values"] = torch.stack([example["pixel_values"] for example in batch])
    
    # Process text inputs (labels)
    input_ids = [example["input_ids"] for example in batch]
    attention_mask = [example["attention_mask"] for example in batch]
    
    # Pad the input_ids and attention_mask
    max_length = max(len(ids) for ids in input_ids)
    
    # Create padded tensors
    padded_input_ids = torch.full((len(batch), max_length), processor.tokenizer.pad_token_id, dtype=torch.long)
    padded_attention_mask = torch.zeros((len(batch), max_length), dtype=torch.long)
    
    for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
        length = len(ids)
        padded_input_ids[i, :length] = ids
        padded_attention_mask[i, :length] = mask
    
    processed_batch["input_ids"] = padded_input_ids
    processed_batch["attention_mask"] = padded_attention_mask
    
    # Set labels for calculating loss
    processed_batch["labels"] = processed_batch["input_ids"].clone()
    
    return processed_batch

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [10]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [11]:
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", 
    quantization_config=quantization_config,
    device_map="auto"
)

processor_config.json:   0%|          | 0.00/68.0 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/882 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.56M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/122k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

In [12]:
model = prepare_model_for_kbit_training(model)

# Configure LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj"]  # Added v_proj and o_proj
)

In [13]:
# Use PeftModelForSeq2SeqLM instead of the default
from peft import get_peft_model, PeftModelForSeq2SeqLM

# Option 1: Let PEFT determine the model type
model = get_peft_model(model, lora_config)

# Option 2: Explicitly use Seq2Seq model type
model = PeftModelForSeq2SeqLM(model, lora_config)
model.print_trainable_parameters()

# Create datasets and dataloaders
train_dataset = ImageCaptioningDataset(train_dataset_raw, processor)
eval_dataset = ImageCaptioningDataset(eval_dataset_raw, processor)

trainable params: 7,864,320 || all params: 3,752,626,176 || trainable%: 0.2096




In [14]:
train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=BATCH_SIZE, 
    collate_fn=collate_fn
)

eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn
)

In [15]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# Learning rate scheduler
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=num_training_steps
)

In [6]:
len(train_dataloader)

NameError: name 'train_dataloader' is not defined

In [17]:
num_training_steps

1128

In [None]:
model.train()
global_step = 0
best_eval_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch: {epoch+1}/{NUM_EPOCHS}")
    
    # Training
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_dataloader):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        # Forward pass - Remove inputs_embeds which is causing the error
        outputs = model.base_model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels,
            return_dict=True
        )
        
        # Calculate loss
        loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
        total_loss += loss.item()
        
        # Backward pass
        loss.backward()
        
        # Update weights every GRADIENT_ACCUMULATION_STEPS
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            # Log metrics
            if (global_step + 1) % 10 == 0:
                avg_loss = total_loss * GRADIENT_ACCUMULATION_STEPS / 10
                print(f"Step {global_step+1}/{num_training_steps}, Loss: {avg_loss:.4f}, LR: {lr_scheduler.get_last_lr()[0]:.8f}")
                total_loss = 0
            
            global_step += 1
        
        # Evaluation
        if global_step > 0 and global_step % EVAL_STEPS == 0:
            print("\nRunning evaluation...")
            model.eval()
            eval_loss = 0
            eval_steps = 0
            
            with torch.no_grad():
                for eval_batch in eval_dataloader:
                    # Move batch to device
                    input_ids = eval_batch["input_ids"].to(device)
                    attention_mask = eval_batch["attention_mask"].to(device)
                    pixel_values = eval_batch["pixel_values"].to(device)
                    labels = eval_batch["labels"].to(device)
                    
                    # Forward pass - consistent with the fix above
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        pixel_values=pixel_values,
                        labels=labels,
                        return_dict=True
                    )
                    
                    eval_loss += outputs.loss.item()
                    eval_steps += 1
            
            avg_eval_loss = eval_loss / eval_steps
            print(f"Evaluation Loss: {avg_eval_loss:.4f}")
            
            
            # Generate example caption
            if len(eval_dataset) > 0:
                example = eval_dataset[0]
                pixel_values = example["pixel_values"].unsqueeze(0).to(device)
                
                # Generate with different parameters
                with torch.no_grad():
                    # Standard generation
                    standard_ids = model.generate(
                        pixel_values=pixel_values,
                        max_new_tokens=50
                    )
                    standard_caption = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]
                    
                    # Generation with sampling
                    sampled_ids = model.generate(
                        pixel_values=pixel_values,
                        max_new_tokens=50,
                        do_sample=True,
                        top_k=50,
                        top_p=0.9,
                        temperature=0.7
                    )
                    sampled_caption = processor.batch_decode(sampled_ids, skip_special_tokens=True)[0]
                
                print(f"Example standard caption: {standard_caption}")
                print(f"Example sampled caption: {sampled_caption}")
                
                # Get ground truth
                example_input_ids = example["input_ids"]
                gt_caption = processor.tokenizer.decode(example_input_ids, skip_special_tokens=True)
                print(f"Ground truth caption: {gt_caption}")
                
                wandb.log({
                    "example_standard": standard_caption,
                    "example_sampled": sampled_caption,
                    "example_ground_truth": gt_caption
                }, step=global_step)
            
            # Save model checkpoint if it's the best so far
            if avg_eval_loss < best_eval_loss:
                best_eval_loss = avg_eval_loss
                print(f"New best model with eval loss: {best_eval_loss:.4f}")
                
                # Save LoRA weights
                output_dir = f"./checkpoints/epoch-{epoch+1}_step-{global_step}_loss-{avg_eval_loss:.4f}"
                model.save_pretrained(output_dir)
                print(f"Model saved to {output_dir}")
            
            # Back to training mode
            model.train()

print("\nTraining complete!")

# Final model saving
model.save_pretrained("./fashion-captioning-final")

In [18]:
model.train()
global_step = 0
best_eval_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch: {epoch+1}/{NUM_EPOCHS}")
    
    # Training
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_dataloader):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        # Forward pass
        outputs = model.base_model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels,
            return_dict=True
        )
        
        # Calculate loss
        loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
        total_loss += loss.item()
        
        # Backward pass
        loss.backward()
        
        # Update weights every GRADIENT_ACCUMULATION_STEPS
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            # Log metrics
            if (global_step + 1) % 10 == 0:
                avg_loss = total_loss * GRADIENT_ACCUMULATION_STEPS / 10
                print(f"Step {global_step+1}/{num_training_steps}, Loss: {avg_loss:.4f}, LR: {lr_scheduler.get_last_lr()[0]:.8f}")
                total_loss = 0
            
            global_step += 1
        
        # Evaluation
        if global_step > 0 and global_step % EVAL_STEPS == 0:
            print("\nRunning evaluation...")
            model.eval()
            eval_loss = 0
            eval_steps = 0
            
            with torch.no_grad():
                for eval_batch in eval_dataloader:
                    # Move batch to device
                    input_ids = eval_batch["input_ids"].to(device)
                    attention_mask = eval_batch["attention_mask"].to(device)
                    pixel_values = eval_batch["pixel_values"].to(device)
                    labels = eval_batch["labels"].to(device)
                    
                    # Forward pass - FIXED: use base_model.forward to be consistent with training
                    outputs = model.base_model.forward(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        pixel_values=pixel_values,
                        labels=labels,
                        return_dict=True
                    )
                    
                    eval_loss += outputs.loss.item()
                    eval_steps += 1
            
            avg_eval_loss = eval_loss / eval_steps
            print(f"Evaluation Loss: {avg_eval_loss:.4f}")
            
            
            # Generate example caption
            if len(eval_dataset) > 0:
                example = eval_dataset[0]
                pixel_values = example["pixel_values"].unsqueeze(0).to(device)
                
                # Generate with different parameters
                with torch.no_grad():
                    # Standard generation
                    standard_ids = model.generate(
                        pixel_values=pixel_values,
                        max_new_tokens=50
                    )
                    standard_caption = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]
                    
                    # Generation with sampling
                    sampled_ids = model.generate(
                        pixel_values=pixel_values,
                        max_new_tokens=50,
                        do_sample=True,
                        top_k=50,
                        top_p=0.9,
                        temperature=0.7
                    )
                    sampled_caption = processor.batch_decode(sampled_ids, skip_special_tokens=True)[0]
                
                print(f"Example standard caption: {standard_caption}")
                print(f"Example sampled caption: {sampled_caption}")
                
                # Get ground truth
                example_input_ids = example["input_ids"]
                gt_caption = processor.tokenizer.decode(example_input_ids, skip_special_tokens=True)
                print(f"Ground truth caption: {gt_caption}")
            
            # Save model checkpoint if it's the best so far
            if avg_eval_loss < best_eval_loss:
                best_eval_loss = avg_eval_loss
                print(f"New best model with eval loss: {best_eval_loss:.4f}")
                
                # Save LoRA weights
                output_dir = f"./checkpoints/epoch-{epoch+1}_step-{global_step}_loss-{avg_eval_loss:.4f}"
                model.save_pretrained(output_dir)
                print(f"Model saved to {output_dir}")
            
            # Back to training mode
            model.train()

print("\nTraining complete!")

# Final model saving
model.save_pretrained("./fashion-captioning-final")


Epoch: 1/4


  return fn(*args, **kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step 10/1128, Loss: 43.4630, LR: 0.00002000
Step 20/1128, Loss: 38.6788, LR: 0.00004000
Step 30/1128, Loss: 25.4482, LR: 0.00006000
Step 40/1128, Loss: 15.8447, LR: 0.00008000
Step 50/1128, Loss: 11.9168, LR: 0.00010000
Step 60/1128, Loss: 9.2244, LR: 0.00012000
Step 70/1128, Loss: 7.4855, LR: 0.00014000

Epoch: 2/4
Step 80/1128, Loss: 6.1902, LR: 0.00016000
Step 90/1128, Loss: 5.3047, LR: 0.00018000
Step 100/1128, Loss: 4.7182, LR: 0.00020000
Step 110/1128, Loss: 4.3414, LR: 0.00019995
Step 120/1128, Loss: 3.9023, LR: 0.00019981
Step 130/1128, Loss: 3.7749, LR: 0.00019958
Step 140/1128, Loss: 3.6923, LR: 0.00019925

Epoch: 3/4
Step 150/1128, Loss: 3.5086, LR: 0.00019883
Step 160/1128, Loss: 3.3550, LR: 0.00019832
Step 170/1128, Loss: 3.1940, LR: 0.00019772
Step 180/1128, Loss: 3.1261, LR: 0.00019703
Step 190/1128, Loss: 3.0168, LR: 0.00019624
Step 200/1128, Loss: 2.9336, LR: 0.00019537

Running evaluation...
Evaluation Loss: 0.6928
Example standard caption: solid dark blue long-sleeve

In [None]:
def generate_caption(image_path, model, processor, max_new_tokens=50, do_sample=True):
    # Load and process the image
    image = Image.open(image_path).convert('RGB')
    
    # Process the image
    inputs = processor(images=image, return_tensors="pt").to(model.device)
    
    # Generate caption
    with torch.no_grad():
        # Standard generation
        standard_ids = model.generate(
            pixel_values=inputs.pixel_values,
            max_new_tokens=max_new_tokens
        )
        standard_caption = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]
        
        # Generation with sampling for more creative captions
        if do_sample:
            sampled_ids = model.generate(
                pixel_values=inputs.pixel_values,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_k=50,
                top_p=0.9,
                temperature=0.7
            )
            sampled_caption = processor.batch_decode(sampled_ids, skip_special_tokens=True)[0]
            return standard_caption, sampled_caption
        
        return standard_caption

# 4. Test with an example image
image_path = 
standard_caption, sampled_caption = generate_caption(image_path, model, processor)

print(f"Standard caption: {standard_caption}")
print(f"Sampled caption: {sampled_caption}")

In [None]:
# First load the dataset as you've done


# Modify your function to accept dataset examples instead of image paths
def generate_caption_from_dataset(dataset_example, model, processor, max_new_tokens=50, do_sample=True):
    # Get the image from the dataset example
    # First, check what fields are available
    print(f"Dataset example keys: {dataset_example.keys()}")
    
    # Access the image - the field might be called 'image', 'pixel_values', or something else
    # You'll need to adjust this based on the actual dataset structure
    image = ds_small['image']  # Adjust this key as needed
    
    # Process the image
    inputs = processor(images=image, return_tensors="pt").to(model.device)
    
    # Generate caption
    with torch.no_grad():
        # Standard generation
        standard_ids = model.generate(
            pixel_values=inputs.pixel_values,
            max_new_tokens=max_new_tokens
        )
        standard_caption = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]
        
        # Generation with sampling for more creative captions
        if do_sample:
            sampled_ids = model.generate(
                pixel_values=inputs.pixel_values,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_k=50,
                top_p=0.9,
                temperature=0.7
            )
            sampled_caption = processor.batch_decode(sampled_ids, skip_special_tokens=True)[0]
            return standard_caption, sampled_caption
        
        return standard_caption

# Test with a few examples from the dataset
for i in range(3):  # Try 3 examples
    example = ds_small[i+500]
    
    # Generate captions
    standard_caption, sampled_caption = generate_caption_from_dataset(example, model, processor)
    
    # Print results
    print(f"\nImage {i}:")
    print(f"Standard caption: {standard_caption}")
    print(f"Sampled caption: {sampled_caption}")
    
    # If the dataset has ground truth captions, print them for comparison
    if 'caption' in example:
        print(f"Ground truth: {example['caption']}")

Dataset example keys: dict_keys(['text', 'image'])


In [9]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
import torch
from PIL import Image

# Path to your adapter files
adapter_path = "fashion-captioning-final"

# Set up quantization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load the processor from the base model
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")

# Load the base model
base_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b",
    quantization_config=quantization_config,
    device_map="auto"
)

# Load your adapter weights onto the base model
model = PeftModel.from_pretrained(base_model, adapter_path)
model.eval()

# Path to your test image
image_path = ds_small['image'][588]  # Replace with actual path

# Load and process the image
image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt").to(model.device)

# Generate caption
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        num_beams=5,
    )

# Decode the caption
caption = processor.decode(outputs[0], skip_special_tokens=True)
print(f"Image: {image_path}")
print(f"Caption: {caption}")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



AttributeError: read

In [10]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
import torch
import gc

# Path to your adapter files
adapter_path = "fashion-captioning-final"

# Set up quantization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load the processor from the base model
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")

# Load the base model
base_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b",
    quantization_config=quantization_config,
    device_map="auto"
)

# Load your adapter weights onto the base model
model = PeftModel.from_pretrained(base_model, adapter_path)
model.eval()

# Get the image directly from your dataset
image = ds_small['image'][0]

# Check what type the image is and process accordingly
if hasattr(image, 'convert'):  # It's a PIL Image
    # Just use it directly
    pass
elif isinstance(image, torch.Tensor):  # It's a tensor
    # No need to use Image.open, just use the tensor
    pass
else:
    # If it's a numpy array, convert to tensor or PIL Image
    from PIL import Image
    import numpy as np
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image).convert('RGB')

# Process the image
inputs = processor(images=image, return_tensors="pt").to(model.device)

# Generate caption
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        num_beams=5,
    )

# Clear some memory
del inputs
torch.cuda.empty_cache()
gc.collect()

# Decode the caption
caption = processor.decode(outputs[0], skip_special_tokens=True)
print(f"Caption: {caption}")

# More memory cleanup
del outputs
torch.cuda.empty_cache()
gc.collect()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Caption: long-sleeve t-shirt - charcoal



0

In [13]:
def generate_caption(image, model, processor, max_new_tokens=50, do_sample=True):
    """
    Generate captions for an image
    image: Can be either a file path or a direct image object from dataset
    """
    # Check if input is a path string or already an image
    if isinstance(image, str):
        # It's a file path, open the image
        image = Image.open(image).convert('RGB')
    else:
        # It's already an image object from dataset
        # Check if it needs conversion
        if hasattr(image, 'convert'):
            # It's a PIL Image, just ensure it's RGB
            image = image.convert('RGB')
        elif isinstance(image, torch.Tensor):
            # If it's a tensor, no need for Image.open
            pass
        else:
            # If it's a numpy array, convert to PIL Image
            import numpy as np
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image).convert('RGB')
    
    # Process the image
    inputs = processor(images=image, return_tensors="pt").to(model.device)
    
    # Generate caption
    with torch.no_grad():
        # Standard generation
        standard_ids = model.generate(
            pixel_values=inputs.pixel_values,
            max_new_tokens=max_new_tokens
        )
        standard_caption = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]
        
        # Generation with sampling for more creative captions
        if do_sample:
            sampled_ids = model.generate(
                pixel_values=inputs.pixel_values,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_k=50,
                top_p=0.9,
                temperature=0.7
            )
            sampled_caption = processor.batch_decode(sampled_ids, skip_special_tokens=True)[0]
            
            # Clean up memory
            del inputs, standard_ids, sampled_ids
            torch.cuda.empty_cache()
            gc.collect()
            
            return standard_caption, sampled_caption
        
        # Clean up memory
        del inputs, standard_ids
        torch.cuda.empty_cache()
        gc.collect()
        
        return standard_caption

# Now use it with your dataset image
image_obj = ds_small['image'][500]  # Get image from dataset
standard_caption, sampled_caption = generate_caption(image_obj, model, processor)
print(f"Standard caption: {standard_caption}")
print(f"Sampled caption: {sampled_caption}")

Standard caption: a white shirt with a button down collar

Sampled caption: the button down shirt in white



In [15]:
ds_small['text'][500] 

'solid white straight-cut shirt in a cotton weave with a collar chest pocket concealed buttons down the front and long sleeves with buttoned cuffs'

In [16]:
def generate_caption(image, model, processor, max_new_tokens=50, do_sample=True):
    """
    Generate captions for an image - handles both file paths and dataset images
    """
    try:
        # Process the image based on its type
        if isinstance(image, str):
            # It's a file path
            image = Image.open(image).convert('RGB')
            pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(model.device)
            
        elif hasattr(image, 'convert'):
            # It's a PIL Image
            pixel_values = processor(images=image.convert('RGB'), return_tensors="pt").pixel_values.to(model.device)
            
        elif isinstance(image, torch.Tensor):
            # It's already a tensor, check if it needs reshaping
            if len(image.shape) == 3:  # [C, H, W]
                # Add batch dimension if needed
                pixel_values = image.unsqueeze(0).to(model.device)
            else:
                # Assume it's already properly formatted
                pixel_values = image.to(model.device)
                
        else:
            # Try to convert from numpy array
            import numpy as np
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image).convert('RGB')
                pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(model.device)
            else:
                raise TypeError(f"Unsupported image type: {type(image)}")
        
        # Generate caption
        with torch.no_grad():
            # Standard generation - match your training code exactly
            standard_ids = model.generate(
                pixel_values=pixel_values,
                max_new_tokens=max_new_tokens
            )
            standard_caption = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]
            
            # Generation with sampling for more creative captions
            if do_sample:
                sampled_ids = model.generate(
                    pixel_values=pixel_values,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    top_k=50,
                    top_p=0.9,
                    temperature=0.7
                )
                sampled_caption = processor.batch_decode(sampled_ids, skip_special_tokens=True)[0]
                
                # Clean up memory
                del pixel_values, standard_ids, sampled_ids
                torch.cuda.empty_cache()
                gc.collect()
                
                return standard_caption, sampled_caption
            
            # Clean up memory
            del pixel_values, standard_ids
            torch.cuda.empty_cache()
            gc.collect()
            
            return standard_caption
            
    except Exception as e:
        print(f"Error generating caption: {e}")
        import traceback
        traceback.print_exc()
        return "Error generating caption", "Error generating caption"

In [17]:
# Get an image from your dataset
image_obj = ds_small['image'][500]

# Generate captions
standard_caption, sampled_caption = generate_caption(image_obj, model, processor)

print(f"Standard caption: {standard_caption}")
print(f"Sampled caption: {sampled_caption}")

Standard caption: a white shirt with a button down collar

Sampled caption: a white shirt with pockets

