In [None]:
from __future__ import annotations
import os
import re
import json
import math
import argparse
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, TypedDict, Literal
from dotenv import load_dotenv


# LangGraph & LangChain imports
from langgraph.graph import StateGraph, END

try:
    from langchain_community.llms import Ollama
    HAVE_OLLAMA = True
except Exception:
    HAVE_OLLAMA = False

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.messages import HumanMessage

from duckduckgo_search import DDGS
import sympy as sp
import numpy as np

# Load environment variables
load_dotenv()

True

In [2]:
class AgentState(TypedDict):
    query: str
    plan: str
    route: Literal["web_search", "math_solver", "rag", "direct"]
    scratch: List[str]
    tool_result: Any
    final_answer: str


In [3]:
def default_embedder():
    return SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")

def load_docs_from_dir(docs_dir: str) -> List[Any]:
    docs = []
    for root, _, files in os.walk(docs_dir):
        for fn in files:
            path = os.path.join(root, fn)
            ext = fn.lower().split(".")[-1]
            try:
                if ext == "pdf":
                    docs.extend(PyPDFLoader(path).load())
                elif ext in ("docx",):
                    docs.extend(Docx2txtLoader(path).load())
                elif ext in ("txt", "md"):
                    docs.extend(TextLoader(path, encoding="utf-8").load())
            except Exception as e:
                print(f"Skipping {path}: {e}")
    return docs

def build_vectorstore(chunks, persist_dir: str):
    embeddings = default_embedder()
    vs = Chroma.from_documents(chunks, embedding=embeddings, persist_directory=persist_dir)
    vs.persist()
    return vs

def ensure_vectorstore(persist_dir: str):
    embeddings = default_embedder()
    return Chroma(persist_directory=persist_dir, embedding_function=embeddings)


In [None]:
from langchain_core.prompts import ChatPromptTemplate
import requests
import json
import re
import sympy as sp
import numpy as np

OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
    raise ValueError("Please set OPENROUTER_API_KEY in your .env file")

def tool_web_search(query: str, k: int = 5) -> List[Dict[str, str]]:
    results = []
    try:
        with DDGS() as ddgs:
            for r in ddgs.text(query, max_results=k):
                results.append({"title": r.get("title", ""), "href": r.get("href", ""), "body": r.get("body", "")})
    except Exception as e:
        results.append({"title": "search_error", "href": "", "body": str(e)})
    return results


# --- Helper function to extract numbers ---
def extract_number(text: str):
    """Extract last numerical value from text (float or int)."""
    matches = re.findall(r"[-+]?\d*\.?\d+", text)
    if matches:
        try:
            return float(matches[-1]) if "." in matches[-1] else int(matches[-1])
        except:
            return matches[-1]
    return None


# --- Calculator Tool ---
def tool_calculator(expr: str) -> dict:
    """Enhanced calculator with structured output."""
    try:
        # Standardize expression
        expr = expr.replace("^", "**")

        # Allowed sympy functions
        allowed = {k: getattr(sp, k) for k in [
            "sin", "cos", "tan", "asin", "acos", "atan",
            "log", "ln", "exp", "sqrt", "factorial",
            "ceiling", "floor", "Abs", "diff", "integrate"
        ]}

        # Constants
        allowed.update({
            "pi": sp.pi,
            "e": sp.E,
            "inf": sp.oo,
            "golden_ratio": (1 + sp.sqrt(5)) / 2
        })

        # Parse and evaluate
        val = sp.sympify(expr, locals=allowed)
        numeric_val = sp.N(val)

        return {
            "reasoning": f"Expression `{expr}` evaluated successfully.",
            "final_answer": float(numeric_val)
        }

    except Exception as e:
        return {
            "reasoning": f"Calculator error: {e}",
            "final_answer": None
        }


