In [12]:
import os

os.chdir("/home/scottc/links/scratch/causal_pool/")

import jsonlines
from pprint import pprint

dataset = jsonlines.open("datasets/1k_simple/1k_simple.jsonl")

entry = list(dataset)[3]

In [13]:
from typing import Tuple

def get_metrics(entry, pred) -> Tuple[int, int]:
    """
    Returns (exactly correct or not, how many options were correct)
    """
    if not all(c.isalpha() and c.isupper() for c in pred):
        return 0, 0
    
    selected_options = set(ord(c) - ord("A") for c in pred)

    if len(selected_options) != len(pred):  # duplicate options
        return 0, 0
    
    ground_truth = set(entry["ground_truth"])
    
    exactly_correct = int(selected_options == ground_truth)
    num_correct = len(selected_options & ground_truth)
    
    return exactly_correct, num_correct

In [14]:
def build_prompt(entry):
    video_path = f"datasets/1k_simple/shots/{entry['video']}/video_{entry['video']}.mp4"
    question_prompt = f"{entry['question']}\n"
    for i, choice in enumerate(entry["options"]):
        question_prompt += f"{chr(ord('A') + i)}. {choice}\n"
    
    question_prompt += "\nPlease select the correct option(s). Don't write anything else than the option letter(s). Example: AC."
    
    return [{
        "role": "user",
        "content": [
            {
                "type": "video_url",
                "video_url": {"url": f"data:video/mp4;base64,{to_b64(video_path)}"},
            },
            {"type": "text", "text": question_prompt},
        ],
    },
    ]

import base64

def to_b64(video_path):
    with open(video_path, "rb") as video_file:
        return base64.b64encode(video_file.read()).decode("utf-8")

In [29]:
from openai import OpenAI

BASE_URL = "http://trig0002:8002/v1"
MODEL = "Qwen/Qwen3-VL-4B-Instruct"

client = OpenAI(base_url=BASE_URL, api_key="EMPTY")

response = client.chat.completions.create(
    messages=build_prompt(entry),
    model=MODEL,
    max_tokens=20,
    temperature=0.8,
    extra_body={
        "top_k": 20,
        "top_p": 0.8,
        "repetition_penalty": 1.0,
        "presence_penalty": 1.5,
    },
)

In [30]:
print(response.choices[0].message.content)
print(get_metrics(entry, response.choices[0].message.content))

A
(0, 0)
