# Install Dependencies

In [None]:
!pip install -U "transformers>=4.39.0"
!pip install peft bitsandbytes
!pip install -U "trl==0.8.3"
!pip install fsspec==2024.10.0 rich==13.7.1
!pip install unsloth
!pip install evaluate
!pip install rouge_score
!pip install bert-score
!git clone https://github.com/neulab/BARTScore.git

# Imports

In [None]:
import os
import unsloth
import pandas as pd
from PIL import Image
from datasets import Dataset
import torch
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import SFTTrainer
from peft import LoraConfig
from unsloth import FastLanguageModel
from collections import defaultdict
import wandb
from transformers import TrainingArguments
import torch
from huggingface_hub import login, HfApi
from huggingface_hub import HfApi
from huggingface_hub import upload_folder
from peft import get_peft_model
from peft import PeftModel
from bart_score import BARTScorer
import evaluate
from tqdm import tqdm
import numpy as np
import re
from sentence_transformers import SentenceTransformer, util
import sys

## Loading the dataset

In [None]:
# === Load your dataset ===
csv_path = "/kaggle/input/vr-dataset-final-20k/annotations.csv"
image_folder = "/kaggle/input/vr-dataset-final-20k/images/unique_images"

df = pd.read_csv(csv_path)
df["image_path"] = df["image_name"].apply(lambda x: os.path.join(image_folder, x))
assert all(os.path.exists(p) for p in df["image_path"]), "Some image paths are invalid."

# === Convert to expected format for LLaVA ===
def create_messages(row):
    return [
        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": row["question"]}]},
        {"role": "assistant", "content": [{"type": "text", "text": row["answer"]}]}
    ]

# -----------------------------

# Group all Q&As by image path
grouped = defaultdict(list)
for _, row in df.iterrows():
    grouped[row["image_path"]].append((row["question"], row["answer"]))

# Create new rows: one per image
new_rows = []
for image_path, qas in grouped.items():
    messages = []
    
    # Insert one image token at the start
    messages.append({"role": "user", "content": [{"type": "image"}]})
    
    for question, answer in qas:
        messages.append({"role": "user", "content": [{"type": "text", "text": question}]})
        messages.append({"role": "assistant", "content": [{"type": "text", "text": answer}]})
    
    new_rows.append({
        "messages": messages,
        "images": [image_path]
    })

# Create a new DataFrame
grouped_df = pd.DataFrame(new_rows)
#--------------------------------------------

# === Print statistics ===
print(f"Total unique images loaded: {len(grouped_df)}")

# === Split the dataset into train (80%), eval (10%), and test (10%) ===
train_df = grouped_df.sample(frac=0.8, random_state=42)
remaining_df = grouped_df.drop(train_df.index)

eval_df = remaining_df.sample(frac=0.5, random_state=42)  # 50% of the remaining dataset
test_df = remaining_df.drop(eval_df.index)  # The other 50% becomes the test set

# === Create the datasets ===
train_dataset = Dataset.from_pandas(train_df[["messages", "images"]])
eval_dataset = Dataset.from_pandas(eval_df[["messages", "images"]])
test_dataset = Dataset.from_pandas(test_df[["messages", "images"]])

print(f"Sample formatted message (Train Dataset):\n{train_dataset[0]}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


### Create a Chat template set `tokenizer` and `processor`

In [None]:
LLAVA_CHAT_TEMPLATE = """A chat between a user and an AI assistant. 
The assistant must always answer with exactly one word—no spaces, no explanations, no exceptions.
If unsure, it should respond with the most likely single word.  
Never say anything beyond one word.

{% for message in messages %}
{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}
{% for item in message['content'] %}
{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}
{% endfor %}
{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}
{% endfor %}"""

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizer

In [None]:
sample_messages = [
    {"role": "user", "content": [{"type": "text", "text": "What is in this image?"}]},
    {"role": "assistant", "content": [{"type": "text", "text": "This appears to be a beach scene."}]},
]
# TEST if the template works now
print(tokenizer.apply_chat_template(sample_messages, tokenize=False))

# Inference

In [None]:
base_model_id = "unsloth/llava-1.5-7b-hf-bnb-4bit"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=base_model_id,
    device_map="auto",
    load_in_4bit=True,
    use_gradient_checkpointing=True,
)


