# Compare Local (Gemma 2 2B) vs NIM (Llama 3 8B)

Runs the **same queries** through both pipelines and plots comparison: latency, response length, and side-by-side answers.

**Requirements:** GPU, HF_TOKEN, NIM_BASE_URL (Colab Secrets), NIM running on GKE.

<a href="https://colab.research.google.com/github/KarthikSriramGit/Project-Insight/blob/main/notebooks/04_compare_local_vs_nim.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Colab setup
try:
    import google.colab
    get_ipython().system("git clone -q https://github.com/KarthikSriramGit/Project-Insight.git")
    get_ipython().run_line_magic("cd", "Project-Insight")
    get_ipython().system("pip install -q -r requirements.txt")
except Exception:
    pass

/content/Project-Insight


In [2]:
import os, sys, time, subprocess
from pathlib import Path

ROOT = Path(".").resolve()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

data_path = ROOT / "data" / "synthetic" / "fleet_telemetry.parquet"
if not data_path.exists():
    subprocess.run(["python", "data/synthetic/generate_telemetry.py", "--rows", "100000", "--output-dir", "data/synthetic", "--format", "parquet"], check=True, cwd=str(ROOT))

try:
    from google.colab import userdata
    NIM_BASE_URL = userdata.get("NIM_BASE_URL")
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
except Exception:
    NIM_BASE_URL = os.environ.get("NIM_BASE_URL", "http://YOUR_IP:8000")

from src.query.engine import TelemetryQueryEngine
from src.query.query_config import QUERY_CONFIG
from src.query.prompts import SYSTEM_PROMPT, format_user_query

engine = TelemetryQueryEngine(data_path=str(data_path), nim_base_url=NIM_BASE_URL, max_context_rows=500)
print("Setup ready.")

Setup ready.


## 1. Run local model (Gemma 2 2B)

In [None]:
import warnings
warnings.filterwarnings("ignore", message=".*torch_dtype.*")

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.inference.pipeline import InferencePipeline

model_id = "google/gemma-2-2b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if device=="cuda" else torch.float32)
model = model.to(device)
pipe = InferencePipeline(model, tokenizer, device=device, max_new_tokens=256)

def clean_resp(txt, prompt=""):
    if prompt and txt.startswith(prompt): txt = txt[len(prompt):]
    for x in ["<start_of_turn>", "<end_of_turn>", "<bos>", "<eos>", "model\n", "user\n"]: txt = txt.replace(x, "")
    return txt.strip()

results_02 = []
for cfg in QUERY_CONFIG:
    q = cfg["query"]
    if cfg.get("skip_data"):
        ctx = "No telemetry — general knowledge question."
    else:
        df = engine.retrieve(vehicle_ids=cfg.get("vehicle_ids"), sensor_type=cfg.get("sensor_type"), brake_threshold=cfg.get("brake_threshold"))
        ctx = engine._data_to_context(df)
    user_msg = f"{SYSTEM_PROMPT}\n\n{format_user_query(q, ctx)}"
    chat = f"<start_of_turn>user\n{user_msg}<end_of_turn>\n<start_of_turn>model\n"
    t0 = time.perf_counter()
    out = pipe.generate([chat], max_new_tokens=256)
    lat = time.perf_counter() - t0
    ans = clean_resp(out[0], chat)
    results_02.append({"label": cfg["label"], "query": q, "answer": ans, "latency_s": lat, "response_chars": len(ans)})
    print(f"[Local] {cfg['label']}: {lat:.2f}s")
print(f"Local total: {sum(r['latency_s'] for r in results_02):.2f}s")

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/288 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

## 2. Run NIM (Llama 3 8B on GKE)

In [None]:
from src.deploy.nim_client import NIMClient

nim_client = NIMClient(base_url=NIM_BASE_URL, max_tokens=256)
results_03 = []

