# MTRAGEval - RAG Pipeline Testing

This notebook tests the complete Self-CRAG pipeline for SemEval 2026 Task 8.

## Steps:
1. Install dependencies and load Llama 3.1
2. Initialize retrieval components
3. Build and test the LangGraph workflow
4. Run multi-turn conversations

## 1. Environment Setup

In [None]:
# Install dependencies (run once)
!pip install -q langchain==0.1.10 langchain-community==0.0.25 langchain-huggingface==0.0.3 langgraph==0.0.26
!pip install -q transformers accelerate bitsandbytes
!pip install -q chromadb==0.4.24 sentence-transformers==2.5.1

In [None]:
import sys
import os
import torch

# Add src to path
sys.path.insert(0, '../')

# Verify GPU
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

In [None]:
# HuggingFace login for gated models (Llama 3.1)
from huggingface_hub import login

# Use Kaggle secrets or environment variable
# login(token="your_hf_token_here")

## 2. Load Llama 3.1 (4-bit Quantized)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain_huggingface import HuggingFacePipeline

MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"

def load_llama_4bit():
    """
    Load Llama 3.1 with 4-bit NF4 quantization.
    
    TODO: Implement model loading with BitsAndBytesConfig
    """
    raise NotImplementedError("Implement 4-bit Llama loading")

In [None]:
# Load LLM
# print("Loading Llama 3.1 8B (4-bit)... This may take 1-2 minutes.")
# llm = load_llama_4bit()
# print("Model loaded successfully!")

## 3. Initialize Retrieval Pipeline

In [None]:
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

CHROMA_PERSIST_DIR = "../chromadb"
EMBEDDING_MODEL = "BAAI/bge-m3"
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"

def get_retriever_with_reranking():
    """
    Build retriever: Vector Search (Top 20) -> Rerank (Top 5).
    
    TODO: Implement retriever with cross-encoder reranking
    """
    raise NotImplementedError("Implement retriever with reranking")

In [None]:
# Initialize retriever
# retriever = get_retriever_with_reranking()
# print("Retriever initialized!")

## 4. Create LangChain Chains

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field

def create_query_rewriter(llm):
    """
    Create query rewriter chain for context-dependent questions.
    
    TODO: Implement with Llama 3 special tokens
    """
    raise NotImplementedError("Implement query rewriter")


def create_generator(llm):
    """
    Create RAG generator with I_DONT_KNOW fallback.
    
    TODO: Implement generator chain
    """
    raise NotImplementedError("Implement generator")


def create_relevance_grader(llm):
    """
    Create document relevance grader (CRAG).
    
    TODO: Implement with JSON output
    """
    raise NotImplementedError("Implement relevance grader")


def create_hallucination_grader(llm):
    """
    Create hallucination grader (Self-RAG).
    
    TODO: Implement with JSON output
    """
    raise NotImplementedError("Implement hallucination grader")

## 5. Build Self-CRAG Graph

In [None]:
from typing import TypedDict, List, Annotated, Any
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages

class GraphState(TypedDict):
    """State for Self-CRAG workflow."""
    messages: Annotated[List[BaseMessage], add_messages]
    question: str
    standalone_question: str
    documents: List[Any]
    generation: str
    documents_relevant: str  # 'yes' or 'no'
    is_hallucination: str    # 'yes' or 'no'
    retry_count: int

In [None]:
def build_self_crag_graph(retriever, query_rewriter, generator, relevance_grader, hallucination_grader):
    """
    Build the Self-CRAG LangGraph workflow.
    
    Nodes:
    - rewrite: Rewrite context-dependent queries
    - retrieve: Search vector store
    - grade_docs: Filter irrelevant documents (CRAG)
    - generate: Produce answer
    - hallucination_check: Validate generation (Self-RAG)
    - fallback: Return I_DONT_KNOW
    
    TODO: Implement graph construction
    """
    raise NotImplementedError("Implement Self-CRAG graph")

In [None]:
# Build graph
# app = build_self_crag_graph(retriever, query_rewriter, generator, relevance_grader, hallucination_grader)
# print("Self-CRAG graph built!")

## 6. Test Single Turn

In [None]:
def run_single_turn(app, question: str):
    """
    Run a single question through the pipeline.
    
    TODO: Implement single turn execution
    """
    raise NotImplementedError("Implement single turn")

In [None]:
# Test single turn
# question = "Who is the CEO of Apple?"
# response = run_single_turn(app, question)
# print(f"Q: {question}")
# print(f"A: {response}")

## 7. Test Multi-Turn Conversation

In [None]:
def run_multi_turn_conversation(app, questions: list):
    """
    Run a multi-turn conversation maintaining history.
    
    TODO: Implement multi-turn with chat history
    """
    raise NotImplementedError("Implement multi-turn conversation")

In [None]:
# Test multi-turn
# questions = [
#     "Who is the CEO of Apple?",
#     "How old is he?",
#     "When did he become CEO?"
# ]
# run_multi_turn_conversation(app, questions)

## 8. VRAM Monitoring

In [None]:
# Monitor GPU memory usage
!nvidia-smi