# KV Cache Quantisation with TensorRT-LLM on a Single A100 GPU

This notebook demonstrates how I set up a **self-contained TensorRT-LLM workflow** on a single A100 GPU.  
I show baseline FP16 inference and then progressively apply **KV-cache quantisation** strategies to optimise memory usage and speed.



## 1. Environment Setup

Clone the TensorRT-LLM repository, install dependencies, and prepare CUDA.

In [None]:
!git clone https://github.com/NVIDIA/TensorRT-LLM.git
!pip install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
!pip install huggingface_hub pynvml mpi4py
!pip install -r TensorRT-LLM/examples/models/core/llama/requirements.txt

import os
os.environ["CUDA_HOME"] = "/usr/local/cuda"
print("CUDA_HOME =", os.environ["CUDA_HOME"])

# Install specific versions of cuda-python and nvidia-cudnn-cu12 for compatibility
!pip install --upgrade --force-reinstall cuda-python==12.2.1
!pip install nvidia-cudnn-cu12==8.9.2.26

print("CUDA_HOME =", os.environ["CUDA_HOME"])

Cloning into 'TensorRT-LLM'...
remote: Enumerating objects: 135576, done.[K
remote: Counting objects: 100% (319/319), done.[K
remote: Compressing objects: 100% (178/178), done.[K
remote: Total 135576 (delta 207), reused 141 (delta 141), pack-reused 135257 (from 2)[K
Receiving objects: 100% (135576/135576), 1.59 GiB | 25.90 MiB/s, done.
Resolving deltas: 100% (88806/88806), done.
Updating files: 100% (6623/6623), done.
Filtering content: 100% (2668/2668), 1.59 GiB | 20.63 MiB/s, done.
Looking in indexes: https://pypi.org/simple, https://pypi.nvidia.com
Collecting tensorrt_llm
  Downloading https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-1.2.0rc0-cp312-cp312-linux_x86_64.whl (3724.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 GB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
Collecting colored (from tensorrt_llm)
  Downloading colored-2.3.1-py3-none-any.whl.metadata (3.6 kB)
Collecting mpi4py (from tensorrt_llm)
  Downloading mpi4py-4.1.0-cp312-cp312-ma

[31mERROR: Operation cancelled by user[0m[31m
[0m^C


In [None]:
!pip install --upgrade --force-reinstall cuda-python

## 2. Download Model from Hugging Face

I use the Hugging Face Hub to fetch **Llama-3.2-1B**.  
The model is stored locally for conversion into TensorRT format.

In [1]:
from huggingface_hub import snapshot_download

snapshot_download(
    "meta-llama/Llama-3.2-1B",
    local_dir="tmp/hf_models/meta-llama/Llama-3.2-1B",
    max_workers=4
)

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

LICENSE.txt:   0%|          | 0.00/7.71k [00:00<?, ?B/s]

USE_POLICY.md:   0%|          | 0.00/6.02k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/41.2k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

original/consolidated.00.pth:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

params.json:   0%|          | 0.00/220 [00:00<?, ?B/s]

original/tokenizer.model:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

'/content/tmp/hf_models/meta-llama/Llama-3.2-1B'

## 3. Optimisation Scenario 0: Baseline FP16 Engine

First, I convert the HF checkpoint and build a **TensorRT FP16 engine**.  
This serves as the baseline for comparison against quantised versions.

In [4]:
import os
os.environ["CUDA_HOME"] = "/usr/local/cuda"
print("CUDA_HOME =", os.environ["CUDA_HOME"])

CUDA_HOME = /usr/local/cuda


In [5]:
!python ./TensorRT-LLM/examples/models/core/llama/convert_checkpoint.py \
  --model_dir ./tmp/hf_models/meta-llama/Llama-3.2-1B \
  --output_dir ./tmp/trt_engines/1-gpu/ \
  --dtype float16

print("Building FP16 Engine...")
!trtllm-build --checkpoint_dir ./tmp/trt_engines/1-gpu/ \
              --output_dir ./tmp/trt_engines/llama_fp16 \
              --gemm_plugin auto

Traceback (most recent call last):
  File "/content/./TensorRT-LLM/examples/models/core/llama/convert_checkpoint.py", line 10, in <module>
    import tensorrt_llm
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/__init__.py", line 66, in <module>
    import tensorrt_llm._torch.models as torch_models
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/__init__.py", line 1, in <module>
    from .llm import LLM
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/llm.py", line 1, in <module>
    from tensorrt_llm.llmapi.llm import _TorchLLM
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/__init__.py", line 1, in <module>
    from ..disaggregated_params import DisaggregatedParams
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/disaggregated_params.py", line 11, in <module>
    from tensorrt_llm.bindings import executor as tllme
ImportError: libcuda.so.1: cannot open shared object file: No such file or directory
Bui

## 4. Optimisation Scenario 1: INT8 KV Cache Only

Here, I quantise the **KV-cache to INT8**, while leaving weights in FP16.  
This reduces memory bandwidth usage during long sequence decoding.

In [None]:
!mkdir -p ./tmp/trt_engines/1-gpu-int8-ckpt

!python ./TensorRT-LLM/examples/models/core/llama/convert_checkpoint.py \
  --model_dir ./tmp/hf_models/meta-llama/Llama-3.2-1B \
  --output_dir ./tmp/trt_engines/1-gpu-int8-ckpt/ \
  --dtype float16 \
  --int8_kv_cache

!trtllm-build --checkpoint_dir ./tmp/trt_engines/1-gpu-int8-ckpt/ \
              --output_dir ./tmp/trt_engines/llama_int8_kv_cache_only \
              --gemm_plugin auto

## 5. Optimisation Scenario 2: INT8 KV Cache + INT8 Weight-Only Quantisation (W8A16)

Now, I combine **INT8 KV-cache** with **INT8 weight-only quantisation**.  
This further reduces model size while keeping activations in FP16.


In [None]:
!mkdir -p ./tmp/trt_engines/1-gpu-int8-kv-wq-ckpt

!python ./TensorRT-LLM/examples/models/core/llama/convert_checkpoint.py \
  --model_dir ./tmp/hf_models/meta-llama/Llama-3.2-1B \
  --output_dir ./tmp/trt_engines/1-gpu-int8-kv-wq-ckpt/ \
  --dtype float16 \
  --int8_kv_cache \
  --use_weight_only \
  --weight_only_precision int8

!trtllm-build --checkpoint_dir ./tmp/trt_engines/1-gpu-int8-kv-wq-ckpt/ \
              --output_dir ./tmp/trt_engines/llama_int8_kv_cache_int8_wq \
              --gemm_plugin auto

## 6. Optimisation Scenario 3: INT8 KV Cache + AWQ (W4A16, Group-wise)

Finally, I apply **Activation-Aware Quantisation (AWQ)**:  
- Weights: INT4 (group-wise, block size = 128)  
- KV-cache: INT8  
- Activations: FP16  

This is the most aggressive compression tested here.


In [None]:
!mkdir -p ./tmp/trt_engines/1-gpu-int8-kv-awq-ckpt

!python ./TensorRT-LLM/examples/quantization/quantize.py \
  --model_dir ./tmp/hf_models/meta-llama/Llama-3.2-1B \
  --output_dir ./tmp/trt_engines/1-gpu-int8-kv-awq-ckpt/ \
  --dtype float16 \
  --qformat int4_awq \
  --awq_block_size 128 \
  --kv_cache_dtype int8 \
  --calib_size 32

!trtllm-build --checkpoint_dir ./tmp/trt_engines/1-gpu-int8-kv-awq-ckpt/ \
              --output_dir ./tmp/trt_engines/llama_int8_kv_cache_int4_awq \
              --gemm_plugin auto


## 7. Evaluation: Short vs Long Context

I now benchmark across four backends and two settings:

1. **Short context (128 prompt tokens, 128 output tokens)**  
2. **Long context (real document text, ~2000 prompt tokens, 512 output tokens)**  

This lets me show how TensorRT-LLM optimisations (FP16, INT8 KV, W8A16, AWQ) scale from short inputs to long-context workloads where **KV cache quantisation is most beneficial**.

In [None]:
import time, torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorrt_llm.runtime import ModelRunner
from datasets import load_dataset

device = "cuda"
max_new_short = 128
max_new_long = 512

# Benchmark helpers
def bench_pt(prompt, max_new, iters=3):
    lat = []
    with torch.inference_mode():
        for _ in range(iters):
            t0 = time.time()
            _ = model.generate(**tokenizer(prompt, return_tensors="pt").to(device),
                               max_new_tokens=max_new)
            torch.cuda.synchronize()
            lat.append((time.time()-t0)*1000)
    ms = float(np.mean(lat))
    tps = max_new / (ms/1000.0)
    return ms, tps

def bench_trt(runner, prompt, max_new, iters=3):
    lat = []
    for _ in range(iters):
        t0 = time.time()
        _ = runner.generate(tokenizer(prompt, return_tensors="pt").to("cuda")["input_ids"],
                            max_new_tokens=max_new)
        torch.cuda.synchronize()
        lat.append((time.time()-t0)*1000)
    ms = float(np.mean(lat))
    tps = max_new / (ms/1000.0)
    return ms, tps

# Try to load runners if engine dirs exist
available_runners = {}

def try_load(name, path):
    import os
    if os.path.exists(path):
        try:
            r = ModelRunner.from_dir(engine_dir=path, rank=0)
            available_runners[name] = r
            print(f"Loaded {name} from {path}")
        except Exception as e:
            print(f"Could not load {name}: {e}")

# Add your engine paths here
try_load("TensorRT-LLM FP16", str(engine_fp16))
try_load("TensorRT-LLM INT8 KV", "./tmp/trt_engines/llama_int8_kv_cache_only")
try_load("TensorRT-LLM INT8 KV + W8A16", "./tmp/trt_engines/llama_int8_kv_cache_int8_wq")
try_load("TensorRT-LLM INT8 KV + AWQ", "./tmp/trt_engines/llama_int8_kv_cache_int4_awq")

# Prepare prompts
short_prompt = "Summarize the benefits of KV-cache in 3 concise bullet points."
gutenberg = load_dataset("gutenberg", "shakespeare-macbeth", split="train")
long_prompt = gutenberg[0]["text"][:4000]  # ~2000 tokens
print("Sample long prompt:\n", long_prompt[:300], "...\n")

# Run benchmarks
def run_benchmarks(prompt, max_new, label):
    results = []
    # PyTorch baseline
    pt_ms, pt_tps = bench_pt(prompt, max_new)
    results.append({"backend":"PyTorch FP16", "avg_ms":pt_ms, "tokens_per_sec":pt_tps})
    # TensorRT engines
    for name, runner in available_runners.items():
        ms, tps = bench_trt(runner, prompt, max_new)
        results.append({"backend":name, "avg_ms":ms, "tokens_per_sec":tps})
    df = pd.DataFrame(results)
    display(df)

    # Plot
    plt.figure(figsize=(8,4))
    plt.bar(df["backend"], df["tokens_per_sec"])
    plt.ylabel("Throughput (tokens/s)")
    plt.title(f"{label} (max_new={max_new})")
    plt.xticks(rotation=20, ha="right")
    plt.show()

    return df

print("### Short context benchmark")
df_short = run_benchmarks(short_prompt, max_new_short, "Short Context")

print("### Long context benchmark")
df_long = run_benchmarks(long_prompt, max_new_long, "Long Context")