# Lab 2 – Base Model Loading and Inference
**Part 2 of the 7 Lab Hands-On SLM Training Series**

This lab builds on the concepts from the post **How to Train Your Dragon: Customize a Small Language Model for Your Domain**:
https://www.linkedin.com/pulse/how-train-your-dragon-customize-small-language-model-domain-patel-cvzjc

In Lab 1, we established a clean, reproducible environment. In Lab 2, we will:
- Load a high-quality Small Language Model (SLM) in 4-bit to fit common GPUs.
- Run initial prompts to understand the base model’s behavior.
- Explore generation parameters (temperature, top_p, max_new_tokens, repetition_penalty).
- Log results for later comparison after fine-tuning.

> Tip: If you hit out-of-memory on the default model, switch to the TinyLlama fallback included below.


## 1. Runtime and GPU Check
Run the cell below to confirm your runtime has a GPU. On Google Colab, choose **Runtime → Change runtime type → T4 or L4 GPU**.


In [None]:
#@title Check GPU
import os, sys, subprocess, textwrap

def _run(cmd):
    try:
        out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True, text=True, timeout=30)
        print(out)
    except Exception as e:
        print(f"Command failed: {cmd}\n{e}")

print("Python:", sys.version)
print("CUDA visible devices:", os.environ.get("CUDA_VISIBLE_DEVICES"))

print("\n=== nvidia-smi ===")
_run("nvidia-smi || echo 'No NVIDIA GPU found'")

try:
    import torch
    print("\nPyTorch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("CUDA device:", torch.cuda.get_device_name(0))
        print("Total VRAM (GB):", round(torch.cuda.get_device_properties(0).total_memory/1e9, 2))
except Exception as e:
    print("Torch not installed yet. Will be installed next.\n", e)


## 2. Install Required Libraries
We use:
- `transformers` for model and tokenizer
- `accelerate` for device placement and mixed precision
- `bitsandbytes` for 4-bit quantization
- `sentencepiece` for tokenizers used by Mistral/LLaMA

Rerun this cell if installs time out.


In [None]:
#@title Install libraries
!pip -q install --upgrade transformers accelerate bitsandbytes sentencepiece > /dev/null

import importlib, sys
mods = ["transformers", "accelerate", "bitsandbytes", "sentencepiece"]
for m in mods:
    try:
        importlib.import_module(m)
        print(f"OK: {m}")
    except Exception as e:
        print(f"Missing: {m}", e)


## 3. Choose a Base Model
Default is **mistralai/Mistral-7B-Instruct-v0.2** (strong, efficient).  
If you hit memory errors, switch to the fallback **TinyLlama/TinyLlama-1.1B-Chat-v1.0**.


In [None]:
#@title Select your base model
#@markdown Recommended default: mistralai/Mistral-7B-Instruct-v0.2
#@markdown Fallback if you see CUDA OOM: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_name = "mistralai/Mistral-7B-Instruct-v0.2"  #@param ["mistralai/Mistral-7B-Instruct-v0.2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
print("Using model:", model_name)


## 4. Load the Model in 4-bit (bnb.int4)
This reduces VRAM requirements and makes large models usable on consumer GPUs.


In [None]:
#@title Load model and tokenizer (4-bit)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch, math

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
)

print("Loaded:", model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


## 5. Inference Helper
The function below formats prompts, applies generation parameters, and returns decoded text.  
It auto-detects chat templates when available (e.g., Mistral Instruct).


In [None]:
#@title Define a generate_text helper
from typing import Optional, Dict

def apply_chat_template(tokenizer, user_prompt: str, system_prompt: Optional[str] = None):
    # Use chat template if provided by the tokenizer (common for Instruct models)
    if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": user_prompt})
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # Fallback: plain prompt
    if system_prompt:
        return f"System: {system_prompt}\nUser: {user_prompt}\nAssistant:"
    return user_prompt

def generate_text(
    prompt: str,
    system: Optional[str] = None,
    max_new_tokens: int = 256,
    temperature: float = 0.7,
    top_p: float = 0.9,
    repetition_penalty: float = 1.1,
    do_sample: bool = True,
) -> str:
    text = apply_chat_template(tokenizer, prompt, system_prompt=system)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
        )
    # Remove the input portion if needed
    gen_ids = output_ids[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=True)


## 6. Quick Smoke Test
Run a couple of prompts to observe out-of-the-box behavior.


In [None]:
#@title Try a few prompts
system_prompt = "You are a concise, accurate assistant. Answer clearly and avoid speculation."
prompts = [
    "Summarize the core differences between SLMs and LLMs for an executive audience.",
    "Give three use cases where a domain-tuned SLM can outperform a general-purpose LLM.",
    "Rewrite this sentence in a more formal tone: 'We gotta cut costs but keep quality.'",
]

for i, p in enumerate(prompts, 1):
    print(f"--- Prompt {i} ---")
    out = generate_text(p, system=system_prompt, max_new_tokens=200, temperature=0.7, top_p=0.9)
    print(out.strip(), "\n")


## 7. Parameter Exploration
Experiment with generation settings to see their effect on style and determinism.


In [None]:
#@title Parameter sweep
base_prompt = "Explain retrieval-augmented generation in two short paragraphs."
sweep = [
    dict(temperature=0.2, top_p=0.9, repetition_penalty=1.05),
    dict(temperature=0.7, top_p=0.9, repetition_penalty=1.05),
    dict(temperature=1.0, top_p=0.95, repetition_penalty=1.05),
]

for cfg in sweep:
    print(f"\n### Settings: {cfg}")
    txt = generate_text(
        base_prompt,
        system="You are an expert technical writer.",
        max_new_tokens=220,
        temperature=cfg["temperature"],
        top_p=cfg["top_p"],
        repetition_penalty=cfg["repetition_penalty"],
    )
    print(txt.strip())


## 8. Structured Logging
Use this template to log prompts, parameters, and outputs for later comparison (e.g., in Lab 6).


In [None]:
#@title Log results to CSV
import pandas as pd
from pathlib import Path
log_path = Path("/content") if Path("/content").exists() else Path("/mnt/data")
csv_file = log_path / "lab2_inference_log.csv"

records = []
records.append({
    "timestamp": pd.Timestamp.utcnow().isoformat(),
    "model_name": model_name,
    "prompt": "Explain retrieval-augmented generation in two short paragraphs.",
    "temperature": 0.7,
    "top_p": 0.9,
    "max_new_tokens": 220,
    "repetition_penalty": 1.1,
    "output": generate_text("Explain retrieval-augmented generation in two short paragraphs.", system="You are an expert technical writer.", max_new_tokens=220),
})

df = pd.DataFrame.from_records(records)
if csv_file.exists():
    df_existing = pd.read_csv(csv_file)
    df = pd.concat([df_existing, df], ignore_index=True)

df.to_csv(csv_file, index=False)
print("Logged results to:", csv_file)
df.tail(1)


## 9. Next Steps
- Capture a small suite of representative prompts from your domain and log baseline outputs.
- Identify weaknesses you expect fine-tuning to address (terminology use, formatting, factual grounding).
- Save your `lab2_inference_log.csv` for use in Lab 6 (evaluation and comparison).

When ready, proceed to **Lab 3 – Data Loading and Tokenization**.