# --- GSM8k Math Solver ---
def tool_gsm8k_math_solve(problem: str, use_ollama: bool = True, model: str = "mistral") -> dict:
    """Math Problem Solver with structured output."""
    reasoning = ""
    final_answer = None

    # --- Try DeepSeek via OpenRouter ---
    if OPENROUTER_API_KEY:
        try:
            response = requests.post(
                url="https://openrouter.ai/api/v1/chat/completions",
                headers={
                    "Authorization": f"Bearer {OPENROUTER_API_KEY}",
                    "Content-Type": "application/json"
                },
                json={
                    "model": "deepseek/deepseek-r1-0528-qwen3-8b:free",
                    "messages": [
                        {"role": "system", "content": "You are a math tutor. Solve step by step with reasoning."},
                        {"role": "user", "content": problem}
                    ]
                }
            )

            if response.status_code == 200:
                content = response.json()["choices"][0]["message"]["content"]
                reasoning = content
                final_answer = extract_number(content)
            else:
                reasoning = f"DeepSeek API error {response.status_code}: {response.text}"

        except Exception as e:
            reasoning = f"DeepSeek error: {e}"

    # --- Fallback Solver (simple heuristics) ---
    if not final_answer:
        numbers = list(map(float, re.findall(r"[-+]?[0-9]*\.?[0-9]+", problem)))
        problem_lower = problem.lower()

        if len(numbers) >= 2:
            operations = {
                ("sum", "total", "add", "plus"): sum(numbers),
                ("difference", "subtract", "minus"): numbers[0] - sum(numbers[1:]),
                ("product", "multiply", "times"): np.prod(numbers),
                ("divide", "ratio", "per"): numbers[0] / numbers[1] if len(numbers) == 2 else None,
                ("average", "mean"): sum(numbers) / len(numbers)
            }

            for keywords, operation in operations.items():
                if any(word in problem_lower for word in keywords) and operation is not None:
                    reasoning = f"Heuristic detected operation `{keywords[0]}`."
                    final_answer = operation
                    break

    # --- Return structured result ---
    return {
        "reasoning": reasoning if reasoning else "No detailed reasoning available.",
        "final_answer": final_answer
    }

