In [3]:
from google.colab import drive
drive.mount('/content/drive')

!pip install datasets
!pip install langchain
!pip install -U langchain-community
!pip install chromadb
# !pip install faiss-gpu
import torch
from datasets import load_dataset
from transformers import LlavaForConditionalGeneration, AutoProcessor
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma # FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PIL import Image
import pandas as pd
import networkx as nx
from tqdm import tqdm
import re

IMAGE_PATH = "/content/drive/MyDrive/MATH-V-main"

class MathVisionEvaluator:
    def __init__(self, dataset_name="MathLLMs/MathVision"):
        """
        Initialize the evaluator with RAG and graphRAG capabilities
        """
        # Load dataset and model as before
        self.dataset = load_dataset(dataset_name)
        self.model_id = "llava-hf/llava-1.5-7b-hf"
        self.model = LlavaForConditionalGeneration.from_pretrained(
            self.model_id,
            torch_dtype=torch.float16,
            device_map='auto'
        )
        self.processor = AutoProcessor.from_pretrained(self.model_id)

        # Initialize RAG components
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )

        # Create knowledge base from training data
        self._initialize_knowledge_base()

        # Create problem graph
        self._initialize_problem_graph()

        self.results = {'zero_shot': [], 'rag': [], 'graph_rag': []}

    def _initialize_knowledge_base(self, method='Chroma'):
        """
        Initialize RAG knowledge base from training data
        """

        # Prepare documents from training data
        documents = []
        metadatas = []
        for example in self.dataset['test']:
            doc = f"Question: {example['question']}\nAnswer: {example['answer']}\nSubject: {example['subject']}\nlevel: {example['level']}"
            documents.append(doc)
            metadatas.append({
                'subject': example['subject'],
                'level': example['level']
            })

        if method=='Chroma': #Chroma
            # Create Chroma vector store
            self.vector_store = Chroma.from_texts(
                texts=documents,
                metadatas=metadatas,
                embedding=self.embeddings,
                persist_directory="./chroma_db"  # This will persist the database locally
            )
        else: # FAISS
            # Split documents into chunks
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=500,
                chunk_overlap=50
            )
            texts = text_splitter.create_documents(documents)

            # Create vector store
            self.vector_store = FAISS.from_documents(texts, self.embeddings)


    def _initialize_problem_graph(self):
        """
        Create a graph representation of problems and their relationships
        """
        self.problem_graph = nx.Graph()

        # Add nodes for each training example
        for i, example in enumerate(self.dataset['test']):
            self.problem_graph.add_node(i,
                                      subject=example['subject'],
                                      level=example['level'],
                                      question=example['question'],
                                      answer=example['answer'])

        # Add edges between similar problems (based on subject and level)
        for i in range(len(self.dataset['test'])):
            for j in range(i + 1, len(self.dataset['test'])):
                if (self.problem_graph.nodes[i]['subject'] == self.problem_graph.nodes[j]['subject'] and
                    abs(self.problem_graph.nodes[i]['level'] - self.problem_graph.nodes[j]['level']) <= 1):
                    self.problem_graph.add_edge(i, j, weight=1)

    def _get_relevant_context(self, question, subject, level, method='rag'):
        """
        Get relevant context using either RAG or graphRAG
        """
        if method == 'rag':
            # Use traditional RAG to find similar problems
            similar_docs = self.vector_store.similarity_search(
                f"Question: {question} level: {level}", k=3, # Subject: {subject}
                filter={"subject": subject}
            )
            context = "\n\n".join([doc.page_content for doc in similar_docs])

        else:  # graph_rag
            # Find most similar node in graph
            similar_problems = []
            for node in self.problem_graph.nodes():
                node_data = self.problem_graph.nodes[node]
                if (node_data['subject'] == subject and
                    abs(node_data['level'] - level) <= 1):
                    similar_problems.append(node)

            # Get connected problems from graph
            context_problems = []
            for problem in similar_problems[:2]:  # Get top 2 similar problems
                neighbors = list(self.problem_graph.neighbors(problem))
                if neighbors:
                    context_problems.extend([self.problem_graph.nodes[n] for n in neighbors[:2]])

            context = "\n\n".join([
                f"Question: {p['question']}\nAnswer: {p['answer']}"
                for p in context_problems
            ])

        return context

    def _prepare_input(self, example, method='zero_shot'):
        """
        Prepare input with optional RAG/graphRAG context
        """
        base_prompt = f"Solve the following math problem step by step, given the image attached. Write the final answer after <Answer:> \n{example['question']}"

        if method in ['rag', 'graph_rag']:
            context = self._get_relevant_context(
                example['question'],
                example['subject'],
                example['level'],
                method=method
            )
            base_prompt = f"Here are some similar problems and their solutions:\n{context}\n\nNow solve this problem:\n{base_prompt}"

        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": base_prompt},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)

        # Process image if available
        if 'image' in example and example['image']:
            images = [Image.open(f"{IMAGE_PATH}/{example['image']}")]
        else:
            images = None

        inputs = self.processor(
            images=images,
            text=[prompt],
            padding=True,
            return_tensors="pt"
        ).to(self.model.device, torch.float16)

        return inputs

    def evaluate(self, methods=['zero_shot', 'rag', 'graph_rag']):
        """
        Evaluate using specified methods
        """
        for method in methods:
            model_results = []
            for i, example in enumerate(tqdm(self.dataset['testmini'],
                                          desc=f"Evaluating Llava - {method}")):
                inputs = self._prepare_input(example, method=method)

                generate_ids = self.model.generate(
                    **inputs,
                    max_new_tokens=200,
                )

                generated_text = self.processor.batch_decode(
                    generate_ids,
                    skip_special_tokens=True
                )
                extracted_solution = self._extract_solution(generated_text)

                model_results.append({
                    'question': example['question'],
                    'ground_truth': example['answer'],
                    'model_prediction': extracted_solution,
                    'method': method
                })

                if i % 50 == 0:
                    pd.DataFrame(model_results).to_csv(f'{IMAGE_PATH}/{method}_results.csv')

            self.results[method] = model_results

        return self.results

    def _extract_solution(self, generated_text):
        """Extract solution from generated text"""
        solution_match = re.search(r'<Answer:>\s*(.*)', generated_text[0], re.DOTALL)
        if solution_match:
            return solution_match.group(1).strip()
        return generated_text[0]

# Usage
evaluator = MathVisionEvaluator()
results = evaluator.evaluate(['zero_shot']) # 'zero_shot', 'rag', 'graph_rag'

# Save results
for method in results:
    pd.DataFrame(results[method]).to_csv(f'{IMAGE_PATH}/{method}_results.csv')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating Llava - zero_shot: 100%|██████████| 304/304 [23:49<00:00,  4.70s/it]
