In [None]:
# ===================================================================
# IMPROVED BERT SCORE EVALUATION FOR 4-WHEELER MODEL (LEXUS)
# Using Optimized Settings from 2-Wheeler Testing
# ===================================================================

# ========== CELL 1: Install Packages ==========
!pip install -q bert-score sentencepiece accelerate bitsandbytes peft transformers

# Note: After running Cell 1, restart runtime: Runtime ‚Üí Restart runtime
# Then run cells 2-12 (skip cell 1 after restart)



# ========== CELL 2: Import Libraries ==========
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# from peft import PeftModel  # ‚Üê Remove this line
from bert_score import score
import pandas as pd
import time


# ========== CELL 3: Define 4-Wheeler Dataset (Lexus) ==========
dataset = [
    {
        "question": "What is the purpose of the SRS airbags in the vehicle?", 
        "answer": "The SRS airbags are designed to deploy in the event of a crash or sudden stop, providing protection for the occupants of the vehicle."
    },
    {
        "question": "What is the function of the steering wheel?", 
        "answer": "Adjusting the steering wheel"
    },
    {
        "question": "What is the procedure for connecting a Bluetooth audio player?", 
        "answer": "Connecting a Bluetooth audio player involves selecting a Bluetooth device, registering the device, and then connecting it to the vehicle's Bluetooth system."
    },
    {
        "question": "If your vehicle overheats", 
        "answer": "Check the coolant level and condition, and refer to the owner's manual for guidance on how to address the issue."
    },
    {
        "question": "What is the recommended approach for replacing genuine Lexus parts or accessories in the vehicle?", 
        "answer": "Lexus recommends using genuine Lexus parts or accessories for replacement, but other parts or accessories of matching quality can also be used."
    },
    {
        "question": "What is the recommended procedure for removing and disposing of the SRS airbag and seat belt pretensioner devices from a Lexus vehicle before scrapping?", 
        "answer": "Have the systems removed and disposed of by an authorized Lexus dealer or a duly qualified and equipped professional."
    }
]


# ========== CELL 4: Configure 4-Wheeler Model ==========
# CUSTOMIZE THIS: Change to your 4-wheeler model
adapter_name = "Prithwiraj731/FourWheeler-Gemma-2B"
base_model_name = "google/gemma-2-2b"

# Based on 2-wheeler testing, "simple" format worked best
PROMPT_FORMAT = "simple"

print("üîß Model Configuration:")
print(f"   Adapter: {adapter_name}")
print(f"   Base Model: {base_model_name}")
print(f"   Prompt Format: {PROMPT_FORMAT}")


# ========== CELL 5: Configure 4-Wheeler Model ==========
adapter_name = "Prithwiraj731/FourWheeler-Gemma-2B"  # Full merged model
base_model_name = None  # Not needed for merged models

PROMPT_FORMAT = "simple"

print("üîß Model Configuration:")
print(f"   Model: {adapter_name}")
print(f"   Type: Full merged model (not LoRA)")
print(f"   Prompt Format: {PROMPT_FORMAT}")


# ========== CELL 6: Load Model (FULL MODEL VERSION) ==========
print("\nüì• Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(adapter_name)

print("üì• Loading full model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load as a complete model (not base + adapter)
model = AutoModelForCausalLM.from_pretrained(
    adapter_name,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16
)

model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("‚úÖ Model loaded successfully!\n")

# ========== CELL 7: Generate Predictions ==========
print("=" * 70)
print("üöó GENERATING ANSWERS (4-WHEELER LEXUS MODEL)")
print("=" * 70)

predictions = []
references = []
generation_times = []

for i, item in enumerate(dataset):
    question = item["question"]
    reference = item["answer"]
    
    print(f"\nüìå Question {i+1}/{len(dataset)}")
    print(f"Q: {question[:80]}...")
    
    start_time = time.time()
    prediction = generate_answer(question, temperature=0.1)
    gen_time = time.time() - start_time
    generation_times.append(gen_time)
    
    print(f"‚ú® Generated: {prediction[:100]}...")
    print(f"üìñ Reference: {reference[:100]}...")
    print(f"‚è±Ô∏è  Time: {gen_time:.2f}s")
    
    predictions.append(prediction)
    references.append(reference)

print(f"\n‚è±Ô∏è  Average generation time: {sum(generation_times)/len(generation_times):.2f}s")


# ========== CELL 8: Calculate BERT Score ==========
print("\n" + "=" * 70)
print("üìä CALCULATING BERT SCORES")
print("=" * 70)

P, R, F1 = score(
    predictions, 
    references, 
    lang="en",
    verbose=True,
    rescale_with_baseline=True
)


# ========== CELL 9: Display Detailed Results ==========
print("\n" + "=" * 70)
print("üìà BERT SCORE RESULTS (4-WHEELER LEXUS MODEL)")
print("=" * 70)

for i, item in enumerate(dataset):
    print(f"\n{'='*70}")
    print(f"Question {i+1}: {item['question'][:60]}...")
    print(f"{'-'*70}")
    print(f"Generated: {predictions[i][:120]}...")
    print(f"Reference: {references[i][:120]}...")
    print(f"{'-'*70}")
    print(f"  üìä Precision: {P[i].item():.4f}")
    print(f"  üìä Recall:    {R[i].item():.4f}")
    print(f"  üìä F1 Score:  {F1[i].item():.4f}")


# ========== CELL 10: Summary Statistics ==========
print("\n" + "=" * 70)
print("üéØ AVERAGE BERT SCORES (4-WHEELER MODEL)")
print("=" * 70)

avg_precision = P.mean().item()
avg_recall = R.mean().item()
avg_f1 = F1.mean().item()

print(f"\n  üìä Average Precision: {avg_precision:.4f} ({avg_precision*100:.2f}%)")
print(f"  üìä Average Recall:    {avg_recall:.4f} ({avg_recall*100:.2f}%)")
print(f"  üìä Average F1 Score:  {avg_f1:.4f} ({avg_f1*100:.2f}%)")

print("\n" + "=" * 70)
print("üìñ SCORE INTERPRETATION")
print("=" * 70)
print("\nBERT Score Range: -1.0 (worst) to 1.0 (best)")
print("\nQuality Guide:")
print("  üü¢ 0.7 - 1.0  : Excellent")
print("  üü° 0.5 - 0.7  : Good")
print("  üü† 0.3 - 0.5  : Moderate")
print("  üî¥ 0.0 - 0.3  : Poor")
print("  ‚ö´ < 0.0       : Very poor")

if avg_f1 >= 0.7:
    status = "üü¢ EXCELLENT"
elif avg_f1 >= 0.5:
    status = "üü° GOOD"
elif avg_f1 >= 0.3:
    status = "üü† MODERATE"
else:
    status = "üî¥ NEEDS IMPROVEMENT"

print(f"\n4-Wheeler Model Status: {status}")


# ========== CELL 11: Results DataFrame ==========
results_df = pd.DataFrame({
    'Question': [item['question'][:50] + '...' if len(item['question']) > 50 else item['question'] for item in dataset],
    'Generated': [p[:50] + '...' if len(p) > 50 else p for p in predictions],
    'Reference': [r[:50] + '...' if len(r) > 50 else r for r in references],
    'Precision': [f"{p.item():.4f}" for p in P],
    'Recall': [f"{r.item():.4f}" for r in R],
    'F1': [f"{f.item():.4f}" for f in F1]
})

print("\n" + "=" * 70)
print("üìã RESULTS SUMMARY TABLE")
print("=" * 70)
display(results_df)


# ========== CELL 12: Save Results (Optional) ==========
# Uncomment to save and download

# results_df.to_csv('bert_score_4wheeler_results.csv', index=False)
# print("\n‚úÖ Results saved to 'bert_score_4wheeler_results.csv'")

# from google.colab import files
# files.download('bert_score_4wheeler_results.csv')


# ========== CELL 13: Compare with 2-Wheeler ==========
print("\n" + "=" * 70)
print("üìä MODEL COMPARISON")
print("=" * 70)
print("\nüèçÔ∏è  2-Wheeler (BSA) Model BERT F1: 0.0321 (3.21%)")
print(f"üöó 4-Wheeler (Lexus) Model BERT F1: {avg_f1:.4f} ({avg_f1*100:.2f}%)")

if avg_f1 > 0.0321:
    diff = ((avg_f1 - 0.0321) / 0.0321) * 100
    print(f"\n‚úÖ 4-Wheeler performs {diff:.1f}% BETTER than 2-Wheeler")
elif avg_f1 < 0.0321:
    diff = ((0.0321 - avg_f1) / 0.0321) * 100
    print(f"\n‚ö†Ô∏è 4-Wheeler performs {diff:.1f}% WORSE than 2-Wheeler")
else:
    print("\n‚û°Ô∏è Both models perform similarly")


# ========== CELL 14: Test Different Formats (Diagnostic) ==========
# Run this if results are poor to test other prompt formats

print("\n" + "=" * 70)
print("üß™ TESTING ALL PROMPT FORMATS")
print("=" * 70)

test_question = dataset[0]["question"]
formats_to_test = {
    "simple": f"{test_question}\n",
    "instruction": f"### Question:\n{test_question}\n\n### Answer:\n",
    "qa": f"Question: {test_question}\nAnswer:",
    "chat": f"User: {test_question}\nAssistant:"
}

print(f"\nTest Question: {test_question}\n")

for fmt_name, fmt_prompt in formats_to_test.items():
    inputs = tokenizer(fmt_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=80,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = generated[len(fmt_prompt):].strip().split('\n')[0]
    
    print(f"\n{'='*70}")
    print(f"Format: {fmt_name.upper()}")
    print(f"{'-'*70}")
    print(f"Answer: {answer[:150]}")

print("\n" + "=" * 70)
print("üí° If 'simple' format doesn't work well, update PROMPT_FORMAT in Cell 4")
print("=" * 70)

