# Installing Dependencies

In [None]:
# Install necessary packages
!pip install -U transformers accelerate peft bitsandbytes
!pip install backoff
!pip install flash-attn --no-build-isolation

# Imports

In [None]:
# Imports
import os
import torch
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers import (
    AutoProcessor, AutoModelForCausalLM,
    TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset

## Loading the model

In [None]:
# Paths
image_folder = "/kaggle/input/vr-dataset-final/images/images"
csv_path = "/kaggle/input/vr-dataset-final/annotations.csv"

model_id = "microsoft/Phi-3-vision-128k-instruct"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model     = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_4bit=True,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="eager"
)


In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules="all-linear"

)
model = get_peft_model(model, config)
model.print_trainable_parameters()

## Loading the dataset

In [None]:
class Phi3VQADataset(Dataset):
    def __init__(self, csv_path, image_folder, processor, max_samples=None):
        self.data = pd.read_csv(csv_path)
        if max_samples:
            self.data = self.data[:max_samples]
        self.image_folder = image_folder
        self.processor = processor
        self.image_token = processor.img_tokens[0]  # "<|image_1|>"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.image_folder, row["image_name"])
        image = Image.open(img_path).convert("RGB")

        # Create the prompt for the question
        prompt = f"[INST] {row['question']} {self.image_token} [/INST]"
        answer = row["answer"]

        # Process the text and image
        inputs = self.processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=256
        )

        # Tokenize the one-word answer directly
        labels = self.processor.tokenizer(
            answer,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=30
        ).input_ids

        # Squeeze the batch dimension
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = labels.squeeze(0)

        return inputs


In [None]:
class DataCollator:
    def __init__(self, processor, image_dir):
        self.processor = processor
        self.image_dir = image_dir

    def __call__(self, examples):
        example = examples[0]
        print(example)
        # Load image using image_name
        image_path = os.path.join(self.image_dir, example["image_name"])
        image = Image.open(image_path).convert("RGB")

        user_prompt = f"Answer the question about this image: {example['question']}"
        answer = example["answer"]

        # Prepare chat-style messages
        messages = [
            {"role": "user", "content": f"<|image_1|>\n{user_prompt}"}
        ]

        # Tokenize prompt with chat template
        prompt = self.processor.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        answer = f"{answer}<|end|>\n<|endoftext|>"

        # Process image and prompt
        batch = self.processor(prompt, [image], return_tensors="pt")
        prompt_input_ids = batch["input_ids"]

        # Tokenize the one-word answer
        answer_input_ids = self.processor.tokenizer(
            answer, add_special_tokens=False, return_tensors="pt"
        )["input_ids"]

        # Combine prompt and answer tokens
        concatenated_input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1)

        # Create labels: ignore prompt tokens
        ignore_index = -100
        labels = torch.cat(
            [
                torch.full_like(prompt_input_ids, ignore_index),
                answer_input_ids,
            ],
            dim=1,
        )

        # Final batch dictionary
        batch["input_ids"] = concatenated_input_ids
        batch["labels"] = labels

        # Ensure gradients only for float tensors
        for key, value in batch.items():
            if isinstance(value, torch.Tensor) and torch.is_floating_point(value):
                batch[key] = value.clone().detach().requires_grad_(True)

        return batch


# Training

In [None]:
# Set up the dataset
full_dataset = Phi3VQADataset(csv_path, image_folder, processor, max_samples=5)
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./phi4-mm-vqa",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    learning_rate=2e-4,
    weight_decay=0.01,
    fp16=True,
    report_to="none",
    logging_dir="./logs",
    logging_strategy="steps", 
    logging_steps=5,
    eval_strategy="steps",      
    eval_steps=5,                   
    save_strategy="steps",            
    save_steps=5,                   
    save_total_limit=2,              
    load_best_model_at_end=True,      
    metric_for_best_model="eval_loss", 
    greater_is_better=False,         
    label_names=["labels"]
)



In [None]:
# Initialize trainer with custom data collator
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.tokenizer,
    # data_collator=DataCollator(processor,image_folder)
)

trainer.train()

# Inference

In [None]:
def phi4_infer(image_path, question):
    image = Image.open(image_path).convert("RGB")
    prompt = f"[INST] {question} [/INST]"

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=10)
    
    return processor.tokenizer.decode(output[0], skip_special_tokens=True).strip()

# Test example
row = pd.read_csv(csv_path).iloc[123]
img_path = os.path.join(image_folder, row["image_name"])
print("Q:", row["question"])
print("A (pred):", phi4_infer(img_path, row["question"]))
print("A (GT):", row["answer"])
