In [20]:
import os
import json
import torch
from PIL import Image
from datasets import Dataset
from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    TrainingArguments,
    Trainer
)
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig

# --- CONFIG ---
MODEL_NAME = "Qwen/Qwen2-VL-2B"
JSON_PATH = "train_dataset.json"  # Make sure this is the correct path
OUTPUT_DIR = "./qwen2-vl-lora-output"
EPOCHS = 5
BATCH_SIZE = 1
LEARNING_RATE = 5e-5

# --- LOAD JSON ---
with open(JSON_PATH, "r") as f:
    raw_data = json.load(f)

# --- CLEAN & CHECK IMAGE PATHS ---
valid_data = []
for item in raw_data:
    image_path = item["image"].strip().replace("\\", "/")  # normalize slashes
    if os.path.exists(image_path):
        valid_data.append({
            "image": image_path,
            "query": item["query"],
            "description": item["description"]
        })
    else:
        print(f"⚠️ Skipping missing image: {image_path}")

# --- LOAD PROCESSOR ---
processor = AutoProcessor.from_pretrained(MODEL_NAME)

# --- PREPROCESS FUNCTION ---
def preprocess(example):
    image = Image.open(example["image"]).convert("RGB")
    query = example["query"]
    answer = example["description"]

    inputs = processor(images=image, text=query, return_tensors="pt", padding="max_length", truncation=True)
    labels = processor.tokenizer(answer, return_tensors="pt", padding="max_length", truncation=True).input_ids

    inputs["labels"] = labels[0]
    return {k: v.squeeze(0) if isinstance(v, torch.Tensor) and v.ndim > 1 else v for k, v in inputs.items()}

# --- CONVERT TO DATASET ---
dataset = Dataset.from_list(valid_data)
dataset = dataset.map(preprocess)

# --- LOAD MODEL (4-bit + LoRA) ---
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_NAME,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="VISION2SEQ"
)

model = get_peft_model(model, lora_config)

# --- TRAINING ARGS ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    learning_rate=LEARNING_RATE,
    fp16=True,
    num_train_epochs=EPOCHS,
    logging_steps=10,
    save_strategy="epoch",
    remove_unused_columns=False,
    report_to="none"
)

# --- TRAINER ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# --- TRAIN ---
trainer.train()

# --- SAVE ---
model.save_pretrained(f"{OUTPUT_DIR}/lora")
processor.save_pretrained(f"{OUTPUT_DIR}/lora")


JSONDecodeError: Invalid control character at: line 3 column 79 (char 84)

In [2]:
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "video",
                "image": "C:/AI/Github/Reconnaissance_drone_report/Data/Images/earthquake/download1.jpg",
                "max_pixels": 360 * 420,
                "fps": 1.0,
            },
            {"type": "text", "text": "Consider yourself as a airforce pilot who is operating a drone at this moment, explain this event brief."},
        ],
    }
]

In [None]:
import time
st=time.time()
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
end=time.time()
print("".join(output_text))
print("Time took to generate: ",end-st)

In [2]:

torch.cuda.reset_peak_memory_stats()

In [4]:
torch.cuda.empty_cache()

In [13]:
!pip install peft

Collecting peft
  Downloading peft-0.15.1-py3-none-any.whl.metadata (13 kB)
Downloading peft-0.15.1-py3-none-any.whl (411 kB)
Installing collected packages: peft
Successfully installed peft-0.15.1
