### RAG Playground: Llama 3 8B Instruct + Reddit-style Few-shot from JSON

Goal: retrieve example Reddit posts from `./data/post-sample.json` and condition Llama 3 8B Instruct to generate one stylistically similar post.

Prompt used: "generate a reddit post that the user is likely to enjoy"


In [1]:
%pip install -qU transformers==4.55.2 sentence-transformers faiss-cpu datasets einops peft accelerate bitsandbytes jinja2>=3.1.0


Note: you may need to restart the kernel to use updated packages.


In [2]:
# Load and clean dataset; assemble corpus strings
from typing import List, Dict
import json, re, os

DATA_PATH = "./data/post-sample.json"

assert os.path.exists(DATA_PATH), f"Missing {DATA_PATH}"

with open(DATA_PATH, "r", encoding="utf-8") as f:
    raw_posts: List[Dict] = json.load(f)

BOILERPLATE_PATTERNS = [
    r"^\s*View\s+More\s+Posts\s*$",
    r"^\s*View\s+Post\s*$",
    r"^\s*Help\??\s*$",
    r"^\s*Edit:\s*.*$",
]
boilerplate_regexes = [re.compile(p, flags=re.IGNORECASE) for p in BOILERPLATE_PATTERNS]

def clean_text(text: str) -> str:
    if not text:
        return ""
    text = re.sub(r"\r\n?", "\n", text)
    text = re.sub(r"\s+", " ", text).strip()
    lines = [ln.strip() for ln in text.split("\n")]
    kept = []
    for ln in lines:
        if any(rx.match(ln) for rx in boilerplate_regexes):
            continue
        kept.append(ln)
    return "\n".join(kept).strip()

corpus: List[str] = []
meta: List[Dict] = []
for p in raw_posts:
    title = clean_text(p.get("title", ""))
    self_text = clean_text(p.get("self_text", ""))
    subreddit = clean_text(p.get("subreddit", ""))
    subreddit = re.sub(r"\s*(/)?r/", "r/", subreddit)
    doc = f"title: {title}\nself_text: {self_text}\nsubreddit: {subreddit}"
    corpus.append(doc)
    meta.append({"title": title, "subreddit": subreddit})

len(corpus), corpus[0][:200]


(250,
 'title: 3070ti, 6900xt or wait for new cards?\nself_text: Backstory: I ordered a $500 34" qhd 144hz monitor off best buy and they accidentally shipped me 2 of them so now I\'m going to return one lol. I ')

In [3]:
# Embed corpus and build FAISS index
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
embedder = SentenceTransformer(EMBED_MODEL)

embeddings = embedder.encode(corpus, convert_to_numpy=True, show_progress_bar=True, normalize_embeddings=True)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)

index.ntotal


  from .autonotebook import tqdm as notebook_tqdm

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/lib/python3/dist-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/usr/lib/python3/dist-packages

AttributeError: _ARRAY_API not found

ImportError: numpy.core.multiarray failed to import

In [None]:
# Retrieval and prompt construction
from typing import List, Tuple

SYSTEM_STYLE = (
    "You are a writing assistant that outputs exactly one reddit post in the format:\n"
    "title: ...\nself_text: ...\nsubreddit: r/...\n"
)

USER_TASK = "generate a reddit post that the user is likely to enjoy"

def retrieve_examples(query: str, k: int = 6) -> List[str]:
    q_emb = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)
    scores, idxs = index.search(q_emb, k)
    return [corpus[i] for i in idxs[0]]

def build_fewshot_prompt(query: str, k: int = 6) -> str:
    examples = retrieve_examples(query, k=k)
    exemplars = "\n\n".join(examples)
    return (
        f"{SYSTEM_STYLE}\n\nHere are style examples:\n\n{exemplars}\n\nTask: {USER_TASK}\n"
    )

prompt = build_fewshot_prompt(USER_TASK, k=8)
print(prompt[:500])


In [None]:
# Load Llama 3 8B Instruct and generate
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
)
model.eval()

STOP_TOKENS = ["\n\ntitle:"]

def generate_with_rag(query: str, k: int = 8, max_new_tokens: int = 256,
                      temperature: float = 0.7, top_p: float = 0.9,
                      repetition_penalty: float = 1.05) -> str:
    prompt = build_fewshot_prompt(query, k=k)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            repetition_penalty=repetition_penalty,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            return_dict_in_generate=True,
        )
    gen = tokenizer.decode(out.sequences[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    # Keep only the first post shape
    first_idx = gen.find("title:")
    if first_idx != -1:
        gen = gen[first_idx:]
        # stop before a second title if it appears
        nxt = gen.find("\ntitle:", 1)
        if nxt != -1:
            gen = gen[:nxt]
    return gen.strip()



In [None]:
# Demo run
print(generate_with_rag("generate a reddit post that the user is likely to enjoy", k=8))
