## Data Preprocessing

In [1]:
import json
import os
import glob

LOG_DIR = "./dataset_logs"
OUTPUT_FILE = "vlm_sft_dataset.json"


def create_sft_dataset(log_dir):
    sft_data = []

    # Find all JSON log files
    json_files = glob.glob(os.path.join(log_dir, "log_*.json"))

    print(f"Found {len(json_files)} log files. Processing...")

    for json_path in json_files:
        try:
            with open(json_path, "r") as f:
                data = json.load(f)

            scene_id = data.get("scene", "")
            run_idx = data.get("run_idx", 0)
            raw_dialogue = data.get("dialogue", [])

            # Construct the conversation list for this specific run
            conversation = []

            # We track the round index to find the matching image (r0, r1, r2...)
            current_round_idx = 0

            for idx, message in enumerate(raw_dialogue):
                role = message.get("role")
                content = message.get("content")

                if role == "system":
                    # System prompt stays text-only
                    conversation.append({"role": "system", "content": content})

                    img_filename = (
                        f"{scene_id}_run{run_idx}_r{current_round_idx}_ego.jpg"
                    )
                    img_path = os.path.join(log_dir, img_filename)
                    conversation.append(
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": img_path},
                                {
                                    "type": "text",
                                    "text": "This is the current camera view.",
                                },
                            ],
                        }
                    )

                    current_round_idx += 1

                elif role == "user":
                    # This is where we inject the image
                    # e.g., "0010_run0_r0_ego.jpg"
                    img_filename = (
                        f"{scene_id}_run{run_idx}_r{current_round_idx}_ego.jpg"
                    )
                    img_path = os.path.join(log_dir, img_filename)

                    # Check if the image actually exists for this round
                    if os.path.exists(img_path):
                        # Create the multimodal user message
                        user_turn = {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": img_path},
                                {"type": "text", "text": content},
                            ],
                        }
                        # Increment round index only after we consume an image for a user turn
                        current_round_idx += 1
                    else:
                        # Fallback if image is missing: keep text only
                        user_turn = {"role": "user", "content": content}

                    conversation.append(user_turn)

                elif role == "assistant":
                    # Assistant replies are usually text/code
                    conversation.append({"role": "assistant", "content": content})

            # Add this conversation to the master list
            sft_data.append({"dialogue": conversation})

        except Exception as e:
            print(f"Error processing {json_path}: {e}")

    # Save the final dataset
    with open(OUTPUT_FILE, "w") as f:
        json.dump(sft_data, f, indent=4)

    print(f"Successfully processed {len(sft_data)} dialogues.")
    print(f"Saved to {OUTPUT_FILE}")


if __name__ == "__main__":
    create_sft_dataset(LOG_DIR)

Found 20 log files. Processing...
Successfully processed 20 dialogues.
Saved to vlm_sft_dataset.json


In [None]:
%pip install -q -U peft bitsandbytes accelerate datasets trl qwen-vl-utils

## Load Model


In [2]:
import torch
from transformers import (
    Qwen3VLForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
)

MODEL_ID = "/project/jevans/tzhang3/models/Qwen3-VL-8B-Instruct"

# Configure 4-bit quantization to fit in VRAM
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load Processor
# min_pixels/max_pixels control the resolution.
# Reducing max_pixels saves VRAM significantly.
processor = AutoProcessor.from_pretrained(
    MODEL_ID, min_pixels=256 * 28 * 28, max_pixels=640 * 28 * 28
)

# Load Model
model = Qwen3VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# @title 4. Apply LoRA Adapters
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=[],  # We don't need to save embed_tokens for code generation
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 43,646,976 || all params: 8,810,770,672 || trainable%: 0.4954


In [7]:
import json
import os
from torch.utils.data import Dataset as TorchDataset
from qwen_vl_utils import process_vision_info

