# Efficiently serving Large Language Models in 2bit with `aqlm` and `transformers` compiled into a CUDA graph

<a target="_blank" href="https://colab.research.google.com/github/Vahe1994/AQLM/blob/main/notebooks/aqlm_cuda_graph.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Welcome to this notebook that goes through the recent `aqlm` integration that introduces efficient GPU utilization when serving LLMs quantized to 2bit.

In this notebook, we will learn how to load a large model in 2bit (`Llama-2-7b`) and comile a CUDA graph of it, to circumvent Python overhead whem serving the model.


**Install the `aqlm` library**
- The only extra dependency to run AQLM models.
- Add `[gpu]` to install the required CUDA specific dependencies.
- To use nice features like `device_map` you'll need to install accelerate. To properly support AQLM, you'd have to install the latest version straight from their GitHub (to catch [PR#2376](https://github.com/huggingface/accelerate/pull/2376)).

In [None]:
%%capture
!pip install aqlm[gpu]>=1.1.0
!pip install accelerate>=0.27.0
!pip install git+https://github.com/huggingface/transformers.git@main

**Load the model as usual**

The tokenizer is just a normal `Llama 2` tokenizer.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

quantized_model = AutoModelForCausalLM.from_pretrained(
    "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
    torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf")

Do a few forward passes to load CUDA and automatically compile the kernels. It's done separately here for it not to affect the generation speed benchmark below.

In [None]:
%%capture
output = quantized_model.generate(tokenizer("", return_tensors="pt")["input_ids"].cuda(), max_new_tokens=10)

**Measure generation speed**

In [None]:
import time

In [None]:
start = time.perf_counter()
output = quantized_model.generate(tokenizer("I'm AQLM, ", return_tensors="pt")["input_ids"].cuda(), min_new_tokens=128, max_new_tokens=128)
end = time.perf_counter()

In [None]:
print(f"Generating at {128 / (end - start):.1f} tok/s")

Generating at 8.5 tok/s


**Check that the output is what one would expect from Llama**

In [None]:
print(tokenizer.decode(output[0]))

<s> I'm AQLM, 20, and I'm from the UK. I'm a student at the University of Nottingham, studying English and Creative Writing. I'm a huge fan of the Harry Potter series, and I'm also a huge fan of the Marvel Cinematic Universe. I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan of the Star Wars franchise. I'm also a huge fan of the Marvel Cinematic Universe, and I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan


### Compile a CUDA graph

Note that `transformers` generation itself is not the fastest implementation and it's heavily influenced by CPU capabilities of _Google Colab_. We'll deal with it by using static caches and compiling the model's forward pass into a homogeneous CUDA graph, effectively removing python's overhead.

**We'll have to implement the logic around forward passes on our own since CUDA graphs are not yet integrated into transformers**

In [None]:
import torch

def decode_one_tokens(model, cur_token, input_pos, cache_position):
    logits = model(
        cur_token, position_ids=None, cache_position=cache_position, return_dict=False, use_cache=True
    )[0]
    new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)

    return new_token

MAX_NEW_TOKENS = 128

**Setup static KV cache for generation**

In [None]:
from transformers import StaticCache

input_ids = tokenizer("I'm AQLM, ", return_tensors="pt").to("cuda")["input_ids"]
seq_length = input_ids.shape[1]
quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + MAX_NEW_TOKENS * 2 + 1)

**Allocate token ids to be generated and copy prefix ids**

In [None]:
cache_position = torch.arange(seq_length, device="cuda")
generated_ids = torch.zeros(1, seq_length + MAX_NEW_TOKENS * 2, dtype=torch.int, device="cuda")
generated_ids[:, cache_position] = input_ids.to("cuda").to(torch.int)

**Do a forward pass to fill the prefix cache and compile the kernels if necessary**

In [None]:
logits = quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token

**Compile the CUDA graph with `torch.compile` and appply the forward pass repeatedly to generate text**

In [None]:
with torch.no_grad():
    # Compile the CUDA graph
    decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)

    # Generate tokens one by one
    cache_position = torch.tensor([seq_length + 1], device="cuda")
    for _ in range(1, MAX_NEW_TOKENS):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position)
            generated_ids.index_copy_(1, cache_position, next_token)
        cache_position += 1

In [None]:
print(tokenizer.decode(generated_ids[0]))

<s> I'm AQLM, 20, and I'm from the UK. I'm a student at the University of Nottingham, studying English and Creative Writing. I'm a huge fan of the Harry Potter series, and I'm also a huge fan of the Marvel Cinematic Universe. I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan of the Star Wars franchise. I'm also a huge fan of the Marvel Cinematic Universe, and I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><un

**Continue the generation mesuring the speed**

In [None]:
start = time.perf_counter()
with torch.no_grad():
    for _ in range(MAX_NEW_TOKENS):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position)
            generated_ids.index_copy_(1, cache_position, next_token)
        cache_position += 1
end = time.perf_counter()

In [None]:
print(f"Generating at {128 / (end - start):.1f} tok/s")

Generating at 24.0 tok/s


We achieved a **3x** speedup over normal generation using CUDA graphs, and the generated text is almost identical, as it should be.