# Setup
Make sure to update the git config step with your email and username.
To use latest Molmo updates, merge the changes from [PR #33962](https://github.com/huggingface/transformers/pull/33962) in the `transformers` library. Then you can use trl to finetune with LoRA.


**Make sure to restart the runtime after installing above!**

In [None]:
# Necessary imports
import torch
from PIL import Image
from io import BytesIO
import base64
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    MolmoForConditionalGeneration,
    AutoProcessor,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split
import logging

logging.getLogger("transformers").setLevel(logging.ERROR)

In [None]:
# @title Training Variables
# Data preparation
model_id = "Molbap/molmo-hf-7B-D" #@param {type:"string"}
dataset_id = "remyxai/openspaces-exploded" #@param {type:"string"}
device = "cuda"

epochs = 1 #@param {type:"integer"}
learning_rate = 2e-5 #@param {type:"number"}
output_dir="lora_molmo_openspaces" #@param {type:"string"}

In [None]:
# @title Helper functions
def format_data(sample):
    formatted_data = []
    for turn in sample["messages"]:
        if turn["role"] == "system":
            continue

        formatted_turn = {"role": turn["role"], "content": []}
        for item in turn["content"]:
            if item["type"] == "text" and item["text"]:
                formatted_turn["content"].append({"type": "text", "text": item["text"]})
            elif item["type"] == "image" and item["image"]:
                try:
                    image_bytes = base64.b64decode(item["image"])
                    pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
                    formatted_turn["content"].append({"type": "image", "image": pil_image})
                except Exception as e:
                    logging.warning(f"Error processing image: {e}")
        if formatted_turn["content"]:
            formatted_data.append(formatted_turn)
    return formatted_data


def collate_fn(examples):
    images, conversations, labels = [], [], []

    for example in examples:
        # Extract user and assistant messages
        user_msg = next((msg for msg in example if msg["role"] == "user"), None)
        assistant_msg = next((msg for msg in example if msg["role"] == "assistant"), None)

        if user_msg:
            # Extract the first image from the user message, if present
            image_entry = next((item for item in user_msg["content"] if item["type"] == "image"), None)
            if image_entry:
                images.append(image_entry["image"])

        if assistant_msg:
            # Extract the first text entry from the assistant message
            labels.append(assistant_msg["content"][0]["text"])

        # Generate conversation template for processing
        conversations.append(processor.apply_chat_template(example, add_generation_prompt=True))

    # Tokenize inputs (images and conversations)
    inputs = processor(images=images, text=conversations, return_tensors="pt", padding="longest")

    # Tokenize labels separately and pad
    labels_tokenized = processor.tokenizer(
        labels, return_tensors="pt", padding="longest", truncation=True
    ).input_ids

    # Create labels with padding index -100
    inputs["labels"] = torch.full_like(inputs.input_ids, -100)
    inputs["labels"][:, :labels_tokenized.size(1)] = labels_tokenized

    # Ensure everything remains on the CPU
    return {key: tensor for key, tensor in inputs.items()}

In [None]:
# @title Prepare data
dataset = load_dataset(dataset_id, split="train")
formatted_dataset = [format_data(sample) for sample in tqdm(dataset, desc="Formatting dataset")]
train_dataset, eval_dataset = train_test_split(formatted_dataset, test_size=0.2, random_state=42)

Resolving data files:   0%|          | 0/71 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/71 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

Formatting dataset: 100%|██████████| 46275/46275 [08:44<00:00, 88.18it/s]


In [None]:
# @title Model Configuration
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = MolmoForConditionalGeneration.from_pretrained(
    model_id, device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

processor = AutoProcessor.from_pretrained(model_id)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 6,266,880 || all params: 8,027,750,912 || trainable%: 0.0781


In [None]:
# @title Training setup
training_args = TrainingArguments(
    num_train_epochs=epochs,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    warmup_steps=2,
    learning_rate=learning_rate,
    weight_decay=1e-6,
    logging_steps=100,
    save_steps=1000,
    save_total_limit=1,
    output_dir=output_dir,
    fp16=True,
    report_to=["wandb"],
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    data_collator=collate_fn,
    args=training_args,
)

In [None]:
# @title Train
trainer.train()

{'loss': 9.2787, 'grad_norm': 15.886802673339844, 'learning_rate': 1.9950294451347992e-05, 'epoch': 0.002701242571582928}
{'loss': 7.4569, 'grad_norm': 8.387580871582031, 'learning_rate': 1.9896266681074075e-05, 'epoch': 0.005402485143165856}
{'loss': 6.4071, 'grad_norm': 9.110421180725098, 'learning_rate': 1.9842238910800154e-05, 'epoch': 0.008103727714748784}
{'loss': 6.0989, 'grad_norm': 11.456809043884277, 'learning_rate': 1.9788211140526233e-05, 'epoch': 0.010804970286331712}
{'loss': 5.7842, 'grad_norm': 6.366999626159668, 'learning_rate': 1.9734183370252313e-05, 'epoch': 0.01350621285791464}
{'loss': 5.682, 'grad_norm': 7.873729705810547, 'learning_rate': 1.968015559997839e-05, 'epoch': 0.01620745542949757}
{'loss': 5.555, 'grad_norm': 11.063384056091309, 'learning_rate': 1.9626127829704468e-05, 'epoch': 0.0189086980010805}
{'loss': 5.4741, 'grad_norm': 7.201417922973633, 'learning_rate': 1.957210005943055e-05, 'epoch': 0.021609940572663425}
{'loss': 5.2964, 'grad_norm': 18.3982

KeyboardInterrupt: 