# Reapply the chat template
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE


In [None]:
def predict(image_path, question):
    messages = [
        {"role": "user", "content": [{"type": "image"}]},
        {"role": "user", "content": [{"type": "text", "text": question}]}
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


    # Open image
    image = Image.open(image_path).convert("RGB")

    # Process inputs
    inputs = processor(text=prompt, images=image, return_tensors="pt", padding=True).to(model.device)

    # Calculate input length
    input_len = inputs["input_ids"].shape[1]

    # Generate output
    output = model.generate(
        **inputs,
        max_new_tokens=20,
        do_sample=False,
        num_beams=1
    )


    # Extract only the generated tokens
    generated_tokens = output[0][input_len:]
    answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    return answer, image

## Check for one sample

In [None]:
from PIL import Image

# Example
row = df.iloc[26]
ans,img = predict(row["image_path"], row["question"])
print("Q:", row["question"])
print("A:", ans)
print("GT:", row["answer"])

In [None]:
img

In [None]:
sys.path.append('/kaggle/working/BARTScore')

## Evaluation

In [None]:
# Load metrics
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# BARTScore setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
bart_scorer = BARTScorer(device=device, checkpoint='facebook/bart-large-cnn')

# Exact Match
def compute_exact_match(pred, label):
    return int(pred.strip().lower() == label.strip().lower())

# Token-level F1
def compute_token_f1(pred, label):
    pred_tokens = pred.strip().lower().split()
    label_tokens = label.strip().lower().split()
    common = set(pred_tokens) & set(label_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(label_tokens)
    return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

# Sentence Embedding Cosine Similarity
def compute_semantic_similarity(pred, label):
    emb_pred = embedding_model.encode(pred, convert_to_tensor=True)
    emb_label = embedding_model.encode(label, convert_to_tensor=True)
    return float(util.cos_sim(emb_pred, emb_label))

# Clean model outputs
def clean_answer(answer):
    answer = re.sub(r"\b\w+:\s*", "", answer)
    answer = re.sub(r"^[^a-zA-Z]+", "", answer)
    return re.split(r"[?\n]", answer)[0].strip()

# Metric storage
exact_matches = []
token_f1s = []
rouge_scores = []
bert_scores = []
bart_scores = []
semantic_similarities = []

# Evaluation loop
for sample in tqdm(eval_dataset, desc="Evaluating"):
    image_path = sample["images"][0]
    messages = sample["messages"]
    user_question = messages[-2]["content"][0]["text"]
    expected_answer = messages[-1]["content"][0]["text"]

    # Get model prediction
    pred_raw, _ = predict(image_path, user_question)
    pred_answer = clean_answer(pred_raw)

    # Standard metrics
    exact_matches.append(compute_exact_match(pred_answer, expected_answer))
    token_f1s.append(compute_token_f1(pred_answer, expected_answer))
    
    rouge_result = rouge.compute(predictions=[pred_answer], references=[expected_answer], use_stemmer=True)
    rouge_scores.append(rouge_result["rougeL"])
    
    bert_result = bertscore.compute(predictions=[pred_answer], references=[expected_answer], lang="en")
    bert_scores.append(bert_result["f1"][0])
    
    bart_score = bart_scorer.score([pred_answer], [expected_answer])[0]
    bart_scores.append(bart_score)

    semantic_similarities.append(compute_semantic_similarity(pred_answer, expected_answer))

# Summary
print("\n🔍 Evaluation Metrics:")
print(f"  - Exact Match:            {np.mean(exact_matches):.4f}")
print(f"  - Token-level F1:         {np.mean(token_f1s):.4f}")
print(f"  - ROUGE-L:                {np.mean(rouge_scores):.4f}")
print(f"  - BERTScore (F1):         {np.mean(bert_scores):.4f}")
print(f"  - BARTScore:              {np.mean(bart_scores):.4f}")
print(f"  - Semantic Cosine Sim.:   {np.mean(semantic_similarities):.4f}")
