In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import requests
import torch 
from tqdm import tqdm

# model = AutoModelForCausalLM.from_pretrained(
#     "anananan116/TinyVLM",
#     trust_remote_code = True,
#     torch_dtype=torch.float16,
#     ).to('cuda').eval()
# tokenizer = AutoTokenizer.from_pretrained("anananan116/TinyVLM")
ds = load_dataset("darkyarding/MME")

In [None]:
ds['test'][0]

In [None]:
from src.model.llama import get_model_and_tokenizer
import json
from PIL import Image
import torch

model_args = {}
with open("configs/Specialtokens/default.json") as f:
    special_token_map = json.load(f)
model_args["pretrained_model"] = "results/checkpoint-21000"
additional_tokens_dict = {x['type']: x['token'] for x in special_token_map['added_tokens']}
model, tokenizer, _, _ = get_model_and_tokenizer(model_args, additional_tokens_dict, load_vision_model=True)
model = model.to(torch.float16)

In [None]:
device = torch.device('cuda')
model = model.to(device)

In [None]:
class Collator:
    def __init__(self):
        """
        Initialize the collator with any necessary parameters
        """
        pass
        
    def __call__(self, batch):
        """
        Collate a batch of samples into a format suitable for the model
        
        Args:
            batch: List of dictionaries containing the samples
            
        Returns:
            Dictionary with batched data
        """
        # Initialize lists to store batch items
        question_ids = []
        images = []
        questions = []
        answers = []
        categories = []
        
        # Collect items from the batch
        for item in batch:
            question_ids.append(item['question_id'])
            images.append(item['image'])
            questions.append(item['question'])
            answers.append(item['answer'])
            categories.append(item['category'])
            
        # Create the batched dictionary
        batched = {
            'question_id': question_ids,
            'image': images,  # Keep as list of PIL images
            'question': questions,
            'answer': answers,
            'category': categories
        }
        
        return batched

In [None]:
answers = []
dataloader = DataLoader(ds['test'], batch_size=16, shuffle=False, collate_fn=Collator())
with torch.no_grad():
    for batch in tqdm(dataloader):
        images = batch['image']
        prompts = ["<IMGPLH>" + x for x in batch['question']]
        question_ids = batch['question_id']
        inputs = model.prepare_input_ids_for_generation(prompts, images, tokenizer)
        outputs = model.generate(
            input_ids=inputs['input_ids'].to(device), 
            attention_mask=inputs['attention_mask'].to(device), 
            encoded_image = inputs["encoded_image"], 
            max_new_tokens=128, 
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            top_p=0.9,
            temperature = 0.8
        )
        output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        output_text = [text.split("assistant\n\n")[1] for text in output_text]
        for qid, ans, q, gt in zip(question_ids, output_text, batch['question'], batch['answer']):
            answers.append((qid, q, gt, ans))

In [None]:
with open("results/answers.json", "w") as f:
    json.dump(answers, f)