# Fine-tune SmolVLM2 2b with LoRA

In [1]:
!pip install -q --upgrade accelerate datasets peft bitsandbytes pyav num2words

In [2]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb")

wandb.login(key=wandb_key)
wandb.init(project="kaggle", name="smolvlm2b_finetune_lora")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhmzbo[0m ([33mhmzbo-vektrai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, AutoModelForImageTextToText
import os


USE_LORA = True
USE_QLORA = False
SMOL = False

model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" if SMOL else "HuggingFaceTB/SmolVLM2-2.2B-Instruct"

processor = AutoProcessor.from_pretrained(
    model_id
)

if USE_QLORA or USE_LORA:
    lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian"
    )
    lora_config.inference_mode = False
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )

    model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        quantization_config=bnb_config if USE_QLORA else None,
        #_attn_implementation="flash_attention_2",
        device_map="auto"
    )
    model.add_adapter(lora_config)
    model.enable_adapters()
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, lora_config)
    print(model.get_nb_trainable_parameters())
else:
    model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        #_attn_implementation="flash_attention_2",
    ).to("cuda")

    # if you'd like to only fine-tune LLM
    for param in model.model.vision_model.parameters():
        param.requires_grad = False

peak_mem = torch.cuda.max_memory_allocated()
print(f"The model as is is holding: {peak_mem / 1024**3:.2f} of GPU RAM")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

(11269248, 2258054128)
The model as is is holding: 7.99 of GPU RAM


In [6]:
model = model.to(dtype=torch.bfloat16)  # Explicitly set dtype

In [8]:
from datasets import load_dataset

ds = load_dataset("TIGER-Lab/VideoFeedback", "real")

In [9]:
split_ds = ds["train"].train_test_split(test_size=0.5)
train_ds = split_ds["train"]

In [10]:
del split_ds, ds

In [11]:
print(f"prompt:  {train_ds[0]['text prompt']}, video: {train_ds[0]['video link']}")

prompt:  we zoom in on the odd figure, video: https://huggingface.co/datasets/hexuan21/VideoFeedback-videos-mp4/resolve/main/p/p009305.mp4


In [12]:
from torch.nn.utils.rnn import pad_sequence

image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
]

def collate_fn(examples):
    instances = []
    for example in examples:
        prompt = example["text prompt"]

        user_content = [{"type": "text", "text": "Caption the video."}]
        user_content.append({"type": "video", "path": example["video link"]})

        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": [{"type": "text", "text": f"{prompt}"}]}
        ]

        instance = processor.apply_chat_template(messages, add_generation_prompt=False,
                                                 tokenize=True, return_dict=True, return_tensors="pt").to("cuda").to(model.dtype)
        instances.append(instance)


    input_ids = pad_sequence(
        [inst["input_ids"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id
    )
    attention_mask = pad_sequence(
        [inst["attention_mask"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=0
    )
    labels = pad_sequence(
        [inst["input_ids"].squeeze(0).clone() for inst in instances],
        batch_first=True,
        padding_value=-100
    )

    labels[labels == image_token_id] = -100

    out = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


    # Step 1: figure out maximum frames, height, width across the batch
    pvs = [inst["pixel_values"].squeeze(0) for inst in instances if "pixel_values" in inst]
    if pvs:  # there is at least one non-None pixel_values
        max_frames = max(pv.shape[0] for pv in pvs)
        max_h = max(pv.shape[-2] for pv in pvs)
        max_w = max(pv.shape[-1] for pv in pvs)
    else:
        max_h = max_w = processor.video_size['longest_edge']
        max_frames = 1

    padded_pixel_values_list = []
    for ex in instances:
        pv = ex.get("pixel_values", None).squeeze(0)

        if pv is None:
            # text-only => fill pixel data + mask with zeros
            shape_pv = (max_frames, 3, max_h, max_w)
            padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
        else:
            f, c, h, w = pv.shape
            # Prepare final storage
            padded_pv = torch.zeros(
                (max_frames, c, max_h, max_w),
                dtype=pv.dtype,
                device=pv.device
            )
            padded_pv[:f, :, :h, :w] = pv
        padded_pixel_values_list.append(padded_pv)

    out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
    return out

In [13]:
from transformers import TrainingArguments, Trainer

model_name = model_id.split("/")[-1]

training_args = TrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=1,
    optim="adamw_hf", # for 8-bit, keep paged_adamw_8bit, else adamw_hf
    bf16=True,
    output_dir=f"./{model_name}-video-feedback",
    remove_unused_columns=False,
    report_to="wandb",
    dataloader_pin_memory=False
)


In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_ds,
)

No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [15]:
trainer.train()



Step,Training Loss
25,3.7957
50,1.2469
75,0.3383
100,0.2607
125,0.2474
150,0.2543
175,0.2123
200,0.2966
225,0.2479
250,0.2666


TrainOutput(global_step=2000, training_loss=0.2845357277393341, metrics={'train_runtime': 11532.2088, 'train_samples_per_second': 0.173, 'train_steps_per_second': 0.173, 'total_flos': 7980278517236160.0, 'train_loss': 0.2845357277393341, 'epoch': 1.0})

In [16]:
messages = [{"role": "user",
                 "content": [{"type": "text", "text": "Describe the video in details."},
                  {"type": "video", "path": "https://huggingface.co/datasets/hexuan21/VideoFeedback-videos-mp4/resolve/main/p/p000304.mp4"}]}]


inputs = processor.apply_chat_template(messages, add_generation_prompt=True,
                                          tokenize=True, return_dict=True, return_tensors="pt").to("cuda").to(model.dtype)

generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
generated_texts = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True,
)

print(generated_texts[0])

User: Describe the video in details.You are provided the following series of three frames from a 0:00:03 [H:MM:SS] video.

Frame from 00:00:
Frame from 00:01:
Frame from 00:02:


Assistant: a woman in a white shirt walks up to the podium.
