# Fine-tune Gemma3n on videos

In this notebook, we will see how to fine-tune Gemma3n with video datset.

In [1]:
!pip install -U -q timm transformers trl peft datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
!pip install opencv-python

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
!pip install tensorboard

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [None]:
from huggingface_hub import login
import os

# Paste your token here (get it from: https://huggingface.co/settings/tokens)
HF_TOKEN = "" 

login(token=HF_TOKEN)

In [5]:
import io
import os
import zipfile
import cv2  # Added
import numpy as np # Added

import torch
from datasets import load_dataset
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration

from trl import (
    SFTConfig,
    SFTTrainer,
)

## Video preprocessing

In [6]:
from datasets import load_dataset, Video
import os

# 1. Load your specific dataset
# We use streaming=False here so it's easier to split into train/test
dataset = load_dataset("blind-assist/walk", split="test").cast_column("video", Video(decode=False))

# This ensures your script runs quickly from start to finish
dataset = dataset.select(range(50))

# 2. Split it into Train and Test (e.g., 90% train, 10% test)
dataset = dataset.train_test_split(test_size=0.1, seed=42)

# Print a check to make sure it's loaded
print(f"Dataset loaded! Total training examples: {len(dataset['train'])}")
print(f"Dataset loaded! Total testing examples: {len(dataset['test'])}")
print(f"Example data point: {dataset['train'][0]['alter']}")


# from datasets import load_dataset
# from datasets.features import Video

# # Load with streaming to avoid downloading everything
# dataset = load_dataset(
#     "blind-assist/walk-train", 
#     split="train",
#     streaming=True
# ).cast_column("video", Video(decode=False))

# # Take only 50 examples
# dataset_list = list(dataset.take(50))

# # Convert back to Dataset object for train_test_split
# from datasets import Dataset
# dataset = Dataset.from_list(dataset_list)

## Now split
# dataset = dataset.train_test_split(test_size=0.1, seed=42)

# print(f"Training examples: {len(dataset['train'])}")
# print(f"Testing examples: {len(dataset['test'])}")
# print(f"Example data point: {dataset['train'][0]['alter']}")

Resolving data files:   0%|          | 0/1008 [00:00<?, ?it/s]

Dataset loaded! Total training examples: 45
Dataset loaded! Total testing examples: 5
Example data point: a vehicle is passing ahead, please move in the 3 - o'clock direction.


In [7]:
print(f"Example data point: {dataset['train'][0]['video']}")

Example data point: {'bytes': None, 'path': '/root/.cache/huggingface/hub/datasets--blind-assist--walk/snapshots/ca890433d36693e4643f60302e56a6c8622dcafc/test/20240918-youtube_short_081e0a96bac802b988a1db9df310ddd1_1min03s.mp4'}


In [8]:
print(f"Example data point: {dataset['train'][0]['summary']}")

Example data point: it is sunny today. while crossing the road, a white car and an electric bike pass by in the front. the electric bike rider wears a mask and black clothes. a yellow electric bike is parked at one o'clock direction. two approaching cars are at eleven o'clock direction.


In [9]:
print(f"Example data point: {dataset['train'][0]['alter']}")

Example data point: a vehicle is passing ahead, please move in the 3 - o'clock direction.


In [10]:
import cv2

video_path = dataset['train'][0]['video']['path']
cap = cv2.VideoCapture(video_path)

if cap.isOpened():
    ret, frame = cap.read()
    if ret:
        print(f"✅ Connection successful! Extracted a frame of shape: {frame.shape}")
    else:
        print("❌ Connected, but could not read a frame. Check if the file is corrupted.")
else:
    print(f"❌ Could not open video at: {video_path}")

cap.release()

✅ Connection successful! Extracted a frame of shape: (1080, 1920, 3)


Sample 8 frames from the video

In [11]:
import cv2
from PIL import Image
import numpy as np

