In [None]:
# ==============================================================================
# TASK ACCURACY EVALUATION - FOUR-WHEELER MODEL (LEXUS)
# Metrics: Exact Match, Partial Match, Keyword Score
# ==============================================================================

# ========== CELL 1: Install Packages ==========
!pip install -q accelerate bitsandbytes transformers

# Note: After running Cell 1, restart runtime then run cells 2-9


# ========== CELL 2: Import Libraries ==========
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import pandas as pd
import re
import time


# ========== CELL 3: Define Dataset ==========
dataset = [
    {
        "question": "What is the purpose of the SRS airbags in the vehicle?", 
        "answer": "The SRS airbags are designed to deploy in the event of a crash or sudden stop, providing protection for the occupants of the vehicle."
    },
    {
        "question": "What is the function of the steering wheel?", 
        "answer": "Adjusting the steering wheel"
    },
    {
        "question": "What is the procedure for connecting a Bluetooth audio player?", 
        "answer": "Connecting a Bluetooth audio player involves selecting a Bluetooth device, registering the device, and then connecting it to the vehicle's Bluetooth system."
    },
    {
        "question": "If your vehicle overheats", 
        "answer": "Check the coolant level and condition, and refer to the owner's manual for guidance on how to address the issue."
    },
    {
        "question": "What is the recommended approach for replacing genuine Lexus parts or accessories in the vehicle?", 
        "answer": "Lexus recommends using genuine Lexus parts or accessories for replacement, but other parts or accessories of matching quality can also be used."
    },
    {
        "question": "What is the recommended procedure for removing and disposing of the SRS airbag and seat belt pretensioner devices from a Lexus vehicle before scrapping?", 
        "answer": "Have the systems removed and disposed of by an authorized Lexus dealer or a duly qualified and equipped professional."
    }
]


# ========== CELL 4: Define Accuracy Metrics Functions ==========
def normalize_text(text):
    """Normalize text for comparison by lowercasing and removing punctuation"""
    text = text.lower().strip()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text)
    return text

def exact_match(prediction, reference):
    """Check if prediction exactly matches reference after normalization"""
    return normalize_text(prediction) == normalize_text(reference)

def partial_match(prediction, reference, threshold=0.3):
    """Check if prediction contains at least threshold% of reference words"""
    pred_words = set(normalize_text(prediction).split())
    ref_words = set(normalize_text(reference).split())
    if len(ref_words) == 0:
        return False
    overlap = len(pred_words & ref_words) / len(ref_words)
    return overlap >= threshold

def keyword_match(prediction, reference):
    """Calculate percentage of non-stopword keywords from reference in prediction"""
    ref_words = set(normalize_text(reference).split())
    pred_words = set(normalize_text(prediction).split())
    
    stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 
                 'to', 'of', 'and', 'or', 'for', 'in', 'on', 'at', 'by', 'with'}
    
    ref_keywords = ref_words - stopwords
    pred_keywords = pred_words - stopwords
    
    if len(ref_keywords) == 0:
        return 0.0
    
    return len(pred_keywords & ref_keywords) / len(ref_keywords)


# ========== CELL 5: Configure Model ==========
model_name = "Prithwiraj731/FourWheeler-Gemma-2B"

print("Model Configuration:")
print(f"  Model: {model_name}")
print(f"  Type: Full merged model")


# ========== CELL 6: Load Model ==========
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Loading full model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16
)

model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Model loaded successfully\n")


# ========== CELL 7: Define Answer Generation Function ==========
def generate_answer(question, max_new_tokens=100):
    """Generate answer using optimized settings"""
    prompt = f"{question}\n"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = generated_text[len(prompt):].strip()
    answer = answer.split('\n')[0].strip()
    
    if not answer or len(answer.strip()) == 0:
        answer = "No answer generated"
    
    return answer


# ========== CELL 8: Generate Predictions and Calculate Metrics ==========
print("="*70)
print("GENERATING ANSWERS AND CALCULATING ACCURACY METRICS")
print("="*70)

results = []
exact_matches = 0
partial_matches = 0
generation_times = []

for i, item in enumerate(dataset):
    question = item["question"]
    reference = item["answer"]
    
    print(f"\nQuestion {i+1}/{len(dataset)}")
    print(f"Q: {question[:70]}...")
    
    start_time = time.time()
    prediction = generate_answer(question)
    gen_time = time.time() - start_time
    generation_times.append(gen_time)
    
    # Calculate metrics
    is_exact = exact_match(prediction, reference)
    is_partial = partial_match(prediction, reference, threshold=0.3)
    keyword_score = keyword_match(prediction, reference)
    
    if is_exact:
        exact_matches += 1
    if is_partial:
        partial_matches += 1
    
    results.append({
        'question': question,
        'reference': reference,
        'prediction': prediction,
        'exact_match': is_exact,
        'partial_match': is_partial,
        'keyword_score': keyword_score
    })
    
    print(f"Generated: {prediction[:70]}...")
    print(f"Reference: {reference[:70]}...")
    print(f"Exact Match: {'YES' if is_exact else 'NO'}")
    print(f"Partial Match (30%): {'YES' if is_partial else 'NO'}")
    print(f"Keyword Score: {keyword_score:.1%}")
    print(f"Time: {gen_time:.2f}s")

avg_gen_time = sum(generation_times) / len(generation_times)
print(f"\nAverage generation time: {avg_gen_time:.2f}s")


# ========== CELL 9: Display Summary Results ==========
print("\n" + "="*70)
print("FOUR-WHEELER MODEL - TASK ACCURACY RESULTS")
print("="*70)

total = len(dataset)
exact_accuracy = exact_matches / total * 100
partial_accuracy = partial_matches / total * 100
avg_keyword_score = sum(r['keyword_score'] for r in results) / total * 100

print(f"\nSummary Metrics:")
print(f"  Exact Match Accuracy:   {exact_matches}/{total} = {exact_accuracy:.1f}%")
print(f"  Partial Match Accuracy: {partial_matches}/{total} = {partial_accuracy:.1f}%")
print(f"  Average Keyword Score:  {avg_keyword_score:.1f}%")

print("\nMetric Definitions:")
print("  - Exact Match: Prediction matches reference exactly (after normalization)")
print("  - Partial Match: At least 30% of reference words appear in prediction")
print("  - Keyword Score: Percentage of important words (non-stopwords) matched")


# ========== CELL 10: Results DataFrame ==========
results_df = pd.DataFrame({
    'Question': [r['question'][:45] + '...' if len(r['question']) > 45 else r['question'] for r in results],
    'Prediction': [r['prediction'][:45] + '...' if len(r['prediction']) > 45 else r['prediction'] for r in results],
    'Exact': ['YES' if r['exact_match'] else 'NO' for r in results],
    'Partial': ['YES' if r['partial_match'] else 'NO' for r in results],
    'Keyword%': [f"{r['keyword_score']*100:.1f}%" for r in results]
})

print("\n" + "="*70)
print("DETAILED RESULTS TABLE")
print("="*70)
display(results_df)


# ========== CELL 11: Save Results (Optional) ==========
# Uncomment to save and download results

# results_df.to_csv('task_accuracy_4wheeler_results.csv', index=False)
# print("\nResults saved to 'task_accuracy_4wheeler_results.csv'")

# from google.colab import files
# files.download('task_accuracy_4wheeler_results.csv')

