# Summary memory usage with CAG
    LLama 8B in 4bit quantization just load model takes - 7Gb
    With 24Gb GPU memory I can use cache with size of 32k tokens (in theory) 

In [1]:
import torch
import os
import json
from time import time
from sentence_transformers import SentenceTransformer, util
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache

---
## Load data

In [2]:
with open("./data/beauty-of-ares.txt", 'r', encoding='utf-8') as f:
    book = ''.join(f.readlines())

In [3]:
print(book[:500])

Beauty of Ares
Sleeping Beauty Inc. Book 3

Stephanie Van Orman

Copyright © 2023, 2024 Stephanie Van Orman
All rights reserved. No part of this book may be reproduced or used in any manner without
written permission of the copyright owner except for the use of written quotations in a book
review.
Any reference to historical events, real people or places are used fictitiously. Names, characters,
places are products of the author’s imagination.
Front cover image by Shutterstock
Book design by S


---
# Generate cache
from paper [repo](https://github.com/hhhuang/CAG/blob/main/kvcache.py#L79)

In [4]:
def preprocess_knowledge(
    model,
    tokenizer,
    prompt: str,
) -> DynamicCache:
    """
    Prepare knowledge kv cache for CAG.
    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: The knowledge to preprocess, which is basically a prompt

    Returns:
        DynamicCache: KV Cache
    """
    embed_device = model.model.embed_tokens.weight.device
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
    past_key_values = DynamicCache()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False
        )
    return outputs.past_key_values


def write_kv_cache(kv: DynamicCache, path: str):
    """
    Write the KV Cache to a file.
    """
    torch.save(kv, path)

In [5]:
def prepare_kvcache(documents, filepath: str = "./data_cache/cache_knowledges.pt", answer_instruction: str = None):
    # Prepare the knowledges kvcache

    if answer_instruction is None:
        answer_instruction = "Answer the question with detailed answer usin ONLY provided context."
    knowledges = f"""
    <|begin_of_text|>
    <|start_header_id|>system<|end_header_id|>
    You are an assistant for giving short answers based on given context.<|eot_id|>
    <|start_header_id|>user<|end_header_id|>
    Context information is bellow.
    ------------------------------------------------
    {documents}
    ------------------------------------------------
    {answer_instruction}
    Question:
    """
    # Get the knowledge cache
    t1 = time()
    kv = preprocess_knowledge(model, tokenizer, knowledges)
    print("kvlen: ", kv.key_cache[0].shape[-2])
    if filepath:
        write_kv_cache(kv, filepath)
    t2 = time()
    return kv, t2 - t1

In [6]:
# Define quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              # Load model in 4-bit precision
    bnb_4bit_quant_type="nf4",      # Normalize float 4 quantization
    bnb_4bit_compute_dtype=torch.bfloat16,  # Compute dtype for 4-bit base matrices
    bnb_4bit_use_double_quant=True  # Use nested quantization
)


def load_quantized_model(model_name, hf_token=None):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=hf_token
    )

    # Load model with quantization
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="cuda",         
        trust_remote_code=True,     # Required for some models
        token=hf_token
    )

    return tokenizer, model

In [7]:
tokenizer, model = load_quantized_model("meta-llama/Llama-3.1-8B-Instruct")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
def get_vram():
    free = torch.cuda.mem_get_info()[0] / 1024 ** 3
    total = torch.cuda.mem_get_info()[1] / 1024 ** 3
    total_cubes = 24
    free_cubes = int(total_cubes * free / total)
    print(f'VRAM: {total - free:.2f}/{total:.2f}GB\t VRAM:[' + (
            total_cubes - free_cubes) * '▮' + free_cubes * '▯' + ']')

In [9]:
get_vram()

VRAM: 6.99/23.99GB	 VRAM:[▮▮▮▮▮▮▮▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯]


---
# Caching

In [10]:
input_ids = tokenizer.encode(book, return_tensors="pt")
input_ids.shape[1]

111016

In [11]:
book_crop = tokenizer.decode(input_ids[0, :32000])
book_crop = '\n'.join(book_crop.split('\n')[:-1])

In [12]:
knowledge_cache, prepare_time = prepare_kvcache(book_crop, filepath=None, answer_instruction=None)
kv_len = knowledge_cache.key_cache[0].shape[-2]
print(f"KVcache prepared in {prepare_time} seconds")

kvlen:  32061
KVcache prepared in 6.443361282348633 seconds


In [13]:
get_vram()

VRAM: 23.99/23.99GB	 VRAM:[▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮]


In [14]:
torch.cuda.memory_reserved(0) / 1024 ** 3

21.904296875

---
# Ask some questions

In [15]:
def clean_up(kv: DynamicCache, origin_len: int):
    """
    Truncate the KV Cache to the original length.
    """
    for i in range(len(kv.key_cache)):
        kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
        kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]

def generate(
    model,
    input_ids: torch.Tensor,
    past_key_values,
    max_new_tokens: int = 300
) -> torch.Tensor:
    """
    Generate text with greedy decoding.

    Args:
        model: HuggingFace model with automatic device mapping
        input_ids: Input token ids
        past_key_values: KV Cache for knowledge
        max_new_tokens: Maximum new tokens to generate
    """

    embed_device = model.model.embed_tokens.weight.device

    origin_ids = input_ids
    input_ids = input_ids.to(embed_device)

    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=next_token, 
                past_key_values=past_key_values,
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            next_token = next_token.to(embed_device)

            past_key_values = outputs.past_key_values

            output_ids = torch.cat([output_ids, next_token], dim=1)

            if next_token.item() in model.config.eos_token_id:
                break
    return output_ids[:, origin_ids.shape[-1]:]

In [16]:
def ask_question(prompt, nocache=False, max_new_tokens=500):
    t1 = time()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    output = generate(model, input_ids, DynamicCache() if nocache else knowledge_cache, max_new_tokens=max_new_tokens)
    t2 = time() - t1
    clean_up(knowledge_cache, kv_len)
    print(f"generate time: {t2:.2f} sec\ntokens/sec: {len(output[0])/t2:.2f}\n\n" + tokenizer.decode(output[0], skip_special_tokens=True, temperature=None))

---

In [17]:
ask_question("What is the story in the book?")

generate time: 0.45 sec
tokens/sec: 2.22




In [18]:
ask_question("Make summary from context")

generate time: 33.20 sec
tokens/sec: 13.55

 of the story so far.

The story so far is about Lisbet, a young woman who was sold to Vantz Bloomburg, a trillionaire on Mars, by her father through the Sleeping Beauty Inc. agency. Lisbet is struggling to adjust to her new life on Mars and her role as Vantz's wife. She is frustrated with her situation and feels like she is being treated like a possession rather than a person. She is also struggling to understand the true nature of her husband's business dealings and the circumstances surrounding her arrival on Mars.

Lisbet is trying to navigate her new life on Mars, but she is finding it difficult to connect with her husband, Vantz, who is distant and uncommunicative. She is also struggling to understand the true nature of her role as a public relations officer for Vantz's terraforming project. She is frustrated with the lack of transparency and the secrecy surrounding the project.

Lisbet is also struggling to connect with her new environ

In [19]:
ask_question("Who is the main 10 persons in book?")

generate time: 0.18 sec
tokens/sec: 5.54




In [20]:
ask_question("Who is the Vantz Bloomburg?")

generate time: 0.18 sec
tokens/sec: 5.67




In [21]:
ask_question("What is Sleeping Beauty?")

generate time: 0.19 sec
tokens/sec: 5.39




In [22]:
ask_question("question: Where lived Lisbet?")

generate time: 0.19 sec
tokens/sec: 5.15


