In [1]:
import sys
sys.path.append("..")
from models.gemma import GemmaAdapter as Model, extract_frames

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
CONFIG_PATH = '../configs/gemma_basic_finetune.yaml'

In [3]:
import yaml
with open(CONFIG_PATH, 'r') as f: 
    cfg = yaml.safe_load(f)
    MODEL = cfg['model']
    MODEL_ID = f"{cfg['model_space']}/{MODEL}"
    CACHE_DIR = cfg['cache_dir']
    DATASET_PATH = cfg['dataset_path']
    META_PATH = cfg['meta_path']
    EVAL_PATH = cfg['eval_meta_path']
    OUTPUT_RESULTS = f"{MODEL}_{cfg['output_prefix']}.txt"
    FPS = cfg['fps']
    NUM_FRAMES = cfg['num_frames']
    MAX_NEW_TOKENS = cfg['max_new_tokens']

with open(cfg['prompt'], 'r') as file: PROMPT = file.read().strip()

In [4]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
 
model_id = "google/gemma-3-4b-pt"  # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
 
# Define model init arguments
model_kwargs = dict(
    attn_implementation="flash_attention_2",  
    torch_dtype=torch.bfloat16,  
    device_map="auto",  # Let torch decide how to load the model
)
 
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
 
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.59s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [5]:
from peft import LoraConfig
 
peft_config = LoraConfig(
    init_lora_weights="eva",
    lora_alpha=16,
    lora_dropout=0,
    r=8,
    bias="none",
    target_modules=[
        'down_proj',
        'o_proj',
        'k_proj',
        'q_proj',
        'gate_proj',
        'up_proj',
        'v_proj'
    ],
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)



In [6]:
from PIL import Image

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]
 
        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                image_inputs.append(Image.open(element["url"]).convert("RGB"))
    return image_inputs

In [7]:
from trl import SFTConfig
 
args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=10,
        num_train_epochs=1, # For full training runs over the dataset.
        learning_rate=2e-4,
        bf16=True,
        logging_steps=200,
        save_strategy='steps',
        save_steps=200,
        save_total_limit=2,
        optim='sgd',
        weight_decay=0.01,
        lr_scheduler_type='linear',
        seed=3407,
        output_dir='outputs',
        report_to='none',     
        remove_unused_columns=False,
        dataset_text_field='',
        dataset_kwargs={'skip_prepare_dataset': True},
        max_seq_length=512,
)

In [8]:
import tempfile
from pathlib import Path

def encode_messages(video_path: str, prompt: str, cache_dir : Path, num_frames: int = 8, **kwargs):
        video_frames = extract_frames(video_path, num_frames=num_frames)
        cache_dir.mkdir(exist_ok=True, parents=True)
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": [{"type": "text", "text": prompt}]
            }
        ]
        for frame_data in video_frames:
            img, timestamp = frame_data
            img_path_str = str(cache_dir / f"frame_{timestamp}.png")
            messages[1]["content"].append({"type": "text", "text": f"Frame at {timestamp} seconds:"})
            img.save(img_path_str)
            messages[1]["content"].append({"type": "image", "url": img_path_str})

        return messages
 
# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        messages = encode_messages(example["path"], PROMPT, Path(CACHE_DIR) / 'frames', fps=FPS, num_frames=NUM_FRAMES)
        image_inputs = process_vision_info(messages)
        text = processor.apply_chat_template(
            messages, add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)
 
    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
 
    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()
 
    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100
 
    batch["labels"] = labels
    return batch

In [9]:
from dataset import UCF_Crime
from pathlib import Path
# Load dataset
train_ds = UCF_Crime(
    Path(DATASET_PATH), 
    META_PATH,
)
eval_ds = UCF_Crime(
    Path(DATASET_PATH), 
    EVAL_PATH,
)

In [10]:
from trl import SFTTrainer
 
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [11]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
 
# Save the final model again to the Hugging Face Hub
trainer.save_model()

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


Step,Training Loss


KeyboardInterrupt: 

In [None]:
del model
del trainer
torch.cuda.empty_cache()