for cfg in QUERY_CONFIG:
    if cfg.get("skip_data"):
        user_msg = format_user_query(cfg["query"], "No telemetry — general knowledge question.")
        t0 = time.perf_counter()
        ans = nim_client.ask(user_msg, system_context=SYSTEM_PROMPT)
        lat = time.perf_counter() - t0
    else:
        t0 = time.perf_counter()
        ans = engine.query(cfg["query"], vehicle_ids=cfg.get("vehicle_ids"), sensor_type=cfg.get("sensor_type"), brake_threshold=cfg.get("brake_threshold"))
        lat = time.perf_counter() - t0
    results_03.append({"label": cfg["label"], "query": cfg["query"], "answer": ans, "latency_s": lat, "response_chars": len(ans)})
    print(f"[NIM] {cfg['label']}: {lat:.2f}s")
print(f"NIM total: {sum(r['latency_s'] for r in results_03):.2f}s")

## 3. Comparison plots

In [None]:
import matplotlib.pyplot as plt
import numpy as np

labels = [r["label"] for r in results_02]
x = np.arange(len(labels))
w = 0.35

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Latency comparison
ax1 = axes[0, 0]
lat_02 = [r["latency_s"] for r in results_02]
lat_03 = [r["latency_s"] for r in results_03]
ax1.bar(x - w/2, lat_02, w, label="Local (Gemma 2 2B)", color="steelblue", alpha=0.8)
ax1.bar(x + w/2, lat_03, w, label="NIM (Llama 3 8B)", color="coral", alpha=0.8)
ax1.set_ylabel("Latency (s)")
ax1.set_title("Query latency: Local vs NIM")
ax1.set_xticks(x)
ax1.set_xticklabels(labels, rotation=45, ha="right")
ax1.legend()

# Response length comparison
ax2 = axes[0, 1]
ch_02 = [r["response_chars"] for r in results_02]
ch_03 = [r["response_chars"] for r in results_03]
ax2.bar(x - w/2, ch_02, w, label="Local", color="steelblue", alpha=0.8)
ax2.bar(x + w/2, ch_03, w, label="NIM", color="coral", alpha=0.8)
ax2.set_ylabel("Chars")
ax2.set_title("Response length")
ax2.set_xticks(x)
ax2.set_xticklabels(labels, rotation=45, ha="right")
ax2.legend()

# Summary: total time
ax3 = axes[1, 0]
ax3.bar(["Local", "NIM"], [sum(lat_02), sum(lat_03)], color=["steelblue", "coral"], alpha=0.8)
ax3.set_ylabel("Total time (s)")
ax3.set_title("Total inference time")

# Summary: avg latency
ax4 = axes[1, 1]
ax4.bar(["Local", "NIM"], [np.mean(lat_02), np.mean(lat_03)], color=["steelblue", "coral"], alpha=0.8)
ax4.set_ylabel("Avg latency (s)")
ax4.set_title("Average latency per query")

plt.tight_layout()
plt.show()

## 5. Insights summary

In [None]:
total_02 = sum(r["latency_s"] for r in results_02)
total_03 = sum(r["latency_s"] for r in results_03)
avg_02 = total_02 / len(results_02)
avg_03 = total_03 / len(results_03)
ratio = total_03 / total_02 if total_02 > 0 else 0

print("Insights:")
print(f"  • Local (Gemma 2 2B) total: {total_02:.2f}s | avg/query: {avg_02:.2f}s")
print(f"  • NIM (Llama 3 8B) total:   {total_03:.2f}s | avg/query: {avg_03:.2f}s")
print(f"  • NIM is {ratio:.2f}x {'faster' if ratio < 1 else 'slower'} than local (total time)")
print(f"  • Local avg response length: {sum(r['response_chars'] for r in results_02)/len(results_02):.0f} chars")
print(f"  • NIM avg response length:   {sum(r['response_chars'] for r in results_03)/len(results_03):.0f} chars")

## 4. Side-by-side answers

In [None]:
for i, lbl in enumerate(labels):
    print("=" * 70)
    print(f"Query: {lbl}")
    print("-" * 70)
    print(f"Local: {results_02[i]['answer'][:300]}{'...' if len(results_02[i]['answer'])>300 else ''}")
    print(f"NIM:   {results_03[i]['answer'][:300]}{'...' if len(results_03[i]['answer'])>300 else ''}")
    print()