def tool_rag_query(question: str, persist_dir: str, k: int = 4) -> Dict[str, Any]:
    """Enhanced RAG with better document retrieval"""
    try:
        vs = ensure_vectorstore(persist_dir)
        # Use semantic search with reranking
        docs = vs.similarity_search(question, k=k)
        
        # Extract and format relevant information
        contexts = []
        metadata = []
        
        for doc in docs:
            # Clean and format the context
            context = doc.page_content.strip()
            contexts.append(context)
            
            # Enhance metadata with source information
            meta = doc.metadata.copy()
            meta["relevance_score"] = doc.metadata.get("score", 1.0)
            metadata.append(meta)
        
        return {
            "contexts": contexts,
            "docs": metadata,
            "summary": "\n".join(f"Document {i+1}:\n{ctx}" for i, ctx in enumerate(contexts[:2]))
        }
    except Exception as e:
        return {"error": f"RAG not available: {e}"}

  '''if use_ollama and HAVE_OLLAMA:


In [23]:
from langchain_core.prompts import ChatPromptTemplate
import requests
import json

# Make sure to get the API key from environment variables
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
    raise ValueError("Please set OPENROUTER_API_KEY in your .env file")

def query_deepseek(user_query: str) -> str:
    """Query the DeepSeek model through OpenRouter API"""
    try:
        response = requests.post(
            url="https://openrouter.ai/api/v1/chat/completions",
            headers={
                "Authorization": f"Bearer {OPENROUTER_API_KEY}",
                "Content-Type": "application/json",
                "HTTP-Referer": "http://localhost:7860",
                "X-Title": "LangGraph Agentic Controller"
            },
            json={
                "model": "deepseek/deepseek-r1-0528-qwen3-8b:free",  # Using the free model
                "messages": [
                    {
                        "role": "system",
                        "content": """You are a routing agent. Analyze the query and decide which tool should handle it.
                        Choose from: math_solver, web_search, rag, or direct.
                        - math_solver: for calculations, math problems, equations
                        - web_search: for real-world facts, current events, people
                        - rag: for document-based questions
                        - direct: for simple queries that need no tools
                        Respond ONLY with one of these four options."""
                    },
                    {
                        "role": "user",
                        "content": user_query
                    }
                ]
            }
        )
        
        if response.status_code == 200:
            content = response.json()["choices"][0]["message"]["content"].strip().lower()
            print(f"Raw response from model: {content}")  # Added debugging
            # Ensure we get one of our valid routes
            valid_routes = {"math_solver", "web_search", "rag", "direct"}
            if content in valid_routes:
                return content
            # If model returns something else, map it to the closest valid route
            for route in valid_routes:
                if route in content:
                    return route
            return "direct"  # fallback
        else:
            print(f"OpenRouter API error: {response.status_code}")
            print(f"Error message: {response.text}")
            return "direct"
    except Exception as e:
        print(f"Error querying DeepSeek: {e}")
        return "direct"

def controller_node(state: AgentState) -> AgentState:
    """Route queries using DeepSeek through OpenRouter"""
    query = state["query"]
    try:
        route = query_deepseek(query)
        state["route"] = route
        state["plan"] = f"Plan: Using {route} to handle the query."
    except Exception as e:
        print(f"Router error: {e}")
        state["route"] = "direct"
        state["plan"] = "Plan: Fallback to direct response due to routing error."
    return state

def web_search_node(state: AgentState) -> AgentState:
    return {**state, "scratch": ["web_search done"], "tool_result": tool_web_search(state["query"])}

def math_solver_node(state: AgentState) -> AgentState:
    query = state["query"].strip()
    # Strict math expression detection
    if re.fullmatch(r"[0-9\.\s\+\-\*\/\^\(\)]+", query):
        result = tool_calculator(query)
        return {**state, "scratch": ["calculator used"], "tool_result": result}
    else:
        # Send word problems (like GSM8K) to DeepSeek solver
        result = tool_gsm8k_math_solve(query)
        return {**state, "scratch": ["gsm8k solver used"], "tool_result": result}

def rag_node(state: AgentState, persist_dir: str) -> AgentState:
    return {**state, "scratch": ["rag used"], "tool_result": tool_rag_query(state["query"], persist_dir)}

def synthesizer_node(state: AgentState) -> AgentState:
    if state["route"] == "web_search":
        items = state["tool_result"] or []
        answer = "\n".join([f"{i['title']}: {i['body']}" for i in items[:3]]) or "No results."
    elif state["route"] == "math_solver":
        answer = f"Result: {state['tool_result']}"
    elif state["route"] == "rag":
        ctxs = state["tool_result"].get("contexts", [])
        answer = "\n---\n".join(ctxs[:2]) or "No context found."
    else:
        answer = "General knowledge route chosen."
    return {**state, "final_answer": answer}

In [24]:
# Test OpenRouter connection
test_query = "What is 2+2?"
print("Testing OpenRouter connection...")
print(f"API Key present: {'Yes' if OPENROUTER_API_KEY else 'No'}")
print(f"API Key length: {len(OPENROUTER_API_KEY) if OPENROUTER_API_KEY else 0}")

try:
    # Make the API request with detailed debugging
    response = requests.post(
        url="https://openrouter.ai/api/v1/chat/completions",
        headers={
            "Authorization": f"Bearer {OPENROUTER_API_KEY}",
            "Content-Type": "application/json",
            "HTTP-Referer": "http://localhost:7860",
            "X-Title": "LangGraph Agentic Controller"
        },
        json={
            "model": "deepseek/deepseek-r1-0528-qwen3-8b:free",  # Using the free model
            "messages": [
                {
                    "role": "system",
                    "content": "You are a routing agent. Choose from: math_solver, web_search, rag, or direct."
                },
                {
                    "role": "user",
                    "content": test_query
                }
            ]
        }
    )
    
    print(f"\nResponse status code: {response.status_code}")
    
    if response.status_code == 200:
        response_data = response.json()
        content = response_data["choices"][0]["message"]["content"].strip()
        print(f"\nRaw response from DeepSeek: {content}")
        
        route = query_deepseek(test_query)  # Test the actual routing function
        print(f"Final determined route: '{route}'")
        
        print("\nTest successful! The routing system is working.")
    else:
        print(f"\nError response from OpenRouter:")
        print(response.text)
        
except Exception as e:
    print(f"\nError connecting to OpenRouter: {str(e)}")
    print("Debugging information:")
    print("1. Check if your .env file exists and contains OPENROUTER_API_KEY")
    print("2. Verify your API key is valid at https://openrouter.ai/keys")
    print("3. Make sure you have internet connectivity")
    print("\nFull error:", str(e))

Testing OpenRouter connection...
API Key present: Yes
API Key length: 73

Response status code: 200

Raw response from DeepSeek: direct
Raw response from model: math_solver
Final determined route: 'math_solver'

Test successful! The routing system is working.


In [25]:
# Evaluation Framework
import numpy as np
from typing import List, Dict, Any
from tqdm import tqdm

class BenchmarkEvaluator:
    def __init__(self, pipeline):
        self.pipeline = pipeline
        
    def evaluate_lama(self, examples: List[Dict[str, str]]) -> Dict[str, float]:
        """
        Evaluate on LAMA benchmark
        examples: List of dicts with 'question' and 'answer' keys
        """
        results = []
        correct = 0
        
        for ex in tqdm(examples, desc="Evaluating LAMA"):
            result = self.pipeline({"query": ex["question"]})
            predicted = result["final_answer"].lower()
            actual = ex["answer"].lower()
            
            # Check if actual answer is in predicted text
            is_correct = actual in predicted
            if is_correct:
                correct += 1
                
            results.append({
                "question": ex["question"],
                "predicted": predicted,
                "actual": actual,
                "correct": is_correct
            })
        
        accuracy = correct / len(examples)
        return {
            "accuracy": accuracy,
            "results": results,
            "total_examples": len(examples)
        }
    
    def evaluate_gsm8k(self, examples: List[Dict[str, Any]]) -> Dict[str, float]:
        """
        Evaluate on GSM8k benchmark
        examples: List of dicts with 'question' and 'answer' keys
        """
        results = []
        correct = 0
        
        for ex in tqdm(examples, desc="Evaluating GSM8k"):
            result = self.pipeline({"query": ex["question"]})
            predicted = None
            actual = float(ex["answer"])  # Convert answer to float first
            is_correct = False
            
            try:
                # Extract numerical answer from the result
                predicted_text = result["final_answer"]
                # Look for "Result:" or "Answer:" in the text
                for prefix in ["Result:", "Answer:", "Numerical Answer:"]:
                    if prefix in predicted_text:
                        number_text = predicted_text.split(prefix)[-1].split()[0]
                        predicted = float(''.join(
                            filter(lambda x: x.isdigit() or x in '.-', number_text)))
                        break
                        
                if predicted is not None:
                    # Check if the answer is within 1% relative error
                    is_correct = abs(predicted - actual) / abs(actual) < 0.01
                    if is_correct:
                        correct += 1
                        
            except Exception:
                pass  # Keep default values for failed parsing
                
            results.append({
                "question": ex["question"],
                "predicted": predicted,
                "actual": actual,
                "correct": is_correct
            })
        
        accuracy = correct / len(examples)
        return {
            "accuracy": accuracy,
            "results": results,
            "total_examples": len(examples)
        }

# Example test sets
lama_examples = [
    {"question": "What is the capital of France?", "answer": "Paris"},
    {"question": "Who wrote Romeo and Juliet?", "answer": "William Shakespeare"},
    {"question": "What is the chemical symbol for gold?", "answer": "Au"}
]

gsm8k_examples = [
    {
        "question": "Janet has 3 apples. She buys 5 more. How many apples does she have now?",
        "answer": "8"
    },
    {
        "question": "A train travels 120 kilometers in 2 hours. What is its speed in kilometers per hour?",
        "answer": "60"
    }
]

In [26]:
def build_graph(persist_dir: Optional[str]):
    graph = StateGraph(AgentState)
    graph.add_node("controller", controller_node)
    graph.add_node("web_search", web_search_node)
    graph.add_node("math_solver", math_solver_node)
    graph.add_node("rag", lambda s: rag_node(s, persist_dir))
    graph.add_node("synth", synthesizer_node)

    graph.set_entry_point("controller")
    graph.add_conditional_edges("controller", lambda s: s["route"], {
        "web_search": "web_search",
        "math_solver": "math_solver",
        "rag": "rag",
        "direct": "synth"
    })
    graph.add_edge("web_search", "synth")
    graph.add_edge("math_solver", "synth")
    graph.add_edge("rag", "synth")
    graph.add_edge("synth", END)
    return graph.compile()


In [27]:
class Pipeline:
    def __init__(self, app):
        self.app = app
    
    def __call__(self, state):
        return self.app.invoke({
            **{
                "plan": "",
                "route": "direct",
                "scratch": [],
                "tool_result": None,
                "final_answer": ""
            },
            **state
        })

import gradio as gr
import socket
import tempfile
import os
from pathlib import Path
import mimetypes

def find_free_port(start_port=7860, max_port=7960):
    """Find a free port in the given range."""
    for port in range(start_port, max_port + 1):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            try:
                s.bind(('', port))
                return port
            except OSError:
                continue
    raise OSError(f"No free ports found in range {start_port}-{max_port}")

def detect_file_type(file_path, original_filename):
    """Detect file type using file extension and basic content checking"""
    # First try by extension
    ext = os.path.splitext(original_filename)[1].lower()
    if ext in ['.pdf', '.docx', '.doc', '.txt', '.md']:
        return ext
    
    # If no extension or unrecognized, try mime type
    mime_type, _ = mimetypes.guess_type(original_filename)
    if mime_type:
        ext_mapping = {
            'application/pdf': '.pdf',
            'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
            'application/msword': '.doc',
            'text/plain': '.txt',
            'text/markdown': '.md'
        }
        if mime_type in ext_mapping:
            return ext_mapping[mime_type]
    
    # Try content-based detection (basic)
    try:
        with open(file_path, 'rb') as f:
            header = f.read(8)  # Read first 8 bytes
            
            # PDF signature check
            if header.startswith(b'%PDF'):
                return '.pdf'
            
            # DOC/DOCX check (basic)
            if header.startswith(b'PK\x03\x04'):  # DOCX is a ZIP file
                return '.docx'
            
            # For text files, try to decode as UTF-8
            try:
                with open(file_path, 'r', encoding='utf-8') as tf:
                    tf.read(1024)  # Try reading some content
                return '.txt'  # If we can read it as text, assume it's a text file
            except UnicodeDecodeError:
                pass
    except Exception:
        pass
    
    return '.txt'  # Default to text if we can't determine type

def process_uploaded_file(file_obj, persist_dir):
    """Process an uploaded file and add it to the RAG store"""
    if file_obj is None:
        return "No file uploaded"
    
    temp_path = None
    try:
        # Handle file upload from Gradio
        if hasattr(file_obj, 'name'):
            # Regular file object
            original_filename = os.path.basename(file_obj.name)
            file_content = file_obj.read()
        else:
            # Bytes object from Gradio
            original_filename = "uploaded_document"
            file_content = file_obj
        
        # Create temp file with no extension first
        with tempfile.NamedTemporaryFile(delete=False, mode='wb') as temp_file:
            if isinstance(file_content, bytes):
                temp_file.write(file_content)
            else:
                temp_file.write(file_content.encode('utf-8'))
            temp_path = temp_file.name
        
        # Detect file type
        file_ext = detect_file_type(temp_path, original_filename)
        
        # Create a new temp file with correct extension
        new_temp_path = temp_path + file_ext
        os.rename(temp_path, new_temp_path)
        temp_path = new_temp_path
        
        # Load the document based on file type
        try:
            if file_ext == '.pdf':
                docs = PyPDFLoader(temp_path).load()
            elif file_ext in ['.docx', '.doc']:
                docs = Docx2txtLoader(temp_path).load()
            else:  # .txt or .md
                # Try different encodings
                encodings = ['utf-8', 'latin-1', 'cp1252']
                docs = None
                for encoding in encodings:
                    try:
                        docs = TextLoader(temp_path, encoding=encoding).load()
                        break
                    except UnicodeDecodeError:
                        continue
                
                if docs is None:
                    return f"Unable to decode file with any supported encoding"
            
            # Process documents
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            )
            chunks = text_splitter.split_documents(docs)
            
            # Build or update vectorstore
            vs = build_vectorstore(chunks, persist_dir)
            
            return f"Successfully processed {original_filename} ({len(chunks)} chunks created)"
            
        except Exception as e:
            import traceback
            error_details = traceback.format_exc()
            return f"Error processing file content: {str(e)}\nDetails:\n{error_details}"
            
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        return f"Error handling file: {str(e)}\nDetails:\n{error_details}"
        
    finally:
        # Clean up temp file
        if temp_path and os.path.exists(temp_path):
            try:
                os.unlink(temp_path)
            except Exception:
                pass

def run_gradio(persist_dir="./rag_store"):
    # Ensure RAG directory exists
    os.makedirs(persist_dir, exist_ok=True)
    
    app = build_graph(persist_dir)
    pipe = Pipeline(app)

    def chat_fn(msg, history):
        result = pipe({"query": msg})
        ans = f"**Route:** {result['route']}\n\n**Plan:** {result['plan']}\n\n### Answer\n{result['final_answer']}"
        return history + [[msg, ans]]

    # Find an available port
    try:
        port = find_free_port()
        print(f"Starting Gradio server on port {port}")
        
        with gr.Blocks() as demo:
            gr.Markdown("# 🤖 LangGraph Agentic Controller (Notebook Version)")
            
            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(height=500)
                    msg = gr.Textbox(label="Your query")
                    
                with gr.Column(scale=1):
                    gr.Markdown("### Document Upload")
                    file_output = gr.Textbox(label="Upload Status", interactive=False)
                    upload_button = gr.File(
                        label="Upload Document",
                        file_types=[".pdf", ".docx", ".doc", ".txt", ".md"],
                        type="binary"
                    )
                    
            # Set up event handlers
            msg.submit(chat_fn, [msg, chatbot], chatbot)
            upload_button.change(
                fn=lambda f: process_uploaded_file(f, persist_dir),
                inputs=[upload_button],
                outputs=[file_output]
            )
            
        demo.launch(server_name="0.0.0.0", server_port=port, share=True)
    except Exception as e:
        print(f"Error starting Gradio server: {e}")
        raise

In [28]:
# Launch the Gradio app
try:
    run_gradio()
except Exception as e:
    print(f"Failed to start Gradio app: {e}")
    print("Try closing any other Gradio apps or Jupyter notebooks that might be using the ports")

Starting Gradio server on port 7862


  chatbot = gr.Chatbot(height=500)


* Running on local URL:  http://0.0.0.0:7862

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


2025/08/26 17:48:21 [W] [service.go:132] login to server failed: read tcp 10.1.163.176:51256->44.237.78.176:7000: wsarecv: An existing connection was forcibly closed by the remote host.


Raw response from model: math_solver


In [29]:
# Run benchmarks before launching the app
print("Running benchmarks...")

# Initialize pipeline and evaluator
app = build_graph("./rag_store")
pipe = Pipeline(app)
evaluator = BenchmarkEvaluator(pipe)

# Run LAMA benchmark
print("\nEvaluating LAMA benchmark:")
lama_results = evaluator.evaluate_lama(lama_examples)
print(f"LAMA Accuracy: {lama_results['accuracy']:.2%}")
print("\nSample LAMA Results:")
for r in lama_results['results'][:3]:  # Show first 3 examples
    print(f"Q: {r['question']}")
    print(f"A (predicted): {r['predicted']}")
    print(f"A (actual): {r['actual']}")
    print(f"Correct: {r['correct']}\n")

# Run GSM8k benchmark
print("\nEvaluating GSM8k benchmark:")
gsm8k_results = evaluator.evaluate_gsm8k(gsm8k_examples)
print(f"GSM8k Accuracy: {gsm8k_results['accuracy']:.2%}")
print("\nSample GSM8k Results:")
for r in gsm8k_results['results'][:3]:  # Show first 3 examples
    print(f"Q: {r['question']}")
    print(f"A (predicted): {r['predicted']}")
    print(f"A (actual): {r['actual']}")
    print(f"Correct: {r['correct']}\n")

Running benchmarks...

Evaluating LAMA benchmark:


Evaluating LAMA:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating LAMA:  33%|███▎      | 1/3 [00:07<00:15,  7.56s/it]

Raw response from model: direct


Evaluating LAMA:  67%|██████▋   | 2/3 [00:13<00:06,  6.53s/it]

Raw response from model: direct


Evaluating LAMA: 100%|██████████| 3/3 [00:18<00:00,  6.15s/it]


Raw response from model: direct
LAMA Accuracy: 0.00%

Sample LAMA Results:
Q: What is the capital of France?
A (predicted): general knowledge route chosen.
A (actual): paris
Correct: False

Q: Who wrote Romeo and Juliet?
A (predicted): general knowledge route chosen.
A (actual): william shakespeare
Correct: False

Q: What is the chemical symbol for gold?
A (predicted): general knowledge route chosen.
A (actual): au
Correct: False


Evaluating GSM8k benchmark:


Evaluating GSM8k:   0%|          | 0/2 [00:00<?, ?it/s]

Raw response from model: math_solver


Evaluating GSM8k:  50%|█████     | 1/2 [00:17<00:17, 17.91s/it]

Raw response from model: math_solver


Evaluating GSM8k: 100%|██████████| 2/2 [00:38<00:00, 19.23s/it]

GSM8k Accuracy: 0.00%

Sample GSM8k Results:
Q: Janet has 3 apples. She buys 5 more. How many apples does she have now?
A (predicted): None
A (actual): 8.0
Correct: False

Q: A train travels 120 kilometers in 2 hours. What is its speed in kilometers per hour?
A (predicted): None
A (actual): 60.0
Correct: False




