# Notebook 3: Inference with the Fine-Tuned Model

Now for the exciting part: using our specialized model. This notebook shows how to load the base Gemma model and apply our trained LoRA adapters to it for inference.

In [None]:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

### Step 1: Load the Base Model and Tokenizer

In [None]:
base_model_name = "google/gemma-2b-it"

# Load the base model in 4-bit for efficiency
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=None, # Load base model without quantization first
    device_map="auto",
    torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

### Step 2: Load and Merge the LoRA Adapters

We use the `PeftModel` class to load our saved adapters and apply them to the base model.

In [None]:
adapter_path = "./gemma-pandas-expert-adapters"

# Load the PEFT model
model = PeftModel.from_pretrained(base_model, adapter_path)

# Merge the adapters into the base model
# This creates a new, standalone specialized model
model = model.merge_and_unload()

print("Adapters merged successfully!")

### Step 3: Test the Specialized Model

In [None]:
def ask_expert(question: str):
    # Format the prompt using the same template as training
    prompt = f"<start_of_turn>user\nYou are a Pandas expert. Answer the following question.\n\nQuestion: {question}<end_of_turn>\n<start_of_turn>model\n"
    
    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate a response
    outputs = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
    
    # Decode and print the response
    response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Clean up the output to only show the model's answer
    answer = response_text.split("<start_of_turn>model\n")[1]
    print(answer)

# Let's test it with a question similar to our training data
test_question = "How do I select rows from a DataFrame using a condition?"
print(f"--- Asking: {test_question} ---\n")
ask_expert(test_question)

# Test with a slightly different question
test_question_2 = "How do I use `groupby` to get the average score per group in Pandas?"
print(f"\n--- Asking: {test_question_2} ---\n")
ask_expert(test_question_2)