### Output-only style tuning with soft prompts (self-contained)

This notebook fine-tunes style by training ONLY on assistant outputs (no instructions), using PEFT Prompt Tuning on an instruct model.

It will:
- Install dependencies in-notebook
- Load a chat checkpoint via `transformers`
- Configure Prompt Tuning (learn virtual tokens only)
- Train on an outputs-only dataset to steer style
- Run inference on a normal user prompt

Notes:
- Adjust `MODEL_ID` to a model you can pull.
- Large models need significant VRAM; pick a smaller one if needed.


In [30]:
%pip install -qU transformers==4.55.2 peft accelerate datasets trl einops sentencepiece bitsandbytes jinja2>=3.1.0


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [31]:
# Config and outputs-only dataset
from typing import List
import json
import re

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"  # Change to a model you can access
OUTPUT_DIR = "./softprompt-style-outputs"
PROMPT_TOKENS = 64
MICRO_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 1
LEARNING_RATE = 5e-3
NUM_TRAIN_STEPS = 800  # longer for style steering
MAX_SEQ_LEN = 256

# Load sampled Reddit posts from JSON created by sample-posts.py
# Each item is a dict with keys: title, subreddit, self_text
with open("./data/post-sample.json", "r", encoding="utf-8") as f:
    reddit_posts: List[dict] = json.load(f)

# Basic cleaners to remove UI boilerplate and artifacts that leak into generations
BOILERPLATE_PATTERNS = [
    r"^\s*View\s+More\s+Posts\s*$",
    r"^\s*View\s+Post\s*$",
    r"^\s*Help\??\s*$",
    r"^\s*Edit:\s*.*$",
]
boilerplate_regexes = [re.compile(p, flags=re.IGNORECASE) for p in BOILERPLATE_PATTERNS]

def clean_text(text: str) -> str:
    if not text:
        return ""
    # Normalize whitespace
    text = re.sub(r"\r\n?", "\n", text)
    text = re.sub(r"\s+", " ", text).strip()
    # Remove boilerplate lines
    lines = [ln.strip() for ln in text.split("\n")]
    kept = []
    for ln in lines:
        if any(rx.match(ln) for rx in boilerplate_regexes):
            continue
        kept.append(ln)
    return "\n".join(kept).strip()

# Convert to outputs-only strings in a strict template order
style_outputs: List[str] = []
for p in reddit_posts:
    title = clean_text(p.get("title", ""))
    self_text = clean_text(p.get("self_text", ""))
    subreddit = clean_text(p.get("subreddit", ""))
    # Remove repeated subreddit tokens like 'r/AskReddit' lines duplicated
    subreddit = re.sub(r"\s*(/)?r/", "r/", subreddit)
    style_outputs.append(
        f"title: {title}\nself_text: {self_text}\nsubreddit: {subreddit}"
    )


In [32]:

# Load tokenizer and model (8B instruct)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
    attn_implementation="eager",
)
model.config.use_cache = False
print("Loaded:", MODEL_ID)


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 59.91it/s]


Loaded: meta-llama/Meta-Llama-3-8B-Instruct


In [33]:
# Configure PEFT Prompt Tuning
from peft import PromptTuningConfig, PromptTuningInit, get_peft_model, TaskType

peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=PROMPT_TOKENS,
    prompt_tuning_init_text="You are a helpful, concise assistant.",
    tokenizer_name_or_path=MODEL_ID,
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


trainable params: 262,144 || all params: 8,030,523,392 || trainable%: 0.0033


In [34]:
# Preprocess outputs-only dataset
from datasets import Dataset

# Create text stream directly from outputs
outputs_ds = Dataset.from_dict({"text": [o.strip() for o in style_outputs]})

# Tokenize without padding; we'll pad dynamically in the collate_fn
def tokenize(batch):
    out = tokenizer(
        batch["text"],
        max_length=MAX_SEQ_LEN,
        truncation=True,
        padding=False,
        return_attention_mask=True,
    )
    out["labels"] = [ids[:] for ids in out["input_ids"]]
    return out

