In [1]:
!pip install datasets transformers bitsandbytes peft torch Pillow

Defaulting to user installation because normal site-packages is not writeable


In [2]:
from datasets import load_from_disk

# Load only the 'test' split
dataset = load_from_disk("modified_hl_final")

# Check a few samples
print(dataset[0])
print(f"Total samples in test set: {len(dataset)}")

  from .autonotebook import tqdm as notebook_tqdm


{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x500 at 0x7F8728488BB0>, 'captions': {'action': ["she is holding a horse's reins.", 'she is showing a horse.', 'she is holding the horse'], 'object': ['A female jockey leading a white horse in front of an audience.', 'A woman is holding a horse on a leash', "A white horse at a horse show with it's trainer.", 'Woman in a competition signalling horse to stand still', 'A woman training a white horse in a contest.'], 'rationale': ['she is competing in a horse show.', 'she wants to win.', 'to train with him'], 'scene': ['at a horse show.', 'it is a horse show.', 'at a training site']}, 'confidence': {'action': [5.0, 5.0, 5.0], 'rationale': [5.0, 5.0, 4.0], 'scene': [5.0, 5.0, 4.0]}, 'purity': {'action': [-0.583791196346283, -1.024692177772522, -0.5447486639022827], 'rationale': [-1.0663983821868896, -1.0827453136444092, -1.575655221939087], 'scene': [-0.9763333797454834, -0.8623560667037964, -1.1825969219207764]}, 'diversit

In [None]:
#Reference: https://github.com/huggingface/notebooks/blob/main/peft/Fine_tune_BLIP2_on_an_image_captioning_dataset_PEFT.ipynb
from datasets import load_from_disk
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
import torch
prompt_map = {
    "scene": "Where is the picture taken?",
    "action": "What is the subject doing?",
    "rationale": "Why is the subject doing it?"
}
# -----------------------------
# Config
# -----------------------------
model_id = "Salesforce/blip2-opt-2.7b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Quantization config (8-bit)
quant_config = BitsAndBytesConfig(load_in_8bit=True)

# Load processor and model
processor = Blip2Processor.from_pretrained(model_id)
model = Blip2ForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=quant_config,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Freeze vision + Q-Former
model.vision_model.requires_grad_(False)
model.qformer.requires_grad_(False)

# LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj"]
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# -----------------------------
# Dataset
# -----------------------------
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.samples = []
        self.processor = processor

        for example in dataset:
            image = example["image"]
            for axis in ["scene", "action", "rationale"]:
                prompt = prompt_map[axis]
                for caption, confidence in zip(
                    example["modified_captions"][axis],
                    example["confidence"][axis]
                ):
                    self.samples.append({
                        "image": image,
                        "prompt": prompt,
                        "caption": caption,
                        "confidence": confidence,  
                        "bin": int(confidence),    # 1-5 
                    })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        encoding = self.processor(images=sample["image"], return_tensors="pt")
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        prompt = sample["prompt"]
        caption = sample["caption"]

        prompt_ids = self.processor.tokenizer(prompt, add_special_tokens=False)["input_ids"]
        caption_ids = self.processor.tokenizer(
            " " + caption + self.processor.tokenizer.eos_token, add_special_tokens=False
        )["input_ids"]
        input_ids = prompt_ids + caption_ids
        labels = [-100] * len(prompt_ids) + caption_ids

        encoding["input_ids"] = torch.tensor(input_ids)
        encoding["labels"] = torch.tensor(labels)
        return encoding

    def get_bins(self):
        from collections import defaultdict
        bins = defaultdict(list)
        for idx, s in enumerate(self.samples):
            bins[s["bin"]].append(idx)
        return bins

def collate_fn(batch):
    # Pad input_ids and labels to the max length in batch
    input_ids = [example["input_ids"] for example in batch]
    labels = [example["labels"] for example in batch]
    pixel_values = torch.stack([example["pixel_values"] for example in batch])

    # Padding
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=-100
    )

    # Attention mask from input_ids
    attention_mask = (input_ids != processor.tokenizer.pad_token_id).long()

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask,
        "pixel_values": pixel_values,
    }


# Load dataset
import csv
avg_losses = []
batch_losses = [] 
hf_dataset = load_from_disk("modified_hl_final")
train_dataset = ImageCaptioningDataset(hf_dataset, processor)

bins = train_dataset.get_bins()
bin_sizes = {k: len(v) for k, v in bins.items()}
print(bin_sizes)
# Sample-weights:  1/size of bin
weights = []
for s in train_dataset.samples:
    b = s["bin"]
    weights.append(1.0 / bin_sizes[b])
weights = torch.DoubleTensor(weights)

from torch.utils.data import WeightedRandomSampler

sampler = WeightedRandomSampler(weights, num_samples=len(train_dataset), replacement=True)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    sampler=sampler,
    collate_fn=collate_fn,
)

# -----------------------------
# Training loop
# -----------------------------
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from tqdm import tqdm

for epoch in range(5): 
    print(f"\n Epoch {epoch + 1}")
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}", leave=False)
    epoch_losses = []
    
    for idx, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device, torch.float16)
        labels = batch["labels"].to(device)
    
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels 
        )
        
        loss = outputs.loss
        global_batch_idx = epoch * len(train_dataloader) + idx 
        batch_losses.append((global_batch_idx, loss.item()))
        epoch_losses.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

        # Update tqdm bar
        progress_bar.set_description(f"Epoch {epoch + 1} | Batch {idx + 1}")
        progress_bar.set_postfix({"loss": loss.item()})
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    avg_losses.append((epoch + 1, avg_loss))

    print(f"Epoch {epoch + 1} average loss: {avg_loss:.4f}")
    model.save_pretrained(f"SAMPLINGEOSLABELblip2-finetuned-full-epoch{epoch+1}")
    processor.save_pretrained(f"SAMPLINGEOSLABELblip2-finetuned-full-epoch{epoch+1}")
with open("SAMPLINGEOSLABELaverage_epoch_losses.csv", mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Epoch", "Average Loss"])
    writer.writerows(avg_losses)

print("Saved average losses to average_epoch_losses.csv")

with open("SAMPLINGEOSLABELbatch_losses.csv", mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Batch", "Loss"])
    writer.writerows(batch_losses)



2025-05-24 11:57:33.415068: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-24 11:57:33.427210: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748080653.442007 3172877 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748080653.446457 3172877 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748080653.457561 3172877 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

trainable params: 5,242,880 || all params: 3,750,004,736 || trainable%: 0.1398
{5: 75309, 4: 32592, 2: 2219, 3: 10552, 1: 799}

 Epoch 1


                                                                                      

Epoch 1 average loss: 1.5043

 Epoch 2


Epoch 2 | Batch 1561:  21%|██        | 1561/7592 [16:34<1:06:52,  1.50it/s, loss=1.62] 

In [None]:
total_captions = 0
for example in dataset:
    for axis in ["scene", "action", "rationale"]:
        total_captions += len(example["modified_captions"].get(axis, []))
print("Total captions:", total_captions)
