# Loading GGML / Ollama weights into LitGPT

In [1]:
import thunder
import torch
import ggmltensor


def load_ggml_weights(model, fn):
    ggml_quant = ggmltensor.GgmlDataReader(fn)

    for n, p in model.named_parameters():
        qw, (typ, shape) = ggml_quant.get_parameter(n)
        with torch.no_grad():
            w = ggmltensor.dequantize(qw, typ, shape, dtype=p.dtype).to(p.device)
            p.copy_(w.t())

In [2]:
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional

import lightning as L
import torch
import torch._dynamo.config
import torch._inductor.config

# from lightning.fabric.plugins import BitsandbytesPrecision

from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint


def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
    if torch._dynamo.is_compiling():
        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
        distribution = torch.empty_like(probs).exponential_(1)
        return torch.argmax(probs / distribution, dim=-1, keepdim=True)
    return torch.multinomial(probs, num_samples=1)


def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
    # Example:
    # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
    # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # Keep at least 1 token always to prevent the case where no token is selected
    # In this case the most probable one is always kept
    sorted_indices_to_remove[-1:] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
    return logits


def sample(
    logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0
) -> torch.Tensor:
    if top_p < 0.0 or top_p > 1.0:
        raise ValueError(f"top_p must be in [0, 1], got {top_p}")
    logits = logits[0, -1]
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0 or top_p > 0.0:
        if temperature > 0.0:
            logits = logits / temperature
        # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
        if top_p < 1.0:
            logits = sample_top_p(logits, top_p)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)


def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    logits = model(x, input_pos)
    next = sample(logits, **kwargs)
    return next.to(dtype=x.dtype)