train_ds = outputs_ds.map(tokenize, batched=True, remove_columns=outputs_ds.column_names)
train_ds


Map: 100%|██████████| 250/250 [00:00<00:00, 4067.35 examples/s]


Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 250
})

In [35]:
# Trainer setup and brief training
import math
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW


def collate_fn(features):
    pad_id = tokenizer.pad_token_id
    batch_size = len(features)
    seq_lens = [len(f["input_ids"]) for f in features]
    max_len = max(seq_lens)

    input_ids = torch.full((batch_size, max_len), pad_id, dtype=torch.long)
    attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long)
    labels = torch.full((batch_size, max_len), -100, dtype=torch.long)

    for i, f in enumerate(features):
        ids = torch.tensor(f["input_ids"], dtype=torch.long)
        attn = torch.tensor(f["attention_mask"], dtype=torch.long)
        labs = torch.tensor(f["labels"], dtype=torch.long)
        L = ids.size(0)
        input_ids[i, :L] = ids
        attention_mask[i, :L] = attn
        labels[i, :L] = labs

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


train_loader = DataLoader(
    train_ds,
    batch_size=MICRO_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# Total optimizer steps we intend to take
total_optim_steps = NUM_TRAIN_STEPS
num_warmup_steps = max(1, int(0.1 * total_optim_steps))
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_optim_steps,
)

model.train()
model = model.to(next(model.parameters()).device)

optimizer.zero_grad()
optim_step = 0
accumulated = 0
running_loss = 0.0
for epoch in range(10):  # repeat over dataset until reaching desired steps
    for batch in train_loader:
        batch = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        (loss / GRAD_ACCUM_STEPS).backward()
        running_loss += loss.item()
        accumulated += 1
        if accumulated % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            if optim_step % 10 == 0:
                print(f"step {optim_step} loss {running_loss / GRAD_ACCUM_STEPS:.4f}")
            running_loss = 0.0
            optim_step += 1
            if optim_step >= total_optim_steps:
                break
    if optim_step >= total_optim_steps:
        break

model.save_pretrained(OUTPUT_DIR)
print("Saved prompt adapter to:", OUTPUT_DIR)


step 0 loss 4.3706
step 10 loss 3.5588
step 20 loss 3.2212
step 30 loss 3.1885
step 40 loss 3.1434
step 50 loss 2.8586
step 60 loss 3.0076
step 70 loss 2.5332
step 80 loss 2.3638
step 90 loss 2.3937
step 100 loss 2.7967
step 110 loss 2.3713
step 120 loss 2.5810
step 130 loss 3.4381
step 140 loss 2.9435
step 150 loss 2.2477
step 160 loss 2.1865
step 170 loss 2.3452
step 180 loss 2.5205
step 190 loss 4.3069
step 200 loss 2.8646
step 210 loss 2.1046
step 220 loss 3.4014
step 230 loss 1.8779
step 240 loss 2.5006
step 250 loss 3.3016
step 260 loss 2.3269
step 270 loss 1.5209
step 280 loss 2.0935
step 290 loss 2.4100
step 300 loss 1.4401
step 310 loss 2.2326
step 320 loss 2.3852
step 330 loss 2.7482
step 340 loss 2.7135
step 350 loss 2.6208
step 360 loss 2.4218
step 370 loss 2.8657
step 380 loss 2.8885
step 390 loss 3.1669
step 400 loss 2.2478
step 410 loss 2.7647
step 420 loss 2.6662
step 430 loss 2.4746
step 440 loss 2.3423
step 450 loss 3.3020
step 460 loss 2.7758
step 470 loss 2.4833
ste

In [36]:
# Inference with tuned soft prompts (no chat template)
from peft import PeftModel
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
import torch 

bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Reload base + adapter
base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
)
base = PeftModel.from_pretrained(base, OUTPUT_DIR)
base.eval()

streamer = TextStreamer(tokenizer, skip_special_tokens=True)

