# SemEval 2026 Task 8 - Task B: Generation

This notebook implements **Task B: Generation** for MTRAGEval.

**Goal:** Given a conversation, generate an answer to the last user question.

**Output Format:** JSONL with `predictions` field containing `[{"text": "..."}]`.

## 1. Setup

In [None]:
import os
import sys
import json
from tqdm import tqdm

if os.path.exists("src"):
    PROJECT_ROOT = os.getcwd()
else:
    PROJECT_ROOT = os.path.abspath("..")

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.graph import initialize_graph

In [None]:
# --- CONFIGURATION ---
TEAM_NAME = "Gbgers"
DOMAINS = ["govt", "clapnq", "fiqa", "cloud"]

# TEST MODE: Set to True for quick verification
TEST_MODE = True
TEST_QUERY_LIMIT = 5  # Queries per domain in test mode

CONVERSATIONS_FILE = os.path.join(PROJECT_ROOT, "dataset/human/conversations/conversations.json")
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "data/submissions")
OUTPUT_FILE = os.path.join(OUTPUT_DIR, f"submission_TaskB_{TEAM_NAME}.jsonl")

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Processing domains: {DOMAINS}")
if TEST_MODE:
    print(f"‚ö†Ô∏è TEST MODE: Processing only {TEST_QUERY_LIMIT} queries per domain.")

## 2. Helper Functions

In [None]:
def extract_last_user_question(messages):
    """Extract the last user message from conversation."""
    for msg in reversed(messages):
        if msg.get("speaker") == "user":
            return msg.get("text", "")
    return ""

def format_input_for_output(messages):
    """Format messages for submission output."""
    return [{"speaker": m["speaker"], "text": m["text"]} for m in messages]

## 3. Initialize Graph and Load Data

In [None]:
# Initialize the RAG graph
print("üîß Initializing RAG graph...")
app = initialize_graph()
print("‚úÖ Graph ready.")

In [None]:
# Load conversations
print("üìÇ Loading conversations...")
with open(CONVERSATIONS_FILE, 'r') as f:
    all_conversations = json.load(f)
print(f"Total conversations: {len(all_conversations)}")

## 4. Run Generation

In [None]:
all_results = []

for domain in DOMAINS:
    print(f"\n{'='*40}\nüåç DOMAIN: {domain.upper()}\n{'='*40}")
    
    # Filter by domain (substring match)
    domain_convs = [c for c in all_conversations if domain.lower() in c.get("domain", "").lower()]
    print(f"Found {len(domain_convs)} conversations")
    
    if not domain_convs:
        continue
    
    if TEST_MODE:
        print(f"‚úÇÔ∏è TEST MODE: Processing {TEST_QUERY_LIMIT} queries")
        domain_convs = domain_convs[:TEST_QUERY_LIMIT]
    
    print(f"üöÄ Running generation...")
    for conv in tqdm(domain_convs):
        messages = conv.get("messages", [])
        question = extract_last_user_question(messages)
        
        if not question:
            continue
        
        try:
            # Invoke graph
            response = app.invoke({"question": question})
            gen_text = response.get("generation", "No Answer")
        except Exception as e:
            print(f"Error: {e}")
            gen_text = "Error"
        
        # Format output
        all_results.append({
            "conversation_id": conv.get("author"),  # Using author as ID if no explicit ID
            "Collection": f"mt-rag-{domain}",
            "input": format_input_for_output(messages),
            "predictions": [{"text": gen_text}]
        })

print(f"\n‚úÖ Generated {len(all_results)} answers.")

## 5. Save Results

In [None]:
print(f"üíæ Saving {len(all_results)} results to {OUTPUT_FILE}...")
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    for item in all_results:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')
print("‚úÖ Done!")

# Validation
if all_results:
    sample = all_results[0]
    if "predictions" in sample and isinstance(sample["predictions"], list):
        print("\033[92mVALIDATION PASS: Structure correct.\033[0m")
    else:
        print("\033[91mVALIDATION FAIL: 'predictions' format error.\033[0m")