# Baseline Evaluation: Mistral 7B Instruct

Evaluate the base Mistral 7B model (no LoRA) on structured JSON generation.

**Run in Colab with GPU (T4 or better)**

In [None]:
# Install dependencies
!pip install -q transformers accelerate bitsandbytes torch

In [None]:
# Clone repo if running in Colab
import os
if not os.path.exists('lora-support-json'):
    !git clone https://github.com/YOUR_USERNAME/lora-support-json.git
    %cd lora-support-json
else:
    print("Repo already cloned")

In [None]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Dict, List
import re
from collections import Counter

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")

In [None]:
# Load eval dataset
EVAL_PATH = "data/eval.jsonl"

eval_data = []
with open(EVAL_PATH, 'r') as f:
    for line in f:
        eval_data.append(json.loads(line))

print(f"Loaded {len(eval_data)} eval examples")
print("\nExample:")
ex = eval_data[0]
print("USER:", ex['messages'][1]['content'][:100])
print("EXPECTED:", ex['messages'][2]['content'][:100])

In [None]:
# Load Mistral 7B with 4-bit quantization (fits in T4 GPU)
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

print("Model loaded successfully!")

In [None]:
INTENTS = ["refund", "cancel", "billing", "tech_support", "shipping", "other"]
PRIORITIES = ["low", "medium", "high"]

def is_valid_json(s: str) -> bool:
    try:
        json.loads(s)
        return True
    except:
        return False

def is_schema_compliant(s: str) -> bool:
    try:
        obj = json.loads(s)
        required_keys = ["intent", "priority", "entities", "needs_clarification", "clarifying_question"]
        if list(obj.keys()) != required_keys:
            return False
        if obj["intent"] not in INTENTS:
            return False
        if obj["priority"] not in PRIORITIES:
            return False
        if "order_id" not in obj["entities"] or "product" not in obj["entities"]:
            return False
        if not isinstance(obj["needs_clarification"], bool):
            return False
        return True
    except:
        return False

def extract_json_from_text(text: str) -> str:
    """Try to extract JSON from markdown/text wrapper"""
    # Remove markdown code blocks
    text = re.sub(r'```json\s*', '', text)
    text = re.sub(r'```\s*', '', text)
    # Try to find JSON object
    match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text)
    if match:
        return match.group(0)
    return text.strip()

print("Validation functions ready")

In [None]:
def run_inference(example: Dict, max_new_tokens: int = 256) -> str:
    """Run inference on a single example"""
    messages = example['messages'][:2]  # system + user
    
    # Format as Mistral Instruct template
    prompt = f"""<s>[INST] {messages[0]['content']}

{messages[1]['content']} [/INST]"""
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response.strip()

# Test on one example
test_output = run_inference(eval_data[0])
print("Test output:")
print(test_output)
print("\nValid JSON?", is_valid_json(extract_json_from_text(test_output)))
print("Schema compliant?", is_schema_compliant(extract_json_from_text(test_output)))

In [None]:
# Run evaluation on full eval set (or subset for speed)
EVAL_SIZE = 100  # Use 100 for quick test, 800 for full eval

results = []
valid_json_count = 0
schema_compliant_count = 0
intent_correct = 0
total = 0

print(f"Running inference on {EVAL_SIZE} examples...\n")

for i, example in enumerate(eval_data[:EVAL_SIZE]):
    if i % 10 == 0:
        print(f"Progress: {i}/{EVAL_SIZE}")
    
    # Get expected output
    expected = json.loads(example['messages'][2]['content'])
    
    # Run inference
    output = run_inference(example)
    
    # Extract JSON
    json_str = extract_json_from_text(output)
    
    # Validate
    valid_json = is_valid_json(json_str)
    schema_valid = is_schema_compliant(json_str) if valid_json else False
    
    if valid_json:
        valid_json_count += 1
        predicted = json.loads(json_str)
        
        if schema_valid:
            schema_compliant_count += 1
            
            # Check intent accuracy
            if predicted['intent'] == expected['intent']:
                intent_correct += 1
    
    results.append({
        'user_message': example['messages'][1]['content'],
        'expected': expected,
        'predicted_raw': output,
        'predicted_json': json_str,
        'valid_json': valid_json,
        'schema_compliant': schema_valid
    })
    
    total += 1

print("\n" + "="*60)
print("BASELINE EVALUATION RESULTS")
print("="*60)
print(f"Total examples: {total}")
print(f"Valid JSON: {valid_json_count}/{total} ({valid_json_count/total*100:.1f}%)")
print(f"Schema compliant: {schema_compliant_count}/{total} ({schema_compliant_count/total*100:.1f}%)")
print(f"Intent accuracy: {intent_correct}/{total} ({intent_correct/total*100:.1f}%)")
print("="*60)

In [None]:
# Show failures
print("\nSample failures:")
print("="*60)
failures = [r for r in results if not r['valid_json'] or not r['schema_compliant']]
for i, fail in enumerate(failures[:3]):
    print(f"\nFAILURE {i+1}:")
    print(f"USER: {fail['user_message'][:80]}...")
    print(f"PREDICTED: {fail['predicted_raw'][:150]}...")
    print(f"Valid JSON: {fail['valid_json']}, Schema compliant: {fail['schema_compliant']}")
    print("-"*60)

In [None]:
# Save results
with open('baseline_results.json', 'w') as f:
    json.dump({
        'model': MODEL_NAME,
        'eval_size': total,
        'valid_json_rate': valid_json_count / total,
        'schema_compliance_rate': schema_compliant_count / total,
        'intent_accuracy': intent_correct / total,
        'detailed_results': results
    }, f, indent=2)

print("\nResults saved to baseline_results.json")
print("Download this file to commit to your repo!")