# Empty prompt so the soft prompt dominates formatting
prompt = "Please generate one reddit post. Make sure to stick to the format below exactly. Don't include any extraneous characters like asterisks or other symbols. \n\n title: {title} \n self_text: {self_text} \n subreddit: {subreddit} \n Here's an example of the format: \n\ntitle: This is the title of the post! \nself_text: Here's where the content of the post goes. \nsubreddit: This is the subreddit, or the name of the community the post belongs to."

# Ensure at least one token as input to avoid empty input_ids
if prompt.strip() == "":
    token_id = (
        tokenizer.bos_token_id
        if tokenizer.bos_token_id is not None
        else (tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id)
    )
    if token_id is not None:
        inputs = {
            "input_ids": torch.tensor([[token_id]], dtype=torch.long, device=base.device),
            "attention_mask": torch.ones((1, 1), dtype=torch.long, device=base.device),
        }
    else:
        inputs = tokenizer(" ", return_tensors="pt").to(base.device)
else:
    inputs = tokenizer(prompt, return_tensors="pt").to(base.device)

with torch.no_grad():
    _ = base.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        streamer=streamer,
    )


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.55it/s]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Please generate one SINGLE reddo enjoy. It's imperative that you only generate one and that you stick to the format below exactly. Don't include any extraneous characters like asterisks or other symbols. 

 title: {title} 
 self_text: {self_text} 
 subreddit: {subreddit} 
 Here's an example of the format: 

title: This is the title of the post! 
self_text: Here's where the content of the post goes. 
subreddit: This is the subreddit, or the name of the community the post belongs to. 
}
submitted to r/RandomJokes. 
View Reddit
\end{self-post}
view more View less more bath & body works 3 wick candles 10.00 bath & body works 3 wick candles 10.00 View Reddit
Pinned post
\end{self-post}
view more View less more View Reddit
Pinned post
\end{self-post}
view more View less more View Reddit
Pinned post
\end{self-post}
view more View less more View Reddit
Pinned post
\end{self-post}
view more View less more View Reddit
Pinned post
\end{self-post}
view more View less more View Reddit
Pinned post
\

# 

In [37]:
# Inference with stop on second title to ensure exactly one post
from peft import PeftModel
from transformers import StoppingCriteria, StoppingCriteriaList
import torch

# Reload base + adapter (safe to re-run)
base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
)
base = PeftModel.from_pretrained(base, OUTPUT_DIR)
base.eval()

class StopOnSecondTitle(StoppingCriteria):
    def __init__(self, tokenizer, prompt_len: int):
        self.tokenizer = tokenizer
        self.prompt_len = prompt_len
    def __call__(self, input_ids, scores, **kwargs):
        gen_ids = input_ids[0, self.prompt_len:]
        if gen_ids.numel() == 0:
            return False
        text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        return text.count("title:") >= 2

prompt = (
    "Please generate a reddit post that the user is likely to enjoy. "
    "It's imperative that you only generate one and that you stick to the format below exactly. "
    " title: {title} \n self_text: {self_text} \n subreddit: {subreddit} \n Here's an example of the format: \n\n"
    "title: This is the title of the post! \n"
    "self_text: Here's where the content of the post goes. \n"
    "subreddit: This is the subreddit, or the name of the community the post belongs to."
)

inputs = tokenizer(prompt, return_tensors="pt").to(base.device)
stopper = StoppingCriteriaList([StopOnSecondTitle(tokenizer, inputs["input_ids"].shape[1])])

with torch.no_grad():
    out = base.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        stopping_criteria=stopper,
        return_dict_in_generate=True,
    )

# Decode only generated tokens (exclude the prompt)
gen_tokens = out.sequences[0, inputs["input_ids"].shape[1]:]
text = tokenizer.decode(gen_tokens, skip_special_tokens=True)

# Trim to the first post only
first_idx = text.find("title:")
if first_idx != -1:
    text = text[first_idx:]
    second_idx = text.find("\ntitle:", 1)
    if second_idx != -1:
        text = text[:second_idx]

print(text.strip())



Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.56it/s]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


