In [None]:
# ===================================================================
# IMPROVED BERT SCORE EVALUATION FOR GOOGLE COLAB
# Fixed Version with Multiple Prompt Format Options
# ===================================================================

# ========== CELL 1: Install Packages ==========
!pip install -q bert-score sentencepiece accelerate bitsandbytes peft transformers

# Note: After running Cell 1, restart runtime: Runtime ‚Üí Restart runtime
# Then run cells 2-10 (skip cell 1 after restart)


# ========== CELL 2: Authentication (if using gated models) ==========
from huggingface_hub import login

# Option A: Manual login (you'll paste token when prompted)
login()

# Option B: Use Colab Secrets (recommended)
# Uncomment below if you set up HF_TOKEN in Colab secrets
# from google.colab import userdata
# hf_token = userdata.get('HF_TOKEN')
# login(token=hf_token)


# ========== CELL 3: Import Libraries ==========
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from bert_score import score
import pandas as pd
import time


# ========== CELL 4: Define Your Dataset ==========
# CUSTOMIZE THIS: Replace with your own test questions and expected answers
dataset = [
    {
        "question": "What is the recommended lubrication for the engine of the BSA D14/4 Bantam Supreme motorcycle?", 
        "answer": "Engine lubrication: BSA recommends using a mixture of 10W-30 oil, with a minimum of 10W-40 oil, for the engine of the BSA D14/4 Bantam Supreme motorcycle."
    },
    {
        "question": "Where should an inexperienced owner consult for assistance with major repair work?", 
        "answer": "His B.S.A. dealer"
    },
    {
        "question": "What is the recommended procedure for claiming assistance under the B.S.A. guarantee?", 
        "answer": "Claim assistance through the dealer from whom the motorcycle was purchased."
    },
    {
        "question": "What is the correct address of the B.S.A. Service Department?", 
        "answer": "B.S.A. MOTOR CYCLES LIMITED, SERVICE DEPARTMENT, ARMOURY ROAD, BIRMINGHAM 11"
    },
    {
        "question": "What is the recommended procedure for claiming assistance under the guarantee for a new motorcycle?", 
        "answer": "The owner must do so through the dealer from whom the machine was purchased."
    },
    {
        "question": "What is the recommended torque wrench setting for the Supreme model?", 
        "answer": "1 to 3"
    }
]


# ========== CELL 5: Configure Your Model ==========
# CUSTOMIZE THIS: Change to your model name
adapter_name = "Prithwiraj731/Gemma2-2b_Two-Wheeler"
base_model_name = "google/gemma-2-2b"

# IMPORTANT: Choose which prompt format to use
# Try different formats until you find what works for YOUR model
PROMPT_FORMAT = "simple"  # Options: "simple", "instruction", "gemma", "chat"

print("üîß Model Configuration:")
print(f"   Adapter: {adapter_name}")
print(f"   Base Model: {base_model_name}")
print(f"   Prompt Format: {PROMPT_FORMAT}")


