In [None]:
# ==============================================================================
# BLEU AND ROUGE-L EVALUATION - FOUR-WHEELER MODEL (LEXUS)
# Metrics: BLEU-1, BLEU-2, BLEU-4, ROUGE-L
# ==============================================================================

# ========== CELL 1: Install Packages ==========
!pip install -q accelerate bitsandbytes transformers rouge-score nltk

# Note: After running Cell 1, restart runtime then run cells 2-10


# ========== CELL 2: Import Libraries ==========
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import pandas as pd
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import time

nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)


# ========== CELL 3: Define Dataset ==========
dataset = [
    {
        "question": "What is the purpose of the SRS airbags in the vehicle?", 
        "answer": "The SRS airbags are designed to deploy in the event of a crash or sudden stop, providing protection for the occupants of the vehicle."
    },
    {
        "question": "What is the function of the steering wheel?", 
        "answer": "Adjusting the steering wheel"
    },
    {
        "question": "What is the procedure for connecting a Bluetooth audio player?", 
        "answer": "Connecting a Bluetooth audio player involves selecting a Bluetooth device, registering the device, and then connecting it to the vehicle's Bluetooth system."
    },
    {
        "question": "If your vehicle overheats", 
        "answer": "Check the coolant level and condition, and refer to the owner's manual for guidance on how to address the issue."
    },
    {
        "question": "What is the recommended approach for replacing genuine Lexus parts or accessories in the vehicle?", 
        "answer": "Lexus recommends using genuine Lexus parts or accessories for replacement, but other parts or accessories of matching quality can also be used."
    },
    {
        "question": "What is the recommended procedure for removing and disposing of the SRS airbag and seat belt pretensioner devices from a Lexus vehicle before scrapping?", 
        "answer": "Have the systems removed and disposed of by an authorized Lexus dealer or a duly qualified and equipped professional."
    }
]


# ========== CELL 4: Define BLEU and ROUGE Functions ==========
def calculate_bleu(prediction, reference):
    """Calculate BLEU-1, BLEU-2, and BLEU-4 scores"""
    smoothie = SmoothingFunction().method4
    reference_tokens = [reference.lower().split()]
    prediction_tokens = prediction.lower().split()
    
    if len(prediction_tokens) == 0:
        return 0.0, 0.0, 0.0
    
    bleu1 = sentence_bleu(reference_tokens, prediction_tokens, 
                         weights=(1, 0, 0, 0), smoothing_function=smoothie)
    bleu2 = sentence_bleu(reference_tokens, prediction_tokens, 
                         weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)
    bleu4 = sentence_bleu(reference_tokens, prediction_tokens, 
                         weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)
    
    return bleu1, bleu2, bleu4

def calculate_rouge(prediction, reference):
    """Calculate ROUGE-L F-measure score"""
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    scores = scorer.score(reference, prediction)
    return scores['rougeL'].fmeasure


# ========== CELL 5: Configure Model ==========
model_name = "Prithwiraj731/FourWheeler-Gemma-2B"

print("Model Configuration:")
print(f"  Model: {model_name}")
print(f"  Type: Full merged model")


# ========== CELL 6: Load Model ==========
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Loading full model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16
)

model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Model loaded successfully\n")


# ========== CELL 7: Define Answer Generation Function ==========
def generate_answer(question, max_new_tokens=100):
    """Generate answer using optimized settings"""
    prompt = f"{question}\n"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = generated_text[len(prompt):].strip()
    answer = answer.split('\n')[0].strip()
    
    if not answer or len(answer.strip()) == 0:
        answer = "No answer generated"
    
    return answer


# ========== CELL 8: Generate Predictions and Calculate Metrics ==========
print("="*70)
print("GENERATING ANSWERS AND CALCULATING BLEU/ROUGE METRICS")
print("="*70)

results = []
all_bleu1, all_bleu2, all_bleu4, all_rouge = [], [], [], []
generation_times = []

for i, item in enumerate(dataset):
    question = item["question"]
    reference = item["answer"]
    
    print(f"\nQuestion {i+1}/{len(dataset)}")
    print(f"Q: {question[:70]}...")
    
    start_time = time.time()
    prediction = generate_answer(question)
    gen_time = time.time() - start_time
    generation_times.append(gen_time)
    
    # Calculate metrics
    bleu1, bleu2, bleu4 = calculate_bleu(prediction, reference)
    rouge_l = calculate_rouge(prediction, reference)
    
    all_bleu1.append(bleu1)
    all_bleu2.append(bleu2)
    all_bleu4.append(bleu4)
    all_rouge.append(rouge_l)
    
    results.append({
        'question': question,
        'reference': reference,
        'prediction': prediction,
        'bleu1': bleu1,
        'bleu2': bleu2,
        'bleu4': bleu4,
        'rouge_l': rouge_l
    })
    
    print(f"Generated: {prediction[:70]}...")
    print(f"Reference: {reference[:70]}...")
    print(f"BLEU-1: {bleu1:.4f} | BLEU-2: {bleu2:.4f} | BLEU-4: {bleu4:.4f}")
    print(f"ROUGE-L: {rouge_l:.4f}")
    print(f"Time: {gen_time:.2f}s")

avg_gen_time = sum(generation_times) / len(generation_times)
print(f"\nAverage generation time: {avg_gen_time:.2f}s")


# ========== CELL 9: Display Summary Results ==========
print("\n" + "="*70)
print("FOUR-WHEELER MODEL - BLEU/ROUGE-L RESULTS")
print("="*70)

avg_bleu1 = sum(all_bleu1) / len(all_bleu1)
avg_bleu2 = sum(all_bleu2) / len(all_bleu2)
avg_bleu4 = sum(all_bleu4) / len(all_bleu4)
avg_rouge = sum(all_rouge) / len(all_rouge)

print(f"\nAverage Scores:")
print(f"  BLEU-1:  {avg_bleu1:.4f} ({avg_bleu1*100:.2f}%)")
print(f"  BLEU-2:  {avg_bleu2:.4f} ({avg_bleu2*100:.2f}%)")
print(f"  BLEU-4:  {avg_bleu4:.4f} ({avg_bleu4*100:.2f}%)")
print(f"  ROUGE-L: {avg_rouge:.4f} ({avg_rouge*100:.2f}%)")

print("\nMetric Definitions:")
print("  - BLEU-1: Unigram precision (individual word matches)")
print("  - BLEU-2: Bigram precision (2-word phrase matches)")
print("  - BLEU-4: 4-gram precision (4-word phrase matches)")
print("  - ROUGE-L: Longest common subsequence F-measure")


# ========== CELL 10: Results DataFrame ==========
results_df = pd.DataFrame({
    'Question': [r['question'][:40] + '...' if len(r['question']) > 40 else r['question'] for r in results],
    'BLEU-1': [f"{r['bleu1']:.4f}" for r in results],
    'BLEU-2': [f"{r['bleu2']:.4f}" for r in results],
    'BLEU-4': [f"{r['bleu4']:.4f}" for r in results],
    'ROUGE-L': [f"{r['rouge_l']:.4f}" for r in results]
})

print("\n" + "="*70)
print("DETAILED RESULTS TABLE")
print("="*70)
display(results_df)


# ========== CELL 11: Save Results (Optional) ==========
# Uncomment to save and download results

# results_df.to_csv('bleu_rouge_4wheeler_results.csv', index=False)
# print("\nResults saved to 'bleu_rouge_4wheeler_results.csv'")

# from google.colab import files
# files.download('bleu_rouge_4wheeler_results.csv')