title: This is the title of the post! self_text: Here's where the content of the post goes. subreddit: This is the subreddit, or the name of the community the post belongs to. 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
</div> 
</div> 
</div> 
</body> 
</html> 
</p> 
<

In [38]:
# Utility: generate a single reddit post (no prompt echo)
from typing import Optional
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnSecondTitle(StoppingCriteria):
    def __init__(self, tokenizer, prompt_len: int):
        self.tokenizer = tokenizer
        self.prompt_len = prompt_len
    def __call__(self, input_ids, scores, **kwargs):
        gen_ids = input_ids[0, self.prompt_len:]
        if gen_ids.numel() == 0:
            return False
        text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        return text.count("title:") >= 2

GEN_PROMPT = (
    "Please generate a reddit post that the user is likely to enjoy. "
    "It's imperative that you only generate one and that you stick to the format below exactly. "
    " title: {title} \n self_text: {self_text} \n subreddit: {subreddit} \n Here's an example of the format: \n\n"
    "title: This is the title of the post! \n"
    "self_text: Here's where the content of the post goes. \n"
    "subreddit: This is the subreddit, or the name of the community the post belongs to."
)

def generate_reddit_post(temperature: float = 0.7, top_p: float = 0.9, max_new_tokens: int = 256) -> str:
    inputs = tokenizer(GEN_PROMPT, return_tensors="pt").to(model.device)
    stopper = StoppingCriteriaList([StopOnSecondTitle(tokenizer, inputs["input_ids"].shape[1])])
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            stopping_criteria=stopper,
            return_dict_in_generate=True,
        )
    gen_tokens = out.sequences[0, inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    first_idx = text.find("title:")
    if first_idx != -1:
        text = text[first_idx:]
        second_idx = text.find("\ntitle:", 1)
        if second_idx != -1:
            text = text[:second_idx - len("title:")]
    return text.strip()



In [39]:
generate_reddit_post()

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


"I'm asking for this because I'm a bot, and I have no idea what you want to see. \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div> \n</div>"

In [40]:
# Generate 10 posts
posts = [generate_reddit_post() for _ in range(10)]
print("\n\n---\n\n".join(posts))



Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


KeyboardInterrupt: 

In [None]:
import os
from huggingface_hub import login

# If you set the env var already:
login(token="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")

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="bfloat16",          # H100-friendly
    device_map={"": 0},              # force GPU 0
)
model.eval()

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 79.19 GiB of which 57.00 MiB is free. Including non-PyTorch memory, this process has 79.12 GiB memory in use. Of the allocated memory 77.75 GiB is allocated by PyTorch, and 739.16 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
prompt = (
    "Please generate a reddit post that the user is likely to enjoy. "
    "Make sure to stick to the format below exactly. Don't include any extraneous characters like asterisks or other symbols. \n\n"
    "title: {title} \n"
    "self_text: {self_text} \n"
    "subreddit: {subreddit} \n"
    "Here's an example of the format: \n\n"
    "title: This is the title of the post! \n"
    "self_text: Here's where the content of the post goes. \n"
    "subreddit: This is the subreddit, or the name of the community the post belongs to."
)

out = pipe(
    prompt,
    max_new_tokens=256,              # smaller = faster
)
print(out[0]["generated_text"])

Please generate a reddit post that the user is likely to enjoy. Make sure to stick to the format below exactly. Don't include any extraneous characters like asterisks or other symbols. 

title: {title} 
self_text: {self_text} 
subreddit: {subreddit} 
Here's an example of the format: 

title: This is the title of the post! 
self_text: Here's where the content of the post goes. 
subreddit: This is the subreddit, or the name of the community the post belongs to. 

Here's a post that I think the user would enjoy: 

title: I just spent 10 hours building a treehouse and I'm not sure if it's worth it 
self_text: I've been working on this treehouse for weeks, and I finally finished it today. It's a lot bigger than I expected it to be, but I'm not sure if it's worth all the effort. The ladder is a little rickety, and the roof is a little leaky, but it's still really cool. I was thinking about adding some stairs instead of a ladder, but I'm not sure if that would make it too big. Has anyone else