### Soft prompt tuning on LLaMA 8B Instruct (self-contained)

This notebook shows a minimal, end-to-end example of soft prompt tuning (Prompt Tuning) using PEFT on a LLaMA-family instruction-tuned model.

It will:
- Install dependencies in-notebook
- Load an 8B instruct checkpoint via `transformers`
- Configure PEFT Prompt Tuning (train only virtual tokens)
- Train briefly on a tiny inline dataset
- Run inference using the tuned soft prompts

Notes:
- You may need access approval for gated models. If so, set `MODEL_ID` to a model you can pull.
- 8B models require significant RAM/VRAM; adjust to a smaller chat model if needed.


In [14]:
pip install -qU transformers==4.55.2 peft accelerate datasets trl einops sentencepiece bitsandbytes


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [15]:
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [16]:
# Config and tiny inline dataset
from typing import List, Dict

MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"  # Change to a model you can access
OUTPUT_DIR = "./softprompt-llama8b"
PROMPT_TOKENS = 32
MICRO_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 5e-3
NUM_TRAIN_STEPS = 30  # keep short for demo
MAX_SEQ_LEN = 512

# Tiny toy dataset of (instruction, output)
examples: List[Dict[str, str]] = [
    {
        "instruction": "Write a short, funny haiku about databases.",
        "output": "Tables join in love\nIndex hearts beat in queries\nACID dreams commit",
    },
    {
        "instruction": "Explain soft prompt tuning in one sentence.",
        "output": "Soft prompt tuning trains only virtual tokens that steer the frozen model.",
    },
    {
        "instruction": "List three benefits of unit tests.",
        "output": "Confidence in changes; documentation of behavior; faster debugging.",
    },
]


In [17]:
# Load tokenizer and model (8B instruct)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
)
model.config.use_cache = False
print("Loaded:", MODEL_ID)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded: Qwen/Qwen3-4B-Instruct-2507


In [18]:
# Configure PEFT Prompt Tuning
from peft import PromptTuningConfig, PromptTuningInit, get_peft_model, TaskType

peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=PROMPT_TOKENS,
    prompt_tuning_init_text="You are a helpful, concise assistant.",
    tokenizer_name_or_path=MODEL_ID,
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


trainable params: 81,920 || all params: 4,022,550,016 || trainable%: 0.0020


In [19]:
# Preprocess examples for supervised fine-tuning format
from datasets import Dataset

# Convert to chat-style prompt => completion text
def to_chat(example):
    system = "You are a helpful assistant."
    user = example["instruction"]
    assistant = example["output"]
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": user},
        {"role": "assistant", "content": assistant},
    ]
    if hasattr(tokenizer, "apply_chat_template"):
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    else:
        # Fallback simple format
        text = f"<s>[SYSTEM]\n{system}\n[/SYSTEM]\n[USER]\n{user}\n[/USER]\n[ASSISTANT]\n{assistant}</s>"
    return {"text": text}

dataset = Dataset.from_list(examples).map(to_chat)

# Tokenize
def tokenize(batch):
    out = tokenizer(
        batch["text"],
        max_length=MAX_SEQ_LEN,
        truncation=True,
        padding="max_length",
        return_attention_mask=True,
    )
    out["labels"] = out["input_ids"].copy()
    return out

train_ds = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
train_ds


Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 3
})

In [20]:
# Trainer setup and brief training
import math
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW

train_loader = DataLoader(
    train_ds,
    batch_size=MICRO_BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda batch: {
        k: torch.tensor([b[k] for b in batch]) for k in batch[0].keys()
    },
)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
num_update_steps = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=max(1, num_update_steps // 10), num_training_steps=num_update_steps
)

model.train()
model = model.to(next(model.parameters()).device)

step = 0
optimizer.zero_grad()
for epoch in range(1):
    for batch in train_loader:
        batch = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss / GRAD_ACCUM_STEPS
        loss.backward()
        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        if step % 10 == 0:
            print(f"step {step} loss {loss.item()*GRAD_ACCUM_STEPS:.4f}")
        step += 1
        if step >= NUM_TRAIN_STEPS:
            break
    if step >= NUM_TRAIN_STEPS:
        break

model.save_pretrained(OUTPUT_DIR)
print("Saved prompt adapter to:", OUTPUT_DIR)


step 0 loss 22.9054
Saved prompt adapter to: ./softprompt-llama8b


In [21]:
# Inference with tuned soft prompts
from peft import PeftModel
from transformers import TextStreamer

# Reload base + adapter
base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
)
base = PeftModel.from_pretrained(base, OUTPUT_DIR)
base.eval()

streamer = TextStreamer(tokenizer, skip_special_tokens=True)

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Write a short, funny haiku about databases."},
]
if hasattr(tokenizer, "apply_chat_template"):
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
    prompt = "<s>[SYSTEM]\nYou are a helpful assistant.\n[/SYSTEM]\n[USER]\nWrite a short, funny haiku about databases.\n[/USER]\n[ASSISTANT]\n"

inputs = tokenizer(prompt, return_tensors="pt").to(base.device)
with torch.no_grad():
    _ = base.generate(
        **inputs,
        max_new_tokens=64,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        streamer=streamer,
    )


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

system
You are a helpful assistant.
user
Write a short, funny haiku about databases.
assistant




Tables full of data,  
Joins cause chaos, indexes sigh—  
SQL says, "Just wait!" 😅


# 