# Cache-Augmented Generation

This notebook is a demonstration of **Cache-Augmented Generation** using:
- **Mistral** (`mistralai/Mistral-7B-Instruct-v0.1`)
- A `document.txt` file describing **YOU**.
- A simple **KV cache** mechanism with `DynamicCache`.

We’ll:
1. Load the model.
2. Preload `document.txt` into the cache.
3. Ask two questions, reusing the same cache.

Prerequisites:
1. A HuggingFace account
2. A .env file with your HuggingFace access token.
3. A document.txt file with sentences about yourself.

For this demo, I will be using my own document.txt and ask questions about myself (Ronan Takizawa)


### Imports and the Generate Function
We import the essential libraries (torch, transformers) and define the generate function. This function handles token-by-token generation, reusing the model’s past_key_values.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os

# Minimal generate function for token-by-token generation
def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:
    device = model.model.embed_tokens.weight.device
    origin_len = input_ids.shape[-1]
    input_ids = input_ids.to(device)
    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            out = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            logits = out.logits[:, -1, :]
            token = torch.argmax(logits, dim=-1, keepdim=True)
            output_ids = torch.cat([output_ids, token], dim=-1)
            past_key_values = out.past_key_values
            next_token = token.to(device)

            if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
                break

    # Return just the newly generated part
    return output_ids[:, origin_len:]

### DynamicCache Setup
Initializing the DynamicCache mechanism for storing and reusing the model’s key/value states. It also provides a clean_up function to truncate any extra tokens appended by user queries.

In [2]:
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])

def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
    # Encode prompt
    device = model.model.embed_tokens.weight.device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    cache = DynamicCache()

    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True
        )
    return cache

def clean_up(cache: DynamicCache, origin_len: int):
    # Remove any tokens appended to the original knowledge
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]

### .env function logic

Defining logic to get Hugging Face Token to log inot Hugging Face and download the LLM.

In [3]:
def get_env():
    env_dict = {}
    env_file = ".env" if os.path.exists(".env") else "env"
    if os.path.exists(env_file):
        with open(env_file, mode="r") as f:
            for line in f:
                key, value = line.strip().split("=")
                env_dict[key] = value.strip('"')
    else:
        print("No .env or env file found; HF_TOKEN may not be set.")
    return env_dict

env = get_env()
HF_TOKEN = env.get("HF_TOKEN", None)

# Global placeholders (if needed)
model_name = None
model = None
tokenizer = None
rand_seed = None

print("Environment and imports are set.")

Environment and imports are set.


### Load Mistral
We’ll load the `mistralai/Mistral-7B-Instruct-v0.1` model in full precision (FP16 on GPU if available).


In [4]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True,
    token=HF_TOKEN
)
device = "cuda"
model.to(device)
print(f"Loaded {model_name}.")

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Loaded mistralai/Mistral-7B-Instruct-v0.1.


### Create a Knowledge Prompt from `document.txt`
We read the file, build a short system/user prompt, and call `get_kv_cache`.

In [5]:
if not os.path.exists("document.txt"):
    raise FileNotFoundError("Please create a `document.txt` with info about Ronan Takizawa.")

with open("document.txt", "r", encoding="utf-8") as f:
    doc_text = f.read()

system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()

# Build the cache
ronan_cache = get_kv_cache(model, tokenizer, system_prompt)
origin_len = ronan_cache.key_cache[0].shape[-2]
print("KV cache built.")

KV cache built.


### Ask Questions Reusing the Cache
We use the same knowledge (no real-time retrieval).

In [6]:
# 1st query
question1 = "Who is Ronan Takizawa?"
clean_up(ronan_cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, ronan_cache)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)
print("Q1:", question1)
print("A1:", answer1)

Q1: Who is Ronan Takizawa?
A1: Answer: Ronan Takizawa is a Colorado College computer science student, cybersecurity researcher, and tech content creator with over 100,000 followers across social media platforms. He has built several applications and systems using various


In [7]:
# 2nd query
question2 = "What are his main projects?"
clean_up(ronan_cache, origin_len)
input_ids_q2 = tokenizer(question2 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q2 = generate(model, input_ids_q2, ronan_cache)
answer2 = tokenizer.decode(gen_ids_q2[0], skip_special_tokens=True)
print("Q2:", question2)
print("A2:", answer2)

Q2: What are his main projects?
A2: Answer: Ronan Takizawa's main projects include Punch Analytics, Noname, TeleSpeech, and the website automation system for the Ireland-Japan Chamber of Commerce.


In [8]:
# Save the cache to disk
clean_up(ronan_cache, origin_len)
cache_dir = "cag_cache"
os.makedirs(cache_dir, exist_ok=True)

# Save the KV cache
torch.save(ronan_cache, os.path.join(cache_dir, "ronan_knowledge.cache"))


In [11]:
# Load cache to prove context is preserved for multiple seesions
loaded_cache = torch.load(os.path.join(cache_dir, "ronan_knowledge.cache"))

question3 = "What technologies has he worked with?"
input_ids_q3 = tokenizer(question3 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q3 = generate(model, input_ids_q3, loaded_cache)
answer3 = tokenizer.decode(gen_ids_q3[0], skip_special_tokens=True)
print("Q3:", question3)
print("A3:", answer3)

Q3: What technologies has he worked with?
A3: Answer: Python, TypeScript, Rust, Java, Shell, SQL, React, NodeJS, MongoDB, Docker, Kubernetes, AWS, GCP, Firebase, OpenCV, GraphQL.


### Done!
This minimal notebook **preloads** knowledge, then answers queries using the **same** cached context.