<a href="https://colab.research.google.com/github/Amirosimani/cache-augmented-generation-gemma/blob/main/cache_aug_generation__gemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


|||
|----------|-------------|
| Author(s)   | amirimani@ |
| Last updated | 8/01/2025 |
<br><br>


# [Don't Do RAG: When Cache-Augmented Generation is All You Need for Knowledge Tasks](https://arxiv.org/abs/2412.15605v1)


This paper arguest that by preloading all relevant information into the model's vast context window and using a key-value cache with precomputed inferences, Context-Augmented Generation (CAG) eliminates the need for real-time retrieval in large language models, overcoming the limitations of traditional Retrieval-Augmented Generation.

### Concepts
**key-value cache**: key-value cache acts as a memory bank for autoregressive generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens

## Benefits of CAG
* Eliminates Real-time Retrieval: Preloads all necessary documents directly into the language model's context.
* Boosts Efficiency: Employs a precomputed cache to accelerate response times.
* Streamlines Architecture: Removes the need for separate retrieval systems, simplifying the overall process.

## Steps:

**Preloading External Knowledge:**
* preprocess a collection of documents relevant to the application.
* Encode these documents into a KV cache, which captures the inference state of the LLM.
* Store the KV cache on disk or in memory for reuse during inference.

**Inference:**
* Load the precomputed KV cache alongside the user’s query.
* The LLM processes the query using the preloaded knowledge for contextually accurate responses.
  * add a repetition criteria
* Combine the user query with the preloaded documents for a unified prompt

**Cache Reset:**
* Truncate new tokens in the cache without reloading the entire context.


## Technical considerations

* Quantization: using the HF blog post approach for kv caching [link](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)
* Stopping condition: to solve the issue of repetitive output, I added:
  - a stopping criteria that terminates generation when encountering a specific pattern, such as a newline or end-of-answer marker.
  - Adjusting Decoding Parameters: set repetition_penalty to penalize repeated tokens.


**based on your choice of model, use a runtime with GPU**

In [None]:
!pip install -U --quiet bitsandbytes accelerate quanto
!pip install --quiet transformers

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.0/90.0 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    DynamicCache
)
import torch

import os
import pandas as pd
from typing import Optional, Union
from google.colab import userdata

# utility functions

In [None]:
def generate(
    model: PreTrainedModel,
    input_ids: torch.Tensor,
    past_key_values: Optional[DynamicCache] = None,
    max_new_tokens: int = 50,
    repetition_penalty: float = 1.0,
    stop_token: Optional[str] = None,
    tokenizer: Optional[PreTrainedTokenizer] = None,
    top_p: float = 0.9,
    temperature: float = 1.0
) -> torch.Tensor:
    """
    Generate tokens from a model using sampling with past key-value caching.

    Args:
        model (PreTrainedModel): The language model for generation.
        input_ids (torch.Tensor): Input token IDs to begin generation.
        past_key_values (Optional[DynamicCache]): Cached key-value pairs for faster inference.
        max_new_tokens (int): Maximum number of new tokens to generate.
        repetition_penalty (float): Penalty for repeated tokens.
        stop_token (Optional[str]): A token to stop generation upon encountering.
        tokenizer (Optional[PreTrainedTokenizer]): Tokenizer to decode stop_token.
        top_p (float): Probability mass for nucleus sampling.
        temperature (float): Sampling temperature.

    Returns:
        torch.Tensor: The generated tokens excluding the input prompt.
    """
    device = model.device
    input_ids = input_ids.to(device)
    output_ids = input_ids.clone()
    generated_text = "" if tokenizer else None

    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=input_ids,
                past_key_values=past_key_values,
                use_cache=True,
            )
            logits = outputs.logits[:, -1, :]

            # Apply temperature scaling
            logits = logits / temperature

            # Apply repetition penalty
            if repetition_penalty > 1.0:
                for token_id in torch.unique(output_ids):
                    logits[:, token_id] /= repetition_penalty

            # Apply nucleus sampling
            probabilities = torch.softmax(logits, dim=-1)
            sorted_probs, sorted_indices = torch.sort(probabilities, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            probabilities[:, indices_to_remove] = 0
            probabilities = probabilities / probabilities.sum(dim=-1, keepdim=True)

            next_token = torch.multinomial(probabilities, num_samples=1)
            output_ids = torch.cat([output_ids, next_token], dim=-1)
            past_key_values = outputs.past_key_values
            input_ids = next_token

            # Decode if stop_token is specified
            if tokenizer and stop_token:
                generated_text += tokenizer.decode(next_token[0], skip_special_tokens=True)
                if stop_token in generated_text:
                    break

            # Stop generation if EOS token is encountered
            if model.config.eos_token_id is not None and next_token.item() == model.config.eos_token_id:
                break

    return output_ids[:, input_ids.shape[-1]:]



def get_kv_cache(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str) -> DynamicCache:
    """
    Generate and return key-value cache for a given prompt.

    Args:
        model (PreTrainedModel): The language model to use.
        tokenizer (PreTrainedTokenizer): Tokenizer to encode the prompt.
        prompt (str): The prompt text.

    Returns:
        DynamicCache: A cache containing key-value tensors for faster inference.
    """
    device = model.model.embed_tokens.weight.device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    cache = DynamicCache()

    # Using quantized KV cache
    cache_config = {"backend": "quanto", "nbits": 4}
    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True,
            cache_implementation="quantized",
            cache_config=cache_config,
        )
    return cache