@torch.inference_mode()
def generate(
    model: GPT,
    prompt: torch.Tensor,
    max_returned_tokens: int,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

    The implementation of this function is modified from A. Karpathy's nanoGPT.

    Args:
        model: The model to use.
        prompt: Tensor of shape (T) with indices of the prompt sequence.
        max_returned_tokens: The maximum number of tokens to return (given plus generated).
        temperature: Scales the predicted logits by 1 / temperature.
        top_k: If specified, only sample among the tokens with the k highest probabilities.
        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
            In top-p sampling, the next token is sampled from the highest probability tokens
            whose cumulative probability exceeds the threshold `top_p`. When specified,
            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
            to sampling the most probable token, while `top_p=1` samples from the whole distribution.
            It can be used in conjunction with `top_k` and `temperature` with the following order
            of application:

            1. `top_k` sampling
            2. `temperature` scaling
            3. `top_p` sampling

            For more details, see https://arxiv.org/abs/1904.09751
            or https://huyenchip.com/2024/01/16/sampling.html#top_p
        eos_id: If specified, stop generating any more token once the <eos> token is triggered.
    """
    T = prompt.size(0)
    assert max_returned_tokens > T
    if model.max_seq_length < max_returned_tokens - 1:
        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
        # not support it to avoid negatively impacting the overall speed
        raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

    device = prompt.device
    tokens = [prompt]
    input_pos = torch.tensor([T], device=device)
    token = next_token(
        model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
    ).clone()
    tokens.append(token)
    for _ in range(2, max_returned_tokens - T + 1):
        token = next_token(
            model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
        ).clone()
        tokens.append(token)
        if token == eos_id:
            break
        input_pos = input_pos.add_(1)
    return torch.cat(tokens)

In [3]:
with torch.inference_mode():
    prompt: str = "What food do llamas eat?"
    num_samples: int = 1
    max_new_tokens: int = 256
    top_k: Optional[int] = 50
    top_p: float = 1.0
    temperature: float = 0.8
    checkpoint_dir: Path = Path(
        "/home/tv/data/firma/grid/thunder/litgpt/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/"
    )
    quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = "bnb.nf4"
    precision: Optional[str] = "bf16-true"
    compile: bool = False
    # litgpt generate base --quantize bnb.nf4 --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256

    """Generates text samples based on a pre-trained model and tokenizer.

    Args:
        prompt: The prompt string to use for generating the samples.
        num_samples: The number of text samples to generate.
        max_new_tokens: The number of generation steps to take.
        top_k: The number of top most probable tokens to consider in the sampling process.
        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
            In top-p sampling, the next token is sampled from the highest probability tokens
            whose cumulative probability exceeds the threshold `top_p`. When specified,
            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
            to sampling the most probable token, while `top_p=1` samples from the whole distribution.
            It can be used in conjunction with `top_k` and `temperature` with the following order
            of application:

            1. `top_k` sampling
            2. `temperature` scaling
            3. `top_p` sampling

            For more details, see https://arxiv.org/abs/1904.09751
            or https://huyenchip.com/2024/01/16/sampling.html#top_p
        temperature: A value controlling the randomness of the sampling process. Higher values result in more random
            samples.
        checkpoint_dir: The checkpoint directory to load.
        quantize: Whether to quantize the model and using which method:
            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
            - bnb.int8: 8-bit quantization from bitsandbytes
            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
        precision: Indicates the Fabric precision setting to use.
        compile: Whether to compile the model.
    """
    precision = precision or get_default_supported_precision(training=False)

    # plugins = BitsandbytesPrecision(mode='nf4', dtype=torch.bfloat16)

    precision = "bf16-true"

    fabric = L.Fabric(devices=1, precision=precision)  # , plugins=plugins)

    check_valid_checkpoint_dir(checkpoint_dir)
    config = Config.from_file(checkpoint_dir / "model_config.yaml")

    checkpoint_path = checkpoint_dir / "lit_model.pth"

    tokenizer = Tokenizer(checkpoint_dir)
    prompt_style = (
        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
    )

    prompt = prompt_style.apply(prompt)
    encoded = tokenizer.encode(prompt, device=fabric.device)
    prompt_length = encoded.size(0)
    max_returned_tokens = prompt_length + max_new_tokens

In [4]:
if 0:
    fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
    t0 = time.perf_counter()
    with fabric.init_module(empty_init=True):
        model = GPT(config)
    fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
    with fabric.init_tensor():
        # set the max_seq_length to limit the memory usage to what we need
        model.max_seq_length = max_returned_tokens
        # enable the kv cache
        model.set_kv_cache(batch_size=1)
    model.eval()

    if compile:
        torch._dynamo.config.automatic_dynamic_shapes = True
        torch._inductor.config.triton.unique_kernel_names = True
        torch._inductor.config.coordinate_descent_tuning = True
        global next_token
        next_token = torch.compile(next_token, mode="reduce-overhead")

    model = fabric.setup_module(model)

    ggml_fn = "~/.ollama/models/manifests/registry.ollama.ai/library/llama3/latest"

    t0 = time.perf_counter()
    # load_checkpoint(fabric, model, checkpoint_path)
    load_ggml_weights(model._original_module, ggml_fn)

    fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

In [5]:
if 0:
    with torch.inference_mode():
        L.seed_everything(1234)
        for i in range(num_samples):
            t0 = time.perf_counter()
            y = generate(
                model,
                encoded,
                max_returned_tokens,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                eos_id=tokenizer.eos_id,
            )
            t = time.perf_counter() - t0
            for block in model.transformer.h:
                block.attn.kv_cache.reset_parameters()
            fabric.print(tokenizer.decode(y))
            tokens_generated = y.size(0) - prompt_length
            fabric.print(
                f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
                file=sys.stderr,
            )
        if fabric.device.type == "cuda":
            fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

# Thunder transform

In [6]:
from collections.abc import Sequence

import thunder
from thunder.core.transform_common import Transform
from thunder.core import utils
from thunder.core import prims
import torch

from thunder.transforms.utils import (
    get_orig_and_thunder_module_proxies_from_prologue,
    get_checks,
    add_trace_output,
)

import ggmltensor

ggmlquant_executor = thunder.extend.OperatorExecutor("quant_ggml", version=0.1)


def ggmlquant_matmul_meta(x, qweight, ggmltype: int, shape):
    assert isinstance(shape, Sequence) and len(shape) == 2
    assert x.shape[-1] == shape[1], f"{x.shape=}, rhs {shape=}"
    return thunder.TensorProxy(like=x, shape=(*x.shape[:-1], shape[0]))


def ggmlquant_matmul_impl(x, qweight, ggmltype: int, shape):
    w = ggmltensor.dequantize(qweight, ggmltensor.GgmlType(ggmltype), shape, dtype=x.dtype)
    return x @ w


def ggmlquant_embed_meta(x, qweight, ggmltype: int, shape):
    assert isinstance(shape, Sequence) and len(shape) == 2
    # checks for mul
    return thunder.TensorProxy(like=x, shape=(*x.shape, shape[1]))


def ggmlquant_embed_impl(x, qweight, ggmltype: int, shape):
    w = ggmltensor.dequantize(qweight, ggmltensor.GgmlType(ggmltype), shape, dtype=torch.bfloat16)
    return torch.nn.functional.embedding(x, w.t())


ggmlquant_matmul = ggmlquant_executor.register_operator(
    "ggmlquant_matmul", meta=ggmlquant_matmul_meta, fn=ggmlquant_matmul_impl
)

ggmlquant_embed = ggmlquant_executor.register_operator(
    "ggmlquant_embed", meta=ggmlquant_embed_meta, fn=ggmlquant_embed_impl
)


class GGMLQuantTransform(Transform):
    def __init__(self, model_file_name, device):
        self.quant_states = {}
        self.quantized_submodule_names = set()
        self.device = device
        self.model_file_name = model_file_name

    def transform_module(self, model: thunder.ThunderModule):
        ggml_quant = ggmltensor.GgmlDataReader(self.model_file_name)
        self.thunder_module = model

        def convert_layer_with_weight(tm, name):
            self.quantized_submodule_names.add(name)
            weight_name = f"{name}.weight"
            w = tm.get_parameter(weight_name)
            qw, (typ, shape) = ggml_quant.get_parameter(weight_name)
            tm._overrides_parameters[weight_name] = qw.to(self.device)
            if not qw.is_floating_point():
                self.quant_states[weight_name] = {"typ": typ, "shape": shape}

        for n, submodule in model._model.named_modules():
            if hasattr(submodule, "weight"):
                convert_layer_with_weight(model, n)
        ggml_quant.close()

    def transform_state_dict_for_submodule(self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict):
        raise NotImplementedError("load weights ...")

    def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
        tm = self.thunder_module
        from thunder.core.trace import tracectx

        checks = get_checks(prologue_trace)

        compute_producers, compute_consumers = utils.producers_and_consumers(computation_trace)

        proglogue_to_compute_outputs = prologue_trace.output[0]

        output_idxes = {id(o): i for i, o in enumerate(proglogue_to_compute_outputs)}

        computation_trace.push_scope([])
        quantized_proxies: dict[int, str] = {}  # id -> name

        for n, qs in self.quant_states.items():
            param = tm.get_parameter(n)
            check, get_param = checks[n]
            quantized_proxies[id(get_param.output)] = n
            # check has args: tensor, shape, device, dtype, requires_grad
            proxy, _, _, _, requires_grad = check.args
            thunder_device = thunder.devices.to_device(param.device)
            thunder_device_str = thunder_device.device_str()
            check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)
        for n, param in tm.named_parameters():
            if n not in self.quant_states:
                check, get_param = checks[n]
                proxy, _, _, _, requires_grad = check.args
                thunder_device = thunder.devices.to_device(param.device)
                thunder_device_str = thunder_device.device_str()
                check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)

        new_computation_trace = thunder.core.trace.from_trace(computation_trace)

        proxies_to_replace = {}
        for bsym in computation_trace.bound_symbols:
            if bsym.sym == thunder.torch.linear and id(bsym.args[1]) in quantized_proxies:
                assert len(bsym.args) == 3  # torch.linear(input, weight, bias)
                assert bsym.args[2] is None
                n = quantized_proxies[id(bsym.args[1])]
                qs = self.quant_states[n]
                # signature of the new symbol:
                # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
                new_args = (
                    *bsym.args[:2],
                    qs["typ"].value,  # integer value
                    qs["shape"],
                )
                mm_bsym = bsym.from_bsym(
                    sym=ggmlquant_matmul,
                    subsymbols=[],
                    args=new_args,
                )

                new_computation_trace.bound_symbols.append(mm_bsym)
            elif bsym.sym == thunder.torch.embedding and id(bsym.args[1]) in quantized_proxies:
                assert len(bsym.args) == 7  # torch.linear(input, weight, bias)
                assert bsym.args[2] is None and bsym.args[3] is None
                assert bsym.args[5] is False and bsym.args[6] is False
                n = quantized_proxies[id(bsym.args[1])]
                qs = self.quant_states[n]
                new_args = (
                    *bsym.args[:2],
                    qs["typ"].value,  # integer value
                    qs["shape"],
                )
                emb_bsym = bsym.from_bsym(
                    sym=ggmlquant_embed,
                    subsymbols=[],
                    args=new_args,
                )

                new_computation_trace.bound_symbols.append(emb_bsym)
            else:
                new_computation_trace.bound_symbols.append(bsym.from_bsym())

        new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
        return prologue_trace, new_computation_trace, epilogue_trace

In [7]:
import thunder, torch

In [8]:
import thunder.tests.litgpt_model

In [9]:
import litgpt

with torch.device("meta"):
    m = thunder.tests.litgpt_model.GPT.from_name("Llama-3-8B-Instruct")
    m.requires_grad_(False)
    # del m.transformer.h[2:]
# enable the kv cache
device = "cuda"
with torch.device(device):
    m.max_seq_length = max_returned_tokens
    m.set_kv_cache(batch_size=1)
m.cos, m.sin = m.rope_cache(device=torch.device("cuda"))

In [10]:
model_file_name = "~/.ollama/models/manifests/registry.ollama.ai/library/llama3/latest"

quant_transform = GGMLQuantTransform(model_file_name, torch.device("cuda"))
tm = thunder.jit(m, transforms=[quant_transform])

In [11]:
# a = torch.randint(1, 100, (1, 64), device="cuda")
# tm(a)

In [12]:
model = tm

In [None]:
for block in model.transformer.h:
    block.attn.kv_cache.reset_parameters()
num_samples = 1
with torch.inference_mode():
    L.seed_everything(1234)
    for i in range(num_samples):
        t0 = time.perf_counter()
        y = generate(
            model,
            encoded,
            max_returned_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            eos_id=tokenizer.eos_id,
        )
        t = time.perf_counter() - t0
        for block in model.transformer.h:
            block.attn.kv_cache.reset_parameters()
        fabric.print(tokenizer.decode(y))
        tokens_generated = y.size(0) - prompt_length
        fabric.print(
            f"Time  for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
        )
    if fabric.device.type == "cuda":
        fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

Seed set to 1234


In [None]:
model.transformer.h[0].attn.kv_cache.v.shape

In [None]:
model._forward_module.max_seq_length

In [None]:
model._forward_module.config.n_head

In [None]:
model._model.mask_cache

In [None]:
m2 = thunder.tests.litgpt_model.OverridenKVCache((1, 32, 286, 128), (1, 32, 286, 128), device=torch.device("cuda"))

In [None]:
tm2 = thunder.jit(m2)

In [None]:
input_pos = torch.tensor([2], device="cuda")
k, v = torch.randn(2, 1, 32, 1, 128, device="cuda")
tm2(input_pos, k, v)

In [None]:
m2

In [None]:
torch.Tensor.index_add

In [None]:
?? m.transformer.h[0].attn.kv_cache.forward