class RobotDataset(TorchDataset):
    def __init__(self, data_path, processor, image_base_dir):
        self.processor = processor
        self.image_base_dir = image_base_dir
        
        # Rename 'self.data' to 'self.examples' to avoid property conflicts
        with open(data_path, 'r') as f:
            self.examples = json.load(f)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        
        messages = example["dialogue"]

        # Pre-process: Fix Paths and Types
        cleaned_messages = []
        for msg in messages:
            new_msg = msg.copy()
            
            # Type Fix
            if isinstance(new_msg['content'], str):
                new_msg['content'] = [{"type": "text", "text": new_msg['content']}]
            
            cleaned_messages.append(new_msg)

        # Apply Template & Tokenize
        text = self.processor.apply_chat_template(cleaned_messages, tokenize=False, add_generation_prompt=False)
        image_inputs, video_inputs = process_vision_info(cleaned_messages)
        
        # INCREASED MAX_LENGTH to prevent truncation of images
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding="max_length",
            truncation=True,
            max_length=16000, 
            return_tensors="pt"
        )

        # Standard squeeze for input_ids (1, seq_len) -> (seq_len)
        input_ids = inputs["input_ids"].squeeze(0)
        labels = input_ids.clone()
        
        # Masking Logic
        labels[:] = -100 
        
        assistant_header = "<|im_start|>assistant\n"
        header_ids = self.processor.tokenizer.encode(assistant_header, add_special_tokens=False)
        end_token_id = self.processor.tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
        
        input_ids_list = input_ids.tolist()
        seq_len = len(input_ids_list)
        header_len = len(header_ids)
        
        i = 0
        while i < seq_len - header_len:
            if input_ids_list[i : i + header_len] == header_ids:
                content_start = i + header_len
                content_end = -1
                for j in range(content_start, seq_len):
                    if input_ids_list[j] == end_token_id:
                        content_end = j
                        break
                
                if content_end != -1:
                    labels[content_start : content_end + 1] = input_ids[content_start : content_end + 1]
                    i = content_end
                else:
                    labels[content_start:] = input_ids[content_start:]
                    break
            i += 1

        result = {
            "input_ids": input_ids,
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": labels
        }
        
        if "pixel_values" in inputs:
            result["pixel_values"] = inputs["pixel_values"] # Usually 1D flattened, keep as is or check
            
            # Handle Grid THW: (N_images, 3)
            grid = inputs["image_grid_thw"]
            
            # If it comes out as (1, N, 3), squeeze the batch dim -> (N, 3)
            if grid.dim() == 3:
                result["image_grid_thw"] = grid.squeeze(0)
            # If it comes out as (1, 3) or (N, 3), KEEP IT. 
            # DO NOT SQUEEZE if dim() == 2, otherwise it becomes (3,) -> CRASH
            else:
                result["image_grid_thw"] = grid
            
        return result

# Instantiate Dataset
train_dataset = RobotDataset("vlm_sft_dataset.json", processor, "dataset_logs")
print(f"Dataset loaded with {len(train_dataset)} samples.")

Dataset loaded with 20 samples.


In [8]:
from dataclasses import dataclass
from typing import Dict, List, Union

@dataclass
class Qwen3VLDataCollator:
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_ids = [f["input_ids"] for f in features]
        labels = [f["labels"] for f in features]
        attention_mask = [f["attention_mask"] for f in features]
        
        pixel_values = [f["pixel_values"] for f in features if "pixel_values" in f]
        image_grid_thw = [f["image_grid_thw"] for f in features if "image_grid_thw" in f]

        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)

        batch = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

        if pixel_values:
            batch["pixel_values"] = torch.cat(pixel_values, dim=0)
            batch["image_grid_thw"] = torch.cat(image_grid_thw, dim=0)

        return batch

In [None]:
from transformers import Trainer, TrainingArguments

# Trainer
training_args = TrainingArguments(
    output_dir="./qwen3-vl-robot-planner",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=1,
    save_strategy="epoch",
    remove_unused_columns=False, 
    report_to="none"
)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    data_collator=Qwen3VLDataCollator(),
)

trainer.train()



  return fn(*args, **kwargs)


Step,Training Loss
1,0.6711
2,0.6772
3,0.5642
4,0.5083
5,0.6497
6,0.6365
7,0.5372
8,0.4841
9,0.5508
10,0.5474


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


TrainOutput(global_step=15, training_loss=0.5629600346088409, metrics={'train_runtime': 2553.5841, 'train_samples_per_second': 0.023, 'train_steps_per_second': 0.006, 'total_flos': 4.715013086208e+16, 'train_loss': 0.5629600346088409, 'epoch': 3.0})