def downsample_video(video_path):
    """Extracts exactly 8 evenly spaced frames across the entire video duration."""
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)

    if total_frames <= 0:
        return []

    frames = []

    # Calculate 8 indices from the start to the end of the video
    indices = np.linspace(0, total_frames - 1, 8, dtype=int)

    for i in indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            # Convert BGR to RGB for PIL
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    
    vidcap.release()
    
    # Return the 8 frames (or fewer if the video is extremely short)
    return frames

print("8-frame extraction ready.")


8-frame extraction ready.


In [12]:
# We just need to make sure the dataset has the right names 
# and remove the columns we aren't using to keep it clean.

# We keep 'video' (path to mp4) and 'alter' (our target advice)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

print(f"Ready for training with {len(train_dataset)} samples.")
print(f"Ready for testing with: {len(test_dataset)} samples.")

Ready for training with 45 samples.
Ready for testing with: 5 samples.


### Load the model

Make sure you have your Hugging Face token

In [13]:
model = Gemma3nForConditionalGeneration.from_pretrained(
    "google/gemma-3n-E4B-it", torch_dtype=torch.bfloat16,   # or 4b model
)
processor = AutoProcessor.from_pretrained(
    "google/gemma-3n-E4B-it",                 # or 4b model
)
processor.tokenizer.padding_side = "right"

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [14]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
        print(f"  Memory reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")

PyTorch version: 2.4.1+cu124
CUDA available: True
CUDA version: 12.4
Number of GPUs: 1

GPU 0: NVIDIA A40
  Memory allocated: 0.00 GB
  Memory reserved: 0.00 GB


In [15]:
processor.tokenizer.all_special_ids

[2, 1, 3, 0, 262273, 256000, 255999, 262272, 262144, 262145]

Write our dataset collator. 

In collator we also sample videos into frames, we have written the helper above. For better results you need more frames.

In [16]:
def collate_fn(examples):
    example = examples[0]
    
    # 1. Get the video path from your dataset
    #video_path = example["video"] 
    video_path = example["video"]["path"] #use thiswhen video decode is false
    
    # 2. Extract your 8 evenly spaced frames using our simplified function
    frames = downsample_video(video_path)

    # 3. Your specific prompt for the "alter" task
    text = "Based on this walking path video, provide the necessary navigation advice."
    
    message = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": text}
            ],
        },
    ]

    # 4. Interleave frames into the message
    for frame in frames:
        image, timestamp = frame
        message[0]["content"].append({"type": "text", "text": f"Frame at {timestamp}s:"})
        # Pass the PIL image directly (faster than saving to disk)
        message[0]["content"].append({"type": "image", "image": image})

    # 5. Assistant response is ONLY the 'alter' advice
    message.append({
        "role": "assistant", 
        "content": [{"type": "text", "text": example["alter"]}]
    })

    # 6. Apply chat template
    inputs = processor.apply_chat_template(
        message,
        add_generation_prompt=False,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        padding=True,
    ).to(model.device)

    # 7. Create labels and mask special tokens for training
    labels = inputs["input_ids"].clone()
    special_token_ids = processor.tokenizer.all_special_ids
    special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)
    
    mask = torch.isin(labels, special_token_ids_tensor)
    labels[mask] = -100

    inputs["labels"] = labels
    
    return inputs

## Training

We do LoRA fine-tuning again to save up on space.

In [17]:
from peft import LoraConfig
peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    target_modules="all-linear",
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    use_rslora=False,
    use_dora=False,
    modules_to_save=None
)

In [18]:
# model.gradient_checkpointing_disable()
model.gradient_checkpointing_enable()

In [19]:
model.config.use_cache = False

In [20]:
training_args = SFTConfig(
    output_dir="./gemma-3n-blind-assist",
    eval_strategy='epoch',
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,  #false before
    learning_rate=1e-05,
    num_train_epochs=2.0,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,          # Only keep the 2 best versions to save disk space
    bf16=True,
    report_to=["tensorboard"],
    dataset_kwargs={'skip_prepare_dataset': True},
    remove_unused_columns=False,
    # max_seq_length=None,
    # push_to_hub=True,
    dataloader_pin_memory=False,
)

