<a href="https://colab.research.google.com/github/merveenoyan/smollm/blob/main/vision/finetuning/Smol_VLM_FT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune SmolVLM on Visual Question Answering using Consumer GPU with QLoRA

In this notebook we will fine-tune SmolVLM VQAv2 dataset. With this notebook you can also fine-tune Idefics3, since both models have the same model class/architecture.

We will use some techniques in this notebook that will let you fine-tune the model on L4 with batch size of 4 only using around 16.4 GB of VRAM. We ran this notebook in that setup to test, but because we were able to afford A100 this notebook was last ran on an A100.

In [5]:
!pip install -q accelerate datasets peft bitsandbytes tensorboard torchvision num2words

In [6]:
!pip install -q flash-attn --no-build-isolation

We will push out model to Hub so we need to authenticate ourselves.

In [8]:
# from huggingface_hub import notebook_login

# notebook_login()

In this notebook we will not do full fine-tuning but use QLoRA method, which loads an adapter to the quantized version of the model, saving space. If you want to do full fine-tuning, set `USE_LORA` and `USE_QLORA` to False. If you want to do LoRA, set `USE_QLORA` to False and `USE_LORA` to True.

In [9]:
import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration

USE_LORA = False
USE_QLORA = False
SMOL = True

model_id = "HuggingFaceTB/SmolVLM-Base" if SMOL else "HuggingFaceM4/Idefics3-8B-Llama3" # original one
model_id = "HuggingFaceTB/SmolVLM-Instruct"


processor = AutoProcessor.from_pretrained(
    model_id
)

if USE_QLORA or USE_LORA:
    lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian"
    )
    lora_config.inference_mode = False
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )

    model = Idefics3ForConditionalGeneration.from_pretrained(
        model_id,
        quantization_config=bnb_config if USE_QLORA else None,
        _attn_implementation="flash_attention_2",
        device_map="auto"
    )
    model.add_adapter(lora_config)
    model.enable_adapters()
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, lora_config)
    print(model.get_nb_trainable_parameters())
else:
    model = Idefics3ForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        # _attn_implementation="flash_attention_2",
    ).to("cuda")

    # if you'd like to only fine-tune LLM
    for param in model.model.vision_model.parameters():
        param.requires_grad = False

Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 21399.51it/s]
`torch_dtype` is deprecated! Use `dtype` instead!


The model as is is holding 2.7 GB of GPU RAM 💗

## Loading the dataset and Preprocessing

We will load a small portion of the VQAv2 dataset. We are loading a small portion of the model for education purposes.

In [None]:
from datasets import load_dataset
ds = load_dataset('merve/vqav2-small', trust_remote_code=True)

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'merve/vqav2-small' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


In [None]:
split_ds = ds["validation"].train_test_split(test_size=0.5)
train_ds = split_ds["train"]

In [None]:
train_ds

Dataset({
    features: ['multiple_choice_answer', 'question', 'image'],
    num_rows: 10717
})

Let's write our data collating function. We will apply prompt template to have questions and answers together so model can learn to answer. Then we pass the formatted prompts and images to the processor which processes both.

In [10]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")]

def collate_fn(examples):
  texts = []
  images = []
  for example in examples:
      image = example["image"]
      if image.mode != 'RGB':
        image = image.convert('RGB')
      question = example["question"]
      answer = example["multiple_choice_answer"]
      messages = [
          {
              "role": "user",
              "content": [
                  {"type": "text", "text": "Answer briefly."},
                  {"type": "image"},
                  {"type": "text", "text": question}
              ]
          },
          {
              "role": "assistant",
              "content": [
                  {"type": "text", "text": answer}
              ]
          }
      ]
      text = processor.apply_chat_template(messages, add_generation_prompt=False)
      texts.append(text.strip())
      images.append([image])

  batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
  labels = batch["input_ids"].clone()
  labels[labels == processor.tokenizer.pad_token_id] = -100
  labels[labels == image_token_id] = -100
  batch["labels"] = labels

  return batch

## Inference

In [11]:
from PIL import Image
import requests

image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
bee_image = Image.open(requests.get(image_url, stream=True).raw)
bee_image = bee_image.resize((384, 384*4))

image_url = "https://www.usaoncanvas.com/images/low_res_image.jpg"
woman_image = Image.open(requests.get(image_url, stream=True).raw)
woman_image = woman_image.resize((384, 100))

# image_url = "https://images.unsplash.com/photo-1590272456521-1bbe160a18ce?q=80&w=627&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"

# Load the image from URL

# Resize the image to 384x384
# image = image.resize((384, 384))

# Check image resolution
width, height = woman_image.size
print(f"Image resolution: {width} x {height}")

messages = [
    {
        "role": "system",
        "content": [
            {"type": "text", "text": "Compare the following two images:"},
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Compare the following two images:"},
            # {"type": "image", "image": woman_image},
            # {"type": "image", "image": bee_image},
        ]
    },
    # {
    #     "role": "assistant",
    #     "content": [
    #         {"type": "text", "text": "They are the same."}
    #     ]
    # },
    # {
    #     "role": "user",
    #     "content": [
    #         {"type": "text", "text": "was your previous response correct?"},
    #     ]
    # },
]
# messages = [
#     {
#         "role": "user",
#         "content": [
#             {"type": "image", "url": woman_image},
#             {"type": "text", "text": "Can you describe the image? "},
#         ]
#     },
# ]

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)

