In [1]:
import sys
sys.path.append("..")
from models.gemma import GemmaAdapter as Model, extract_frames

In [2]:
CONFIG_PATH = '../configs/gemma_basic_finetune.yaml'

In [3]:
import yaml
with open(CONFIG_PATH, 'r') as f: 
    cfg = yaml.safe_load(f)
    MODEL = cfg['model']
    MODEL_ID = f"{cfg['model_space']}/{MODEL}"
    CACHE_DIR = cfg['cache_dir']
    DATASET_PATH = cfg['dataset_path']
    META_PATH = cfg['meta_path']
    EVAL_PATH = cfg['eval_meta_path']
    OUTPUT_RESULTS = f"{MODEL}_{cfg['output_prefix']}.txt"
    FPS = cfg['fps']
    NUM_FRAMES = cfg['num_frames']
    MAX_NEW_TOKENS = cfg['max_new_tokens']

with open(cfg['prompt'], 'r') as file: PROMPT = file.read().strip()

In [4]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
 
model_id = "google/gemma-3-4b-pt"  # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
 
# Define model init arguments
model_kwargs = dict(
    attn_implementation="flash_attention_2",  
    torch_dtype=torch.bfloat16,  
    device_map="auto",  # Let torch decide how to load the model
)
 
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
 
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [5]:
from peft import LoraConfig, get_peft_model
 
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=8,
        bias="none",
        target_modules=["q_proj", "v_proj"],
        task_type="CAUSAL_LM",
        
    )
model = get_peft_model(model, peft_config)

for name, param in model.named_parameters():
    if "lora" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False


In [6]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param:0.2f}%"
    )

In [7]:
print_trainable_parameters(model.model)

trainable params: 3223552 || all params: 1588518256 || trainable: 0.20%


In [8]:
from PIL import Image

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]
 
        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                image_inputs.append(Image.open(element["url"]).convert("RGB"))
    return image_inputs

In [None]:
from trl import SFTConfig
 
args = SFTConfig(
        seed=309,
        num_train_epochs=10,  # Number of training epochs
        per_device_train_batch_size=1,  # Batch size for training
        per_device_eval_batch_size=1,  # Batch size for evaluation
        gradient_accumulation_steps=8,  # Steps to accumulate gradients
        gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
        # Optimizer and scheduler settings
        optim="adamw_torch_fused",  # Optimizer type
        learning_rate=2e-4,  # Learning rate for training
        lr_scheduler_type="linear",  # Type of learning rate scheduler
        # Logging and evaluation
        logging_steps=1,  # Steps interval for logging
        eval_steps=10,  # Steps interval for evaluation
        logging_strategy="steps",  # Strategy for logging
        eval_strategy="steps",  # Strategy for evaluation
        save_strategy="steps",  # Strategy for saving the model
        save_steps=200,  # Steps interval for saving
        metric_for_best_model="eval_loss",  # Metric to evaluate the best model
        greater_is_better=False,  # Whether higher metric values are better
        load_best_model_at_end=True,  # Load the best model after training
        # Mixed precision and gradient settings
        bf16=True,  # Use bfloat16 precision
        tf32=False,  # Use tf32 precision
        max_grad_norm=0.3,  # Maximum norm for gradient clipping
        warmup_ratio=0.03,  # Ratio of total steps for warmup
        # Hub and reporting
        push_to_hub=False,  # Whether to push model to Hugging Face Hub
        # Gradient checkpointing settings
        gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
        # Dataset configuration
        output_dir='outputs',
        report_to='mlflow',
        remove_unused_columns=False,
        dataset_text_field="",  # Text field in dataset
        dataset_kwargs={"skip_prepare_dataset": True},  # Additional dataset options
        # max_seq_length=1024  # Maximum sequence length for input
    )


In [None]:
import tempfile
from pathlib import Path

def encode_messages(video_path: str, prompt: str, label: int, cache_dir : Path, num_frames: int = 8, **kwargs):
        video_frames = extract_frames(video_path, num_frames=num_frames)
        cache_dir.mkdir(exist_ok=True, parents=True)
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": [{"type": "text", "text": prompt}]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": str(label)}]
            }
        ]
        for frame_data in video_frames:
            img, timestamp = frame_data
            img_path_str = str(cache_dir / f"frame_{timestamp}.png")
            messages[1]["content"].append({"type": "text", "text": f"Frame at {timestamp} seconds:"})
            img.save(img_path_str)
            messages[1]["content"].append({"type": "image", "url": img_path_str})

        return messages
 
# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        messages = encode_messages(example["path"], PROMPT, example['label'], Path(CACHE_DIR) / 'frames', fps=FPS, num_frames=NUM_FRAMES)
        image_inputs = process_vision_info(messages)
        text = processor.apply_chat_template(
            messages, add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)
 
    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
 
    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()
 
    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100
 
    batch["labels"] = labels
    return batch

In [11]:
from dataset import UCF_Crime
from pathlib import Path
# Load dataset
train_ds = UCF_Crime(
    Path(DATASET_PATH), 
    META_PATH,
)
eval_ds = UCF_Crime(
    Path(DATASET_PATH), 
    EVAL_PATH,
)

In [12]:
from trl import SFTTrainer
 
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [13]:
# !export MLFLOW_TRACKING_URI=https://mlflow.eksli.ru/mlflow
# !export MLFLOW_TRACKING_USERNAME=mlflow-wastie
# !export MLFLOW_TRACKING_PASSWORD=z7NavXr49nc9tRPMdQl9


In [14]:
import mlflow
mlflow.set_experiment("Gemma-3-4b-QLora")


<Experiment: artifact_location='file:///home/kondrashov/smiles/DeviantBehaviorResearch/train/mlruns/398996829544260123', creation_time=1753134469129, experiment_id='398996829544260123', last_update_time=1753134469129, lifecycle_stage='active', name='Gemma-3-4b-QLora', tags={}>

In [None]:
with mlflow.start_run():
    torch.cuda.empty_cache()
    # Start training, the model will be automatically saved to the Hub and the output directory
    trainer.train()
    
    # Save the final model again to the Hugging Face Hub
    trainer.save_model()

<bos><bos><start_of_turn>user
You are a helpful assistant.

Return `1` if the video shows any deviant, abnormal or criminal behaviour; return `0` if it does not. Respond with only that single digit and nothing else.Frame at 0.0 seconds:

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_s

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


<bos><bos><start_of_turn>user
You are a helpful assistant.

Return `1` if the video shows any deviant, abnormal or criminal behaviour; return `0` if it does not. Respond with only that single digit and nothing else.Frame at 0.0 seconds:

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_s

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss


<bos><bos><start_of_turn>user
You are a helpful assistant.

Return `1` if the video shows any deviant, abnormal or criminal behaviour; return `0` if it does not. Respond with only that single digit and nothing else.Frame at 0.0 seconds:

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_s

KeyboardInterrupt: 

In [None]:
from tqdm import tqdm

with torch.inference_mode():
    for idx in tqdm(range(len(eval_ds)), total=len(eval_ds)):
        sample = eval_ds[idx]
        messages = encode_messages(sample['path'], PROMPT, Path(CACHE_DIR), NUM_FRAMES)
        inputs = processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(model.device)
        input_length = inputs["input_ids"].shape[-1]
        output = model.generate(**inputs, do_sample=False)
        output = output[0][input_length:]
        response = processor.decode(output, skip_special_tokens=True)

        line = f"{sample['path']}\t{response}\n"
        print(line)

In [None]:
del model
del trainer
torch.cuda.empty_cache()