In [21]:
# !nvidia-smi

In [22]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None,
    processing_class=processor.tokenizer,
    peft_config=peft_config,
)

In [23]:
print(f"Model device: {model.device}")
print(f"Model is on CUDA: {next(model.parameters()).is_cuda}")

Model device: cuda:0
Model is on CUDA: True


In [24]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
        print(f"  Memory reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")

PyTorch version: 2.4.1+cu124
CUDA available: True
CUDA version: 12.4
Number of GPUs: 1

GPU 0: NVIDIA A40
  Memory allocated: 14.84 GB
  Memory reserved: 14.91 GB


In [25]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,8.4718,0.464443,2.777475,98725.0,0.896252
2,1.8781,0.295819,2.935516,197450.0,0.936719


TrainOutput(global_step=24, training_loss=4.537518322467804, metrics={'train_runtime': 554.5166, 'train_samples_per_second': 0.162, 'train_steps_per_second': 0.043, 'total_flos': 5945891190148800.0, 'train_loss': 4.537518322467804, 'epoch': 2.0})

Test the model

In [27]:
import torch
import subprocess
from PIL import Image

# 1. Download the test video
!wget -nc https://sprproxy-1258344707.cos.ap-shanghai.myqcloud.com/seraphyuan/ilabel/blind_vlm/data_sample/20240918-youtube_short_e93770538101c9669e57265fe378776e_1m42s.frame/20240918-youtube_short_e93770538101c9669e57265fe378776e_1m42s.mp4

# 2. Get the model from the trainer (which holds the trained LoRA adapters)
model = trainer.model 
model.eval() # Set to evaluation mode

# 3. Downsample the test video to 8 frames
# Use the same function we defined earlier
video_file = "/root/.cache/huggingface/hub/datasets--blind-assist--walk/snapshots/ca890433d36693e4643f60302e56a6c8622dcafc/test/20240918-youtube_short_081e0a96bac802b988a1db9df310ddd1_1min03s.mp4"
frames = downsample_video(video_file)

# 4. Construct the prompt
# Use the SAME prompt style you used in training!
prompt_text = "Based on this video, provide the necessary navigation advice."

message = [
    {
        "role": "user",
        "content": [{"type": "text", "text": prompt_text}],
    },
]

# 5. Add frames as PIL images directly
for i, (image, timestamp) in enumerate(frames):
    message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}s:"})
    message[0]["content"].append({"type": "image", "image": image})

# 6. Prepare inputs
inputs = processor.apply_chat_template(
    message,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device).to(model.dtype)

input_len = inputs["input_ids"].shape[-1]

# 7. Generate!
with torch.inference_mode():
    generation = model.generate(
        **inputs, 
        max_new_tokens=128, 
        do_sample=False # Keep it deterministic for navigation
    )
    # Slice to get only the new tokens (the model's response)
    generation = generation[0][input_len:]

# 8. Decode and Print
decoded = processor.decode(generation, skip_special_tokens=True)
print("\n--- Model Navigation Advice ---")
print(decoded)

File ‘20240918-youtube_short_e93770538101c9669e57265fe378776e_1m42s.mp4’ already there; not retrieving.



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



--- Model Navigation Advice ---
Based on the video, here's the navigation advice:

**General Direction:** You are currently traveling down a road in a town or city. The road appears to be heading straight ahead.

**Road Conditions:** The road is paved and seems to be in decent condition. There are multiple lanes.

**Traffic:** There is moderate traffic. You can see cars, a scooter, and a motorcycle on the road.

**Road Markings:** There are white lane markings on the road.

**Sidewalks:** There are sidewalks on both sides of the road.

**Street Features:**
* **Trees:** There are
