In [None]:
import argparse
import os
import random
import glob
import pandas as pd
import pyarrow.parquet as pq
from PIL import Image
from io import BytesIO
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

from transformers import StoppingCriteriaList
from transformers import LlamaForCausalLM
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub

from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *

original_forward = LlamaForCausalLM.forward

#def new forward
def new_forward(self, *args, **kwargs):
    #remove cache_position
    kwargs.pop('cache_position', None)
    return original_forward(self, *args, **kwargs)

#
LlamaForCausalLM.forward = new_forward

def parse_args():
    parser = argparse.ArgumentParser(description="MME Dataset Processing")
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--root-dir", required=True, help="root directory of MME dataset")
    parser.add_argument("--output-file", default="mme_results.csv", help="output file name")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


class MMEDataset:
    def __init__(self, root_dir, split="test"):
        self.root_dir = root_dir
        self.split = split
        self.data = self.load_data()

    def load_data(self):
        data = []
        file_pattern = os.path.join(self.root_dir, f"{self.split}-*-of-00004-*.parquet")
        parquet_files = sorted(glob.glob(file_pattern))
        
        if not parquet_files:
            raise FileNotFoundError(f"No parquet files found matching pattern: {file_pattern}")

        for file_path in parquet_files:
            print(f"Loading data from {file_path}")
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Data file not found: {file_path}")
            df = pq.read_table(file_path).to_pandas()
            data.append(df)
        
        data = pd.concat(data, ignore_index=True)
        print(f"Loaded {len(data)} examples.")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        print(f"Image data type: {type(item['image'])}")
        print(f"Image content: {item['image']}")
        return item['image'], item['question'], item['answer']


# Model Initialization


conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
             'pretrain_llama2': CONV_VISION_LLama2}

print('Initializing Chat')
args = parse_args()
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

CONV_VISION = conv_dict[model_config.model_type]

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)

stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')


# MME Dataset Processing


def process_mme_dataset():
    dataset = MMEDataset(args.root_dir)
    results = []

    for idx in range(len(dataset)):
        image_dict, question, ground_truth = dataset[idx]
        
        if isinstance(image_dict, dict) and 'bytes' in image_dict:
            image = Image.open(BytesIO(image_dict['bytes']))
        else:
            raise ValueError(f"Unsupported image format: {type(image_dict)}")
        
        chat_state = CONV_VISION.copy()
        img_list = []
        
        chat.upload_img(image, chat_state, img_list)
        chat.encode_img(img_list)
        
        chat.ask(question, chat_state)
        
        answer_kwargs = {
            'conv': chat_state,
            'img_list': img_list,
            'num_beams': 1,
            'temperature': 1.0,
            'max_new_tokens': 300,
            'max_length': 2000
        }
    
        answer_kwargs.pop('cache_position', None)
        model_answer = chat.answer(conv=chat_state,
                                   img_list=img_list,
                                   num_beams=1,
                                   temperature=1.0,
                                   max_new_tokens=300,
                                   max_length=2000)[0]
        
        results.append({
            'question': question,
            'ground_truth': ground_truth,
            'model_answer': model_answer
        })

    results_df = pd.DataFrame(results)
    results_df.to_csv(args.output_file, index=False)
    print(f"Results saved to {args.output_file}")

if __name__ == "__main__":
    process_mme_dataset()