In [None]:
# ==============================================================================
# INFERENCE LATENCY EVALUATION - TWO-WHEELER MODEL (BSA)
# Metrics: Tokenization, Inference, Decoding Time, Throughput
# ==============================================================================

# ========== CELL 1: Install Packages ==========
!pip install -q accelerate bitsandbytes peft transformers

# Note: After running Cell 1, restart runtime then run cells 2-9


# ========== CELL 2: Import Libraries ==========
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import pandas as pd
import time
import numpy as np


# ========== CELL 3: Define Dataset ==========
dataset = [
    {
        "question": "What is the recommended lubrication for the engine of the BSA D14/4 Bantam Supreme motorcycle?", 
        "answer": "Engine lubrication: BSA recommends using a mixture of 10W-30 oil, with a minimum of 10W-40 oil, for the engine of the BSA D14/4 Bantam Supreme motorcycle."
    },
    {
        "question": "Where should an inexperienced owner consult for assistance with major repair work?", 
        "answer": "His B.S.A. dealer"
    },
    {
        "question": "What is the recommended procedure for claiming assistance under the B.S.A. guarantee?", 
        "answer": "Claim assistance through the dealer from whom the motorcycle was purchased."
    },
    {
        "question": "What is the correct address of the B.S.A. Service Department?", 
        "answer": "B.S.A. MOTOR CYCLES LIMITED, SERVICE DEPARTMENT, ARMOURY ROAD, BIRMINGHAM 11"
    },
    {
        "question": "What is the recommended procedure for claiming assistance under the guarantee for a new motorcycle?", 
        "answer": "The owner must do so through the dealer from whom the machine was purchased."
    },
    {
        "question": "What is the recommended torque wrench setting for the Supreme model?", 
        "answer": "1 to 3"
    }
]


# ========== CELL 4: Configure Model ==========
adapter_name = "Prithwiraj731/Gemma2-2b_Two-Wheeler"
base_model_name = "google/gemma-2-2b"

print("Model Configuration:")
print(f"  Adapter: {adapter_name}")
print(f"  Base Model: {base_model_name}")


# ========== CELL 5: Load Model ==========
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(adapter_name)

print("Loading base model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16
)

print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, adapter_name)
model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Model loaded successfully\n")


# ========== CELL 6: Define Answer Generation with Timing ==========
def generate_answer_with_timing(question, max_new_tokens=100):
    """Generate answer and measure latency at each stage"""
    prompt = f"{question}\n"
    
    # Tokenization time
    tok_start = time.perf_counter()
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    tok_time = time.perf_counter() - tok_start
    
    input_tokens = inputs['input_ids'].shape[1]
    
    # Inference time
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    inf_start = time.perf_counter()
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    inf_time = time.perf_counter() - inf_start
    
    output_tokens = outputs.shape[1]
    new_tokens = output_tokens - input_tokens
    
    # Decoding time
    dec_start = time.perf_counter()
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    dec_time = time.perf_counter() - dec_start
    
    answer = generated_text[len(prompt):].strip()
    answer = answer.split('\n')[0].strip()
    
    if not answer or len(answer.strip()) == 0:
        answer = "No answer generated"
    
    total_time = tok_time + inf_time + dec_time
    tokens_per_sec = new_tokens / inf_time if inf_time > 0 else 0
    
    return {
        'answer': answer,
        'input_tokens': input_tokens,
        'output_tokens': new_tokens,
        'tokenization_ms': tok_time * 1000,
        'inference_ms': inf_time * 1000,
        'decoding_ms': dec_time * 1000,
        'total_ms': total_time * 1000,
        'tokens_per_sec': tokens_per_sec
    }


# ========== CELL 7: Warmup Run ==========
print("Warming up model...")
_ = generate_answer_with_timing("Test question")
print("Warmup complete\n")


# ========== CELL 8: Run Latency Tests ==========
print("="*70)
print("MEASURING INFERENCE LATENCY")
print("="*70)

results = []

for i, item in enumerate(dataset):
    question = item["question"]
    
    timing = generate_answer_with_timing(question)
    
    results.append({
        'question': question,
        **timing
    })
    
    print(f"\nQuestion {i+1}/{len(dataset)}")
    print(f"Q: {question[:60]}...")
    print(f"A: {timing['answer'][:60]}...")
    print(f"Input tokens:  {timing['input_tokens']}")
    print(f"Output tokens: {timing['output_tokens']}")
    print(f"Tokenization:  {timing['tokenization_ms']:.2f} ms")
    print(f"Inference:     {timing['inference_ms']:.2f} ms")
    print(f"Decoding:      {timing['decoding_ms']:.2f} ms")
    print(f"Total:         {timing['total_ms']:.2f} ms")
    print(f"Throughput:    {timing['tokens_per_sec']:.2f} tokens/sec")


# ========== CELL 9: Display Summary Results ==========
print("\n" + "="*70)
print("TWO-WHEELER MODEL - INFERENCE LATENCY RESULTS")
print("="*70)

avg_input = np.mean([r['input_tokens'] for r in results])
avg_output = np.mean([r['output_tokens'] for r in results])
avg_tok = np.mean([r['tokenization_ms'] for r in results])
avg_inf = np.mean([r['inference_ms'] for r in results])
avg_dec = np.mean([r['decoding_ms'] for r in results])
avg_total = np.mean([r['total_ms'] for r in results])
avg_tps = np.mean([r['tokens_per_sec'] for r in results])

min_latency = min([r['inference_ms'] for r in results])
max_latency = max([r['inference_ms'] for r in results])
p50_latency = np.percentile([r['inference_ms'] for r in results], 50)
p90_latency = np.percentile([r['inference_ms'] for r in results], 90)
p99_latency = np.percentile([r['inference_ms'] for r in results], 99)

print(f"\nLatency Statistics:")
print(f"  Avg Input Tokens:     {avg_input:.1f}")
print(f"  Avg Output Tokens:    {avg_output:.1f}")
print(f"  Avg Tokenization:     {avg_tok:.2f} ms")
print(f"  Avg Inference:        {avg_inf:.2f} ms")
print(f"  Avg Decoding:         {avg_dec:.2f} ms")
print(f"  Avg Total Latency:    {avg_total:.2f} ms")
print(f"  Avg Throughput:       {avg_tps:.2f} tokens/sec")

print(f"\nLatency Percentiles (Inference only):")
print(f"  Min:     {min_latency:.2f} ms")
print(f"  P50:     {p50_latency:.2f} ms")
print(f"  P90:     {p90_latency:.2f} ms")
print(f"  P99:     {p99_latency:.2f} ms")
print(f"  Max:     {max_latency:.2f} ms")


# ========== CELL 10: Results DataFrame ==========
results_df = pd.DataFrame({
    'Question': [r['question'][:35] + '...' if len(r['question']) > 35 else r['question'] for r in results],
    'In Tok': [r['input_tokens'] for r in results],
    'Out Tok': [r['output_tokens'] for r in results],
    'Inference (ms)': [f"{r['inference_ms']:.2f}" for r in results],
    'Total (ms)': [f"{r['total_ms']:.2f}" for r in results],
    'Tok/sec': [f"{r['tokens_per_sec']:.1f}" for r in results]
})

print("\n" + "="*70)
print("DETAILED RESULTS TABLE")
print("="*70)
display(results_df)


# ========== CELL 11: Save Results (Optional) ==========
# Uncomment to save and download results

# results_df.to_csv('latency_2wheeler_results.csv', index=False)
# print("\nResults saved to 'latency_2wheeler_results.csv'")

# from google.colab import files
# files.download('latency_2wheeler_results.csv')