# Count tokens
num_tokens = inputs['input_ids'].shape[1]
print(f"Number of tokens before generation: {num_tokens}")

# Count image tokens specifically
num_image_tokens = (inputs['input_ids'] == image_token_id).sum().item()
print(f"Number of image tokens before generation: {num_image_tokens}")

generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
generated_texts = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True,
)

print(f"Number of tokens after generation: {generated_ids.shape[1]}")
print(generated_texts[0])


Image resolution: 384 x 100
Number of tokens before generation: 24
Number of image tokens before generation: 0
Number of tokens after generation: 52
System: Compare the following two images:
User: Compare the following two images:
Assistant: The first image shows a group of people standing in a field, while the second image shows a group of people sitting in a circle.


In [47]:
# Check processor image configuration
print("Processor image configuration:")
if hasattr(processor, 'image_processor'):
    image_proc = processor.image_processor
    print(f"Image processor: {type(image_proc).__name__}")
    if hasattr(image_proc, 'size'):
        print(f"Target size: {image_proc.size}")
    if hasattr(image_proc, 'max_width'):
        print(f"Max width: {image_proc.max_width}")
    if hasattr(image_proc, 'max_height'):
        print(f"Max height: {image_proc.max_height}")
    print(f"All image processor attributes: {dir(image_proc)}")


Processor image configuration:
Image processor: Idefics3ImageProcessor
Target size: {'longest_edge': 1536}
All image processor attributes: ['__backends', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slotnames__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_create_repo', '_get_files_timestamps', '_pad_image', '_processor_class', '_set_processor_class', '_upload_modified_files', 'center_crop', 'do_convert_rgb', 'do_image_splitting', 'do_normalize', 'do_pad', 'do_rescale', 'do_resize', 'fetch_images', 'from_dict', 'from_json_file', 'from_pretrained', 'get_image_processor_dict', 'get_number_of_image_patches', 'image_mean', 'image_processor_type', 'image_std', 'is_fast', 'max_image_size', 'mod

In [54]:
# Check what happens during processing
print("\nBefore processing:")
print(f"Original image size: {image.size}")

# Process just the image to see transformation
processed = processor.image_processor(images=image, return_tensors="pt")
if 'pixel_values' in processed:
    pixel_shape = processed['pixel_values'].shape
    print(f"Processed image tensor shape: {pixel_shape}")
    # For vision transformers, often shape is [batch, channels, height, width]
    if len(pixel_shape) == 4:
        print(f"Processed image dimensions: {pixel_shape[2]} x {pixel_shape[3]}")


Before processing:
Original image size: (627, 1114)
Processed image tensor shape: torch.Size([1, 13, 3, 384, 384])


In [34]:
# Test behavior when exceeding max context length
print("Testing context length overflow behavior...\n")

# Create a very long prompt that definitely exceeds context limits
very_long_text = "Can you describe this image? " * 5000  # Much longer than your 1950 repetitions

messages = [
    {
        "role": "user", 
        "content": [
            {"type": "image", "url": woman_image},
            {"type": "text", "text": very_long_text},
        ]
    },
]

try:
    # Process the very long input
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    
    print(f"Input tokens after processing: {inputs['input_ids'].shape[1]}")
    print(f"Max model context: {getattr(model.config, 'max_position_embeddings', 'Not found')}")
    
    # Try to generate
    generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
    
    print("✅ Generation successful!")
    print(f"Number of tokens after generation: {generated_ids.shape[1]}")
    print(f"Response: {generated_texts[0][-200:]}")  # Show last 200 chars
    
except Exception as e:
    print(f"❌ Error occurred: {type(e).__name__}: {e}")

Testing context length overflow behavior...

Input tokens after processing: 31557
Max model context: Not found
✅ Generation successful!
Number of tokens after generation: 31621
Response: an you describe this image? Can you describe this image? Can you describe this image? Can you describe this image? 
Assistant: < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < < <


## Training

We can now initialize `Trainer` and initialize `TrainingArguments` to pass to `Trainer`.

Some notes:
- If you use 8-bit QLoRA with the below setup it uses around 16.4 GB VRAM (beautiful, fits comfortably inside L4, Colab free tier)
- We use gradient accumulation to simulate a larger batch size.
- We also save up on memory from intermediate activations by using gradient checkpointing.

**Disclaimer:**
The techniques here aren't free lunch. The latter two will add additional compute to the training, thus slow down a bit (for reference on two A100s with bsz of 16, we were able to train for 2 hrs 43 mins with the gradient accumulation steps of 4, disabling it reduced it with 2 hr 35 mins).
If you want to speed-up, you might play around, reduce to 4-bit precision and have a higher batch size. Note that 4-bit might result in model learning less.

In [6]:
from transformers import TrainingArguments, Trainer

model_name = model_id.split("/")[-1]

training_args = TrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=1,
    optim="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf
    bf16=True, # underlying precision for 8bit
    output_dir=f"./{model_name}-vqav2",
    hub_model_id=f"{model_name}-vqav2",
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True
)


In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_ds,
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [13]:
trainer.train()



Step,Training Loss
25,4.5418
50,0.3751


KeyboardInterrupt: 

In [None]:
trainer.push_to_hub()