def clean_up(cache: DynamicCache, origin_len: int):
    """
    Clean up the cache by truncating to the original length.

    Args:
        cache (DynamicCache): Cache object containing key and value tensors.
        origin_len (int): The original sequence length to truncate to.
    """
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]

In [None]:
model_name = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.float16,
# )
token = userdata.get('huggingface')

tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    trust_remote_code=True,
    token=token,
    quantization_config=bnb_config,
)

print(f"Loaded {model_name}.")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded google/gemma-2b.


In [None]:
f_name = "/content/NYSOMIGExclusionsList.xlsx"
df = pd.read_excel(f_name, engine="openpyxl")
text_output = df.to_string(index=False)

system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{text_output[:10000]}
Question:
""".strip()

In [None]:
cache = get_kv_cache(model, tokenizer, system_prompt)
print("KV cache with quantization built.")

KV cache with quantization built.


In [None]:
# Generate answers
question = "What is NPI number for 1 MEDICAL SUPPLIES CORP?"

input_ids_q = tokenizer(question + "\n", return_tensors="pt").input_ids.to(model.device)
gen_ids_q = generate(
    model,
    input_ids_q,
    past_key_values=cache,
    max_new_tokens=50,
    repetition_penalty=1.0,
    stop_token=None,
    tokenizer=tokenizer,
    top_p=0.9,
    temperature=0.8
)
answer = tokenizer.decode(gen_ids_q[0], skip_special_tokens=True)


print("Q:", question)
print("A:", answer)

Q: What is NPI number for 1 MEDICAL SUPPLIES CORP?
A: What is NPI number for 1 MEDICAL SUPPLIES CORP?
Answer: 1407487887



# Chat interfrace

In [None]:
# !pip install --quiet langchain

In [None]:
# from langchain.llms.base import LLM
# from typing import Optional, List, Any

# class CustomLLM(LLM):
#     def __init__(self, model: Any, tokenizer: Any, cache: Any, **kwargs):
#         super().__init__(**kwargs)
#         self._model = model  # Use private attributes to bypass pydantic validation
#         self._tokenizer = tokenizer
#         self._cache = cache

#     @property
#     def _llm_type(self) -> str:
#         return "custom_llm"

#     def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
#         # Tokenize the input
#         input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids.to(self._model.device)

#         # Generate response
#         gen_ids = generate(self._model, input_ids, past_key_values=self._cache)

#         # Decode output
#         response = self._tokenizer.decode(gen_ids[0], skip_special_tokens=True)

#         # Apply stop tokens if provided
#         if stop:
#             for stop_token in stop:
#                 response = response.split(stop_token)[0]
#         return response


In [None]:
# # Instantiate the LLM
# custom_llm = CustomLLM(model=model, tokenizer=tokenizer, cache=cache)

# # Create the LLMChain
# llm_chain = LLMChain(prompt=prompt, llm=custom_llm)

# # Chat loop
# context = text_output[:10000]
# while True:
#     user_question = input("You: ")
#     if user_question.lower() == "exit":
#         break
#     response = llm_chain.run(context=context, question=user_question)
#     print(f"Assistant: {response}")

  llm_chain = LLMChain(prompt=prompt, llm=custom_llm)


You: What is NPI number for 1 MEDICAL SUPPLIES CORP?


  response = llm_chain.run(context=context, question=user_question)


Assistant: <|system|>
You are an assistant who provides concise factual answers based on the context.
<|user|>
Context:
                                                                      provider_name                   license_num    npi_num                                     provider_type exclusion_effective_date
                                                         #1 MARKETING SERVICE, INC.                           NaN        NaN                                 Marketing Service               07/26/2016
                                                            1 MEDICAL SUPPLIES CORP                           NaN 1407487887                            DME & Medical Supplier               04/21/2022
                                                  1 STOP PHARMACY AND FOOD MART INC                      00028701 1275716979                                          Pharmacy               02/03/2009
                                     101 FIRST CARE PHARMACY INC AKA/DBA MI CASA

KeyboardInterrupt: Interrupted by user