Source: https://medium.com/@ronantech/cache-augmented-generation-cag-in-llms-a-step-by-step-tutorial-6ac35d415eec

Source: https://colab.research.google.com/drive/1-0eKIu6cGAZ47ROKQaF6EU-mHtvJBILV?usp=sharing

Source: https://arxiv.org/pdf/2412.15605v1

In [9]:
from IPython.display import Markdown
import torch, transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os

In [11]:
print(f"torch version: {torch.__version__}")
print(f"transformers version: {transformers.__version__}")

torch version: 2.5.1+cu121
transformers version: 4.47.1


In [2]:
def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 512) -> torch.Tensor:
    device = model.model.embed_tokens.weight.device
    origin_len = input_ids.shape[-1]
    input_ids = input_ids.to(device)
    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            out = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            logits = out.logits[:, -1, :]
            token = torch.argmax(logits, dim=-1, keepdim=True)
            output_ids = torch.cat([output_ids, token], dim=-1)
            past_key_values = out.past_key_values
            next_token = token.to(device)

            if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
                break
    return output_ids[:, origin_len:]

In [3]:
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])

def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
    device = model.model.embed_tokens.weight.device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    cache = DynamicCache()

    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True
        )
    return cache

def clean_up(cache: DynamicCache, origin_len: int):
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]

In [4]:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')

In [5]:
# model_name = "mistralai/Mistral-7B-Instruct-v0.1"
# model_name = "openai-community/gpt2"
model_name = 'meta-llama/Llama-3.2-1B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True,
    token=HF_TOKEN
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Loaded {model_name}.")

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Loaded meta-llama/Llama-3.2-1B-Instruct.


In [6]:
with open("/content/Licence to Think - Abstract.txt", "r", encoding="utf-8") as f:
    doc_text = f.read()

system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()

ronan_cache = get_kv_cache(model, tokenizer, system_prompt)
origin_len = ronan_cache.key_cache[0].shape[-2]
print("KV cache built.")

KV cache built.


In [7]:
question1 = "Tell about author of the paper"
clean_up(ronan_cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, ronan_cache, max_new_tokens = 128)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)
# print("Q1:", question1)
# print("A1:", answer1)
Markdown(answer1)

Dipankar Porey
Dipankar.Porey@in.ey.com
LinkedIn: linkedln.com/in/dipankarporey1996
GitHub: github.com/dipankarporey1996
Dipankar Porey is an assistant to Ernst & Young LLP.assistant

The author of the paper is Dipankar Porey.

In [8]:
question1 = "summarize the document"
clean_up(ronan_cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, ronan_cache, max_new_tokens = 128)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)
# print("Q1:", question1)
# print("A1:", answer1)
Markdown(answer1)

The document discusses the challenges of training Generative Adversarial Networks (GANs) and proposes a novel approach to address these challenges. The author introduces a concept of "autonomous control" in GANs, where the model can independently orchestrate its thought processes, influencing the updates to node weights. The document explores the importance of conditions in the context of Conditional Generative Adversarial Networks (CGANs) and proposes a dynamic approach to capturing internal patterns within the generator and discriminator. The author questions how conditions aid in capturing internal patterns and how they contribute to robust and independent learning patterns for both the generator and discriminator.

Answer: