In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os

In [2]:
def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:
    device = model.model.embed_tokens.weight.device
    origin_len = input_ids.shape[-1]
    input_ids = input_ids.to(device)
    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            out = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            logits = out.logits[:, -1, :]
            token = torch.argmax(logits, dim=-1, keepdim=True)
            output_ids = torch.cat([output_ids, token], dim=-1)
            past_key_values = out.past_key_values
            next_token = token.to(device)

            if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
                break
    return output_ids[:, origin_len:]

In [3]:
def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
    device = model.model.embed_tokens.weight.device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    cache = DynamicCache()

    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True
        )
    return cache

def clean_up(cache: DynamicCache, origin_len: int):
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]

In [None]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, token="", trust_remote_code=True) #INSERT HF TOKEN IN ""
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True,
    token="" #INSERT HF TOKEN IN ""
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Loaded {model_name}.")

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   7%|6         | 682M/9.94G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.hf.co/repos/ea/00/ea00943d992c7851ad9f4f4bd094a0397fb5087e0f7cba4ef003018963ea07e3/a464228d9c9427bf035cfd5a2e18bf0494d7231bfacc478ec5fc9f57a612d051?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model-00001-of-00002.safetensors%3B+filename%3D%22model-00001-of-00002.safetensors%22%3B&Expires=1743016682&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MzAxNjY4Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9lYS8wMC9lYTAwOTQzZDk5MmM3ODUxYWQ5ZjRmNGJkMDk0YTAzOTdmYjUwODdlMGY3Y2JhNGVmMDAzMDE4OTYzZWEwN2UzL2E0NjQyMjhkOWM5NDI3YmYwMzVjZmQ1YTJlMThiZjA0OTRkNzIzMWJmYWNjNDc4ZWM1ZmM5ZjU3YTYxMmQwNTE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=U5HFtYFbTXTcHBsLBGcXjFqZfezcWqwUYjFJ1fJJnawNWts2yXH6APxLoVvNBdGFobLSGHv%7E3MOZoZVuwY3mfxX8BA080X0tSrflV22Nh-ICTOB6ncRsN%7E5S9K%7E3dP2bDdnkgV4pdme%7Errf6uiyn5oRUxwqdmjnfA8FF54qOZjbyNjW47QGXzzXfr3sa1C62C2RBQc4pCHBBcJ7apOXJvpHFTfPnc1ntfZvpF

model-00001-of-00002.safetensors:  32%|###2      | 3.20G/9.94G [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Some parameters are on the meta device because they were offloaded to the disk and cpu.
You shouldn't move a model that is dispatched using accelerate hooks.


RuntimeError: You can't move a model that has some modules offloaded to cpu or disk.

In [None]:
import time

In [None]:
stime = time.time()
with open("Jon.txt", "r", encoding="utf-8") as f:
    doc_text = f.read()

system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers. You will read the text file
and respond with a max of 3 sentences. If you do not know, say "I don't know".
<|user|>
Context:
{doc_text}
Question:
""".strip()

ronan_cache = get_kv_cache(model, tokenizer, system_prompt)
origin_len = ronan_cache.key_cache[0].shape[-2]
print("KV cache built.")
etime = time.time()
ttime = etime - stime
print(f"total time: {ttime:.4f} seconds")

In [None]:
stime = time.time()
question1 = "What is the index 1 called?" #Change question to user input
clean_up(ronan_cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, ronan_cache)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)
print("Q1:", question1)
print("A1:", answer1)

etime = time.time()
ttime = etime - stime
print(f"total time: {ttime:.4f} seconds")

In [None]:
stime = time.time()
question2 = "What does Jonathan like?" #Change question to user input
clean_up(ronan_cache, origin_len)
input_ids_q2 = tokenizer(question2 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q2 = generate(model, input_ids_q2, ronan_cache)
answer2 = tokenizer.decode(gen_ids_q2[0], skip_special_tokens=True)
print("Q2:", question2)
print("A2:", answer2)

etime = time.time()
ttime = etime - stime
print(f"total time: {ttime:.4f} seconds")

In [None]:
stime = time.time()
question3 = "What is the document about?" #Change question to user input
clean_up(ronan_cache, origin_len)
input_ids_q3 = tokenizer(question3 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q3 = generate(model, input_ids_q3, ronan_cache)
answer3 = tokenizer.decode(gen_ids_q3[0], skip_special_tokens=True)
print("Q3:", question3)
print("A3:", answer3)

etime = time.time()
ttime = etime - stime
print(f"total time: {ttime:.4f} seconds")

In [None]:
stime = time.time()
question4 = "Does Jonathan like Magic the Gathering?" #Change question to user input
clean_up(ronan_cache, origin_len)
input_ids_q4 = tokenizer(question4 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q4 = generate(model, input_ids_q4, ronan_cache)
answer4 = tokenizer.decode(gen_ids_q4[0], skip_special_tokens=True)
print("Q4:", question4)
print("A4:", answer4)

etime = time.time()
ttime = etime - stime
print(f"total time: {ttime:.4f} seconds")