In [1]:
# For Colab: Run **once** at the top of the notebook.
!pip uninstall -y torch torchvision torchaudio
!pip install --pre --index-url https://download.pytorch.org/whl/nightly/cu124 \
            torch==2.7.0.dev20250310+cu124 torchvision torchaudio

!pip install peft==0.11.1 transformers==4.45.0 accelerate safetensors
!pip install -qU "datasets>=2.18.0" "fsspec<2024.0"
!pip install -qU bitsandbytes==0.43.3 loralib einops "xformers<0.0.27"

# loralib: A PyTorch implementation of Low-Rank Adaptation (LoRA), a parameter-efficient approach to adapt a large pre-trained deep learning model.
# einops: A library that simplifies tensor operations.
# xformers: A collection of composable Transformer building blocks.

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/nightly/cu124
Collecting torch==2.7.0.dev20250310+cu124
  Downloading https://download.pytorch.org/whl/nightly/cu124/torch-2.7.0.dev20250310%2Bcu124-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/nightly/cu124/torchvision-0.22.0.dev20250226%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (6.2 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/nightly/cu124/torchaudio-2.6.0.dev20250226%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB

In [2]:
# Imports
import os
import bitsandbytes as bnb # A lightweight wrapper by Hugging Face (🤗) around CUDA custom functions, particularly 8-bit optimizers and quantization functions. It’s used to handle the quantization process in QLoRA.
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import transformers
from datasets import load_dataset
import huggingface_hub
from peft import ( # A library by 🤗 that enables parameter efficient fine tuning.
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from transformers import (
    AutoConfig,
    LlavaNextProcessor, LlavaNextForConditionalGeneration,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

print(torch.cuda.is_available())

True


In [3]:
# Login to hugging face
from google.colab import userdata
key = userdata.get('HF')
huggingface_hub.login(key)
print("Login Successful")

Login Successful


In [3]:
# Load train dataset
dataset = load_dataset('hamzamooraj99/PMC-VQA-1', split='train', streaming=False)
print(dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/40 [00:00<?, ?it/s]

Dataset({
    features: ['Figure_path', 'Question', 'Answer', 'Choice A', 'Choice B', 'Choice C', 'Choice D', 'Answer_label', 'image'],
    num_rows: 154253
})


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Quantization Config
bnb_cfg = BitsAndBytesConfig( # QLoRA default recipe
    load_in_4bit          = True,
    bnb_4bit_quant_type   = "nf4",
    bnb_4bit_use_double_quant = True,
    bnb_4bit_compute_dtype    = torch.float16,
)

In [5]:
# Load checkpoint model instead
ckpt_dir = "/content/drive/MyDrive/llava13b-batch-FT/checkpoint-5000"

# 1⃣  read the adapter’s config
# https://huggingface.co/docs/peft/en/package_reference/config#peft.PeftConfig
peft_cfg   = PeftConfig.from_pretrained(ckpt_dir)          # contains base_model_name_or_path

# 2⃣  get the frozen base model first
base_model = LlavaNextForConditionalGeneration.from_pretrained(
                 peft_cfg.base_model_name_or_path,
                 torch_dtype=torch.float16,
                 low_cpu_mem_usage=True,
                 trust_remote_code=True,
                 quantization_config=bnb_cfg,
                 device_map="auto",
             )

base_model = prepare_model_for_kbit_training(base_model)

# 3⃣  attach the LoRA adapter + load its weights
model = PeftModel.from_pretrained(base_model,
                                  ckpt_dir,
                                  is_trainable=True)
processor = LlavaNextProcessor.from_pretrained(peft_cfg.base_model_name_or_path)



The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


In [7]:
model_id = "llava-hf/llava-v1.6-vicuna-13b-hf"

# Load Model
model = LlavaNextForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    quantization_config=bnb_cfg,
    device_map="auto" # cuda or cpu
)
print('Model Loaded.')

processor = LlavaNextProcessor.from_pretrained(model_id)
#print("Skipped cell.")

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


config.json:   0%|          | 0.00/1.34k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/77.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/6 [00:00<?, ?it/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/2.02G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

Model Loaded.


preprocessor_config.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/176 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


In [None]:
# LoRA Configuration for initial training
# Guide: https://medium.com/@amodwrites/a-definitive-guide-to-qlora-fine-tuning-falcon-7b-with-peft-78f500a1f337
'''
from peft library
The model is then prepared for QLoRA using the `prepare_model_for_kbit_training()` function.
This function initializes the model for QLoRA by setting up the necessary configurations.

`r`: The rank of the update matrices. Lower rank results in smaller update matrices with fewer trainable parameters.
`lora_alpha`: LoRA scaling factor.
`target_modules`: The modules (for example, attention blocks) to apply the LoRA update matrices.
`lora_dropout`: Dropout probability of the LoRA layers.
`bias`: Specifies if the bias parameters should be trained. Can be ‘none’, ‘all’ or ‘lora_only’.
'''

model = prepare_model_for_kbit_training(model)

lora_cfg = LoraConfig(                              # typical 13 B recipe
    r               = 64,
    lora_alpha      = 16,
    lora_dropout    = 0.05,
    target_modules  = ["q_proj","k_proj","v_proj","o_proj",
                       "gate_proj","up_proj","down_proj"],  # MLP + attention
    bias            = "none",
    task_type       = "CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg)


Skipped cell.


In [6]:
from PIL import Image, ImageSequence

# Data Formatting Class
class VQADataset(torch.utils.data.Dataset):
    """
    Expects each row of `hf_ds` to contain:
      - image        : PIL.Image.Image
      - caption      : str
      - question     : str
      - choice_a/b/c/d : str
      - answer       : str
    """
    def __init__(self, hf_ds, processor): # Initialize dataset
        self.ds  = hf_ds
        self.proc = processor
        self.tok  = processor.tokenizer

    def __len__(self): # length function
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        image = row["image"]

        # craft the conversation
        prompt_txt = (
            "Based on the image and the caption, answer the following "
            "multiple-choice question by selecting the correct letter.\n"
            f"Question: {row['Question']}\n"
            f"{row['Choice A']}\n{row['Choice B']}\n"
            f"{row['Choice C']}\n{row['Choice D']}\n"
            #"SELECT ONLY THE LETTER CHOICE. SELECT ONLY THE LETTER CHOICE."
        )

        # user turn: text + image     assistant turn: *empty* (only a tag)
        prompt = [
            {"role": "user", "content": [
                {"type": "text",  "text": prompt_txt},
                {"type": "image"},                                  # PIL passed later
            ]},
            {"role": "assistant", "content": [
                {"type": "text",  "text": "Answer: "}               # generation starts here
            ]}
        ]

        ans = f"{row['Answer_label']}: {row['Answer']}"

        return prompt, image, ans

In [7]:
from torch.nn.utils.rnn import pad_sequence
ds = VQADataset(dataset, processor)

def llava_collate(batch):
    # batch[i][0] --> prompt, batch[i][1] --> PIL image, batch[i][2] --> ans

    # Get prompts and images for each entry in batch
    texts = [b[0] for b in batch]
    images = [b[1] for b in batch]

    prompts = ds.proc.apply_chat_template(
        texts, add_generation_prompt=True
    )

    # Process them with LLaVA processor
    item = ds.proc(
        images=images,
        text=prompts,
        return_tensors="pt",
        padding=True,       # pad text
        do_pad=True         # pad patch dim to max in this batch
    )

    # Tokenize the answers and create labels
    labels = []
    prompt_ids = []
    attn_masks = []
    for i, b in enumerate(batch):
        # Tokenize the answer string
        ans_ids = ds.tok(
            f"{b[2]}", add_special_tokens=False
        ).input_ids
        # Create prompt woth masked out tokens
        input_ids_prompt = item["input_ids"][i]
        masked_prompt = torch.full_like(input_ids_prompt, -100)
        attention_mask_ans = torch.ones_like(torch.tensor(ans_ids))

        # Edit attention mask to reflect new answers
        attn_masks.append(torch.cat([item['attention_mask'][i], attention_mask_ans]))
        # Create label from masked prompt and answer tokens
        labels.append(torch.cat([masked_prompt, torch.tensor(ans_ids)]))
        # Add answer tokens to input prompt
        prompt_ids.append(torch.cat([input_ids_prompt, torch.tensor(ans_ids)]))

    # Convert list to tensors
    PAD_ID      = ds.tok.pad_token_id          # pad to make batch equal length
    IGNORE_IDX  = -100                         # value used for masked-out labels

    prompt_ids  = pad_sequence(prompt_ids, batch_first=True, padding_value=PAD_ID)
    attn_masks  = pad_sequence(attn_masks, batch_first=True, padding_value=0)
    labels      = pad_sequence(labels,      batch_first=True, padding_value=IGNORE_IDX)

    # Final batched input to model
    out = {
        "pixel_values"  : item["pixel_values"],
        "input_ids"     : prompt_ids,
        "attention_mask": attn_masks,
        "image_sizes": item['image_sizes'],
        "labels"        : labels,
    }
    return out


In [None]:
# REMOVE BAD SAMPLES from train set
from torch.utils.data import Subset
train_ds = VQADataset(dataset, processor)
rm = {i for i in range(10120, 10160)}
for i in range(11400, 11440):
    rm.add(i)
for i in range(17280, 17320):
    rm.add(i)

# build a list of indices you want to keep
keep = [i for i in range(len(train_ds)) if i not in rm]  # drop bad records

# wrap the original dataset
train_subset = Subset(train_ds, keep)

In [None]:
# Training Loop with hugging face Trainer API
training_args  = TrainingArguments(
    output_dir   = "/content/drive/MyDrive/llava13b-batch-FT",
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 1,
    fp16 = True,
    gradient_checkpointing = True,
    save_steps = 100, # default is 500
    save_total_limit = 3, # deletes older checkpoints
    optim = "paged_adamw_32bit",
    logging_steps = 10,
    report_to = "none",
    num_train_epochs = 1,
    ignore_data_skip=False #True # will tell Trainer not to “fast-forward” the DataLoader
)

trainer = Trainer(
    model         = model,          # 13B LoRA-patched model from previous cell
    args          = training_args,
    train_dataset = train_subset,
    data_collator = llava_collate,
)

trainer.train(resume_from_checkpoint=True)
# trainer.train()

	save_steps: 100 (from args) != 500 (from trainer_state.json)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Step,Training Loss
4010,0.1761
4020,0.1706
4030,0.1677
4040,0.1289
4050,0.1516
4060,0.1495
4070,0.1478
4080,0.1689
4090,0.1851
4100,0.1505




In [8]:
def non_train_collate(batch):
    # batch[i][0] --> prompt, batch[i][1] --> PIL image

    # Get prompts and images for each entry in batch
    texts = [b[0] for b in batch]
    images = [b[1] for b in batch]

    prompts = ds.proc.apply_chat_template(
        texts, add_generation_prompt=True
    )

    # Process them with LLaVA processor
    item = ds.proc(
        images=images,
        text=prompts,
        return_tensors="pt",
        padding=True,       # pad text
        do_pad=True         # pad patch dim to max in this batch
    )

    # Final batched input to model
    out = {
        "pixel_values"  : item["pixel_values"],
        "input_ids"     : item["input_ids"],
        "attention_mask": item['attention_mask'],
        "image_sizes": item['image_sizes']
    }
    return out

In [None]:
# Load validation dataset
orig_val_dataset = load_dataset('hamzamooraj99/PMC-VQA-1', split='validation', streaming=False)

# remove bad samples
rm = {i for i in range(120, 131)}

# build a list of indices you want to keep
keep = [i for i in range(len(orig_val_dataset)) if i not in rm]  # drop bad records

# wrap the original dataset
val_dataset = Subset(orig_val_dataset, keep)
print(val_dataset)
total_amnt = len(val_dataset)
print(total_amnt)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

<torch.utils.data.dataset.Subset object at 0x7bfad7ed4d90>
22684


In [9]:
from torch.utils.data import Subset
test_dataset = load_dataset('hamzamooraj99/PMC-VQA-1', split='test', streaming=False)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

<torch.utils.data.dataset.Subset object at 0x7e2101e681d0>
50000


In [10]:
import re

# Validation or Testing Loop
def validate_or_test(start, val_amnt, ds):
    # var initialization
    correct = 0
    no_ans = 0

    # Val loop
    for i in range(start, val_amnt):

        # Get inputs
        inputs = non_train_collate([ds[i]])
        device = "cuda" if torch.cuda.is_available() else "cpu"
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # Autoregressively complete prompt
        output = model.generate(**inputs, max_new_tokens=25)

        # Format output
        output_string = processor.decode(output[0], skip_special_tokens=True)
        output_list = output_string.split('ASSISTANT:')
        generated_output = output_list[2]

        # Extract letter choice from output and answer
        letter = extract_selection(generated_output).strip()
        ans = extract_selection(ds[i][2]).strip()
        if letter == ans:
            correct += 1
        if letter == 'Z':
            no_ans += 1

        # Step tracking
        if (i+1)%10 == 0:
            print(f"{i+1} steps complete. {correct} were correct, that's {(correct/(i+1))*100}%. (no_ans={no_ans})")

    # Final totals
    print(f"Total Steps: {val_amnt}")
    print(f"Total Correct: {correct}")
    print(f"Total No Answer: {no_ans}")
    print(f"Final Accuracy: {correct/val_amnt}")
    print(f"Final No Answer Accuracy: {no_ans/val_amnt}")



# Regex to extract letter choice
def extract_selection(output):
    match = re.search(r'^\s*([A-G])\s*', output, re.MULTILINE)
    if match:
        letter_selection = match.group(1)
    else:
        letter_selection = 'Z'

    return letter_selection

# Dataset and function call
#val_ds = VQADataset(val_dataset, processor)
test_ds = VQADataset(test_dataset, processor)
validate_or_test(0, 1000, test_ds)

# Validate 3000k steps: 120 steps complete. 54 were correct, that's 45.0%. (no_ans=0) @3000 steps
'''
Test (Base Model w/additional prompt instructions):
Total Steps: 1000
Total Correct: 308
Total No Answer: 2
Final Accuracy: 0.308
Final No Answer Accuracy: 0.002
'''
###############################################
'''
Test (5k steps (20k samples) QLoRA fine-tuned):
Total Steps: 1000
Total Correct: 454
Total No Answer: 0
Final Accuracy: 0.454
Final No Answer Accuracy: 0.0
'''

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


10 steps complete. 6 were correct, that's 60.0%. (no_ans=0)
20 steps complete. 8 were correct, that's 40.0%. (no_ans=0)
30 steps complete. 11 were correct, that's 36.666666666666664%. (no_ans=0)
40 steps complete. 13 were correct, that's 32.5%. (no_ans=0)
50 steps complete. 18 were correct, that's 36.0%. (no_ans=0)
60 steps complete. 21 were correct, that's 35.0%. (no_ans=0)
70 steps complete. 25 were correct, that's 35.714285714285715%. (no_ans=0)
80 steps complete. 32 were correct, that's 40.0%. (no_ans=0)
90 steps complete. 36 were correct, that's 40.0%. (no_ans=0)
100 steps complete. 43 were correct, that's 43.0%. (no_ans=0)
110 steps complete. 45 were correct, that's 40.909090909090914%. (no_ans=0)
120 steps complete. 47 were correct, that's 39.166666666666664%. (no_ans=0)
130 steps complete. 49 were correct, that's 37.69230769230769%. (no_ans=0)
140 steps complete. 54 were correct, that's 38.57142857142858%. (no_ans=0)
150 steps complete. 59 were correct, that's 39.33333333333333

'\n\n'