# ========== CELL 6: Load Model ==========
print("\nüì• Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(adapter_name)

print("üì• Loading base model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16
)

print("üì• Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, adapter_name)
model.eval()

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("‚úÖ Model loaded successfully!\n")


# ========== CELL 7: Define Answer Generation Function with Multiple Format Options ==========
def generate_answer(question, max_new_tokens=100, temperature=0.1):
    """
    Generate answer with multiple prompt format options.
    
    Args:
        question: The input question
        max_new_tokens: Maximum tokens to generate (reduced to 100)
        temperature: Generation randomness (0.1=very deterministic)
    
    Returns:
        Generated answer text
    """
    
    # Choose prompt format based on PROMPT_FORMAT setting
    if PROMPT_FORMAT == "simple":
        # Simple Q&A format (most common for fine-tuned models)
        prompt = f"{question}\n"
    
    elif PROMPT_FORMAT == "instruction":
        # Instruction-style format
        prompt = f"### Question:\n{question}\n\n### Answer:\n"
    
    elif PROMPT_FORMAT == "gemma":
        # Gemma 2 Chat Template Format
        prompt = f"<start_of_turn>user\n{question}<end_of_turn>\n<start_of_turn>model\n"
    
    elif PROMPT_FORMAT == "chat":
        # Generic chat format
        prompt = f"User: {question}\nAssistant:"
    
    else:
        # Default fallback
        prompt = question
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=False,  # Greedy decoding for most factual answers
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2  # Prevent repetition
        )
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract answer based on format
    if PROMPT_FORMAT == "simple":
        # Remove the prompt
        answer = generated_text[len(prompt):].strip()
    
    elif PROMPT_FORMAT == "instruction":
        if "### Answer:" in generated_text:
            answer = generated_text.split("### Answer:")[-1].strip()
        else:
            answer = generated_text[len(prompt):].strip()
    
    elif PROMPT_FORMAT == "gemma":
        if "<start_of_turn>model" in generated_text:
            answer = generated_text.split("<start_of_turn>model")[-1].strip()
            if "<end_of_turn>" in answer:
                answer = answer.split("<end_of_turn>")[0].strip()
        else:
            answer = generated_text[len(prompt):].strip()
    
    elif PROMPT_FORMAT == "chat":
        if "Assistant:" in generated_text:
            answer = generated_text.split("Assistant:")[-1].strip()
        else:
            answer = generated_text[len(prompt):].strip()
    
    else:
        answer = generated_text.strip()
    
    # Clean up common issues
    answer = answer.split('\n')[0].strip()  # Take only first line
    
    # Handle empty answers
    if not answer or len(answer.strip()) == 0:
        answer = "No answer generated"
    
    return answer


# ========== CELL 8: Generate Predictions ==========
print("=" * 70)
print("ü§ñ GENERATING ANSWERS")
print("=" * 70)

predictions = []
references = []
generation_times = []

for i, item in enumerate(dataset):
    question = item["question"]
    reference = item["answer"]
    
    print(f"\nüìå Question {i+1}/{len(dataset)}")
    print(f"Q: {question[:80]}...")
    
    # Time the generation
    start_time = time.time()
    prediction = generate_answer(question, temperature=0.1)
    gen_time = time.time() - start_time
    generation_times.append(gen_time)
    
    print(f"‚ú® Generated: {prediction[:100]}...")
    print(f"üìñ Reference: {reference[:100]}...")
    print(f"‚è±Ô∏è  Time: {gen_time:.2f}s")
    
    predictions.append(prediction)
    references.append(reference)

print(f"\n‚è±Ô∏è  Average generation time: {sum(generation_times)/len(generation_times):.2f}s")


# ========== CELL 9: Calculate BERT Score ==========
print("\n" + "=" * 70)
print("üìä CALCULATING BERT SCORES")
print("=" * 70)

P, R, F1 = score(
    predictions, 
    references, 
    lang="en",
    verbose=True,
    rescale_with_baseline=True
)


# ========== CELL 10: Display Detailed Results ==========
print("\n" + "=" * 70)
print("üìà BERT SCORE RESULTS (DETAILED)")
print("=" * 70)

for i, item in enumerate(dataset):
    print(f"\n{'='*70}")
    print(f"Question {i+1}: {item['question'][:60]}...")
    print(f"{'-'*70}")
    print(f"Generated: {predictions[i][:120]}...")
    print(f"Reference: {references[i][:120]}...")
    print(f"{'-'*70}")
    print(f"  üìä Precision: {P[i].item():.4f}")
    print(f"  üìä Recall:    {R[i].item():.4f}")
    print(f"  üìä F1 Score:  {F1[i].item():.4f}")  # FIXED: Was 'F' now 'F1'


# ========== CELL 11: Display Summary Statistics ==========
print("\n" + "=" * 70)
print("üéØ AVERAGE BERT SCORES")
print("=" * 70)

avg_precision = P.mean().item()
avg_recall = R.mean().item()
avg_f1 = F1.mean().item()

print(f"\n  üìä Average Precision: {avg_precision:.4f} ({avg_precision*100:.2f}%)")
print(f"  üìä Average Recall:    {avg_recall:.4f} ({avg_recall*100:.2f}%)")
print(f"  üìä Average F1 Score:  {avg_f1:.4f} ({avg_f1*100:.2f}%)")

# Score interpretation
print("\n" + "=" * 70)
print("üìñ SCORE INTERPRETATION")
print("=" * 70)
print("\nBERT Score Range: -1.0 (worst) to 1.0 (best)")
print("\nQuality Guide:")
print("  üü¢ 0.7 - 1.0  : Excellent semantic similarity")
print("  üü° 0.5 - 0.7  : Good similarity")
print("  üü† 0.3 - 0.5  : Moderate similarity")
print("  üî¥ 0.0 - 0.3  : Poor similarity")
print("  ‚ö´ < 0.0       : Very poor / opposing meaning")

if avg_f1 >= 0.7:
    status = "üü¢ EXCELLENT"
elif avg_f1 >= 0.5:
    status = "üü° GOOD"
elif avg_f1 >= 0.3:
    status = "üü† MODERATE"
else:
    status = "üî¥ NEEDS IMPROVEMENT"

print(f"\nYour Model Status: {status}")


# ========== CELL 12: Results DataFrame ==========
results_df = pd.DataFrame({
    'Question': [item['question'][:50] + '...' if len(item['question']) > 50 else item['question'] for item in dataset],
    'Generated': [p[:50] + '...' if len(p) > 50 else p for p in predictions],
    'Reference': [r[:50] + '...' if len(r) > 50 else r for r in references],
    'Precision': [f"{p.item():.4f}" for p in P],
    'Recall': [f"{r.item():.4f}" for r in R],
    'F1': [f"{f.item():.4f}" for f in F1]
})

print("\n" + "=" * 70)
print("üìã RESULTS SUMMARY TABLE")
print("=" * 70)
display(results_df)


# ========== CELL 13: Save Results (Optional) ==========
# Uncomment to save and download results

# results_df.to_csv('bert_score_results.csv', index=False)
# print("\n‚úÖ Results saved to 'bert_score_results.csv'")

# from google.colab import files
# files.download('bert_score_results.csv')


# ========== CELL 14: Test All Prompt Formats (DIAGNOSTIC) ==========
# Run this cell to test which format works best for your model

print("\n" + "=" * 70)
print("üß™ TESTING ALL PROMPT FORMATS")
print("=" * 70)

test_question = dataset[0]["question"]
formats = ["simple", "instruction", "gemma", "chat"]

print(f"\nTest Question: {test_question}\n")

for fmt in formats:
    old_format = PROMPT_FORMAT
    PROMPT_FORMAT = fmt
    
    answer = generate_answer(test_question, max_new_tokens=80, temperature=0.1)
    
    print(f"\n{'='*70}")
    print(f"Format: {fmt.upper()}")
    print(f"{'-'*70}")
    print(f"Answer: {answer[:150]}")
    
    PROMPT_FORMAT = old_format

print("\n" + "=" * 70)
print("üí° RECOMMENDATION:")
print("Look at the outputs above and choose the format that gives actual")
print("answers (not repetitions). Then update PROMPT_FORMAT in Cell 5!")
print("=" * 70)

