diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index c1dde7046..5922e908f 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,5 @@ -from .llama_cpp import * -from .llama import * +from .llama_cpp import * # noqa: F401,F403 +from .llama import * # noqa: F401,F403 +from .raw import LlamaRaw # explicit export of thin raw API __version__ = "0.3.16" diff --git a/llama_cpp/_ggml.py b/llama_cpp/_ggml.py index 5bee8a93b..a2be13b45 100644 --- a/llama_cpp/_ggml.py +++ b/llama_cpp/_ggml.py @@ -2,11 +2,90 @@ This module provides a minimal interface for working with ggml tensors from llama-cpp-python """ +import enum import os import pathlib +import ctypes import llama_cpp._ctypes_extensions as ctypes_ext libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path) +# enum ggml_log_level { +# GGML_LOG_LEVEL_NONE = 0, +# GGML_LOG_LEVEL_DEBUG = 1, +# GGML_LOG_LEVEL_INFO = 2, +# GGML_LOG_LEVEL_WARN = 3, +# GGML_LOG_LEVEL_ERROR = 4, +# GGML_LOG_LEVEL_CONT = 5, // continue previous log +# }; + +class GGMLLogLevel(enum.IntEnum): + GGML_LOG_LEVEL_NONE = 0 + GGML_LOG_LEVEL_DEBUG = 1 + GGML_LOG_LEVEL_INFO = 2 + GGML_LOG_LEVEL_WARN = 3 + GGML_LOG_LEVEL_ERROR = 4 + GGML_LOG_LEVEL_CONT = 5 # continue previous log + +# // ====== ggml-opt.h ====== + +# enum ggml_opt_build_type { +# GGML_OPT_BUILD_TYPE_FORWARD = 10, +# GGML_OPT_BUILD_TYPE_GRAD = 20, +# GGML_OPT_BUILD_TYPE_OPT = 30, +# }; +class GGMLOptBuildType(enum.IntEnum): + GGML_OPT_BUILD_TYPE_FORWARD = 10 + GGML_OPT_BUILD_TYPE_GRAD = 20 + GGML_OPT_BUILD_TYPE_OPT = 30 + + +# // built-in loss types, i.e. the built-in quantities minimized by the optimizer +# // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value +# enum ggml_opt_loss_type { +# GGML_OPT_LOSS_TYPE_MEAN, +# GGML_OPT_LOSS_TYPE_SUM, +# GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, +# GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, +# }; +class GGMLOptLossType(enum.IntEnum): + GGML_OPT_LOSS_TYPE_MEAN = 0 + GGML_OPT_LOSS_TYPE_SUM = 1 + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY = 2 + GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR = 3 + + +# // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss +# struct ggml_opt_optimizer_params { +# // AdamW optimizer parameters +# struct { +# float alpha; // learning rate +# float beta1; +# float beta2; +# float eps; // epsilon for numerical stability +# float wd; // weight decay for AdamW, use 0.0f to disable +# } adamw; +# }; +class ggml_opt_adamw_params(ctypes.Structure): + _fields_ = [ + ('alpha', ctypes.c_float), # learning rate + ('beta1', ctypes.c_float), + ('beta2', ctypes.c_float), + ('eps', ctypes.c_float), # epsilon for numerical stability + ('wd', ctypes.c_float), # weight decay for AdamW, use 0.0f to disable + ] + +class ggml_opt_optimizer_params(ctypes.Structure): + _fields_ = [ + ('adamw', ggml_opt_adamw_params), # Nested AdamW parameters + ] + + +# // callback to calculate optimizer parameters prior to a backward pass +# // userdata can be used to pass arbitrary data +# typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); +ggml_opt_get_optimizer_params = ctypes.CFUNCTYPE( + ctypes.POINTER(ggml_opt_optimizer_params), ctypes.c_void_p +) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index b5175a7f2..1b1cd49ab 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -9,8 +9,6 @@ Tuple, Optional, Sequence, - Callable, - Union, ) from dataclasses import dataclass, field from contextlib import ExitStack @@ -50,7 +48,7 @@ def __init__( raise ValueError(f"Model path does not exist: {path_model}") with suppress_stdout_stderr(disable=verbose): - model = llama_cpp.llama_model_load_from_file( + model = llama_cpp.llama_load_model_from_file( self.path_model.encode("utf-8"), self.params ) @@ -64,7 +62,6 @@ def __init__( self.model = model self.vocab = vocab - self.sampler = None # LlamaModel doesn't use samplers, but some cleanup code expects this attribute def free_model(): if self.model is None: @@ -75,27 +72,28 @@ def free_model(): self._exit_stack.callback(free_model) def close(self): - if self.sampler is not None: - # NOTE: Must remove custom samplers before free or llama.cpp will try to free them - for i, _ in reversed(self.custom_samplers): - llama_cpp.llama_sampler_chain_remove(self.sampler, i) - self.custom_samplers.clear() self._exit_stack.close() def __del__(self): self.close() def vocab_type(self) -> int: - return llama_cpp.llama_vocab_type(self.vocab) + return llama_cpp.llama_vocab_type(self.model) def n_vocab(self) -> int: - return llama_cpp.llama_vocab_n_tokens(self.vocab) + return llama_cpp.llama_n_vocab(self.vocab) def n_ctx_train(self) -> int: - return llama_cpp.llama_model_n_ctx_train(self.model) + return llama_cpp.llama_n_ctx_train(self.model) def n_embd(self) -> int: - return llama_cpp.llama_model_n_embd(self.model) + return llama_cpp.llama_n_embd(self.model) + + def n_head_kv(self) -> int: + return llama_cpp.llama_model_n_head_kv(self.model) + + def n_params(self) -> int: + return llama_cpp.llama_model_n_params(self.model) def rope_freq_scale_train(self) -> float: return llama_cpp.llama_model_rope_freq_scale_train(self.model) @@ -108,9 +106,6 @@ def desc(self) -> str: def size(self) -> int: return llama_cpp.llama_model_size(self.model) - def n_params(self) -> int: - return llama_cpp.llama_model_n_params(self.model) - def get_tensor(self, name: str) -> ctypes.c_void_p: raise NotImplementedError("get_tensor is not implemented in llama.cpp") @@ -125,6 +120,12 @@ def token_get_score(self, token: int) -> float: def token_get_attr(self, token: int) -> int: return llama_cpp.llama_vocab_get_attr(self.vocab, token) + def token_is_eog(self, token: int) -> bool: + return llama_cpp.llama_vocab_is_eog(self.vocab, token) + + def token_is_control(self, token: int) -> bool: + return llama_cpp.llama_vocab_is_control(self.vocab, token) + # Special tokens def token_bos(self) -> int: @@ -133,8 +134,8 @@ def token_bos(self) -> int: def token_eos(self) -> int: return llama_cpp.llama_vocab_eos(self.vocab) - def token_cls(self) -> int: - return llama_cpp.llama_vocab_cls(self.vocab) + def token_eot(self) -> int: + return llama_cpp.llama_vocab_eot(self.vocab) def token_sep(self) -> int: return llama_cpp.llama_vocab_sep(self.vocab) @@ -142,24 +143,42 @@ def token_sep(self) -> int: def token_nl(self) -> int: return llama_cpp.llama_vocab_nl(self.vocab) - def token_prefix(self) -> int: + def token_pad(self) -> int: + return llama_cpp.llama_vocab_pad(self.vocab) + + def token_mask(self) -> int: + return llama_cpp.llama_vocab_mask(self.vocab) + + def token_cls(self) -> int: + return llama_cpp.llama_vocab_cls(self.vocab) + + def token_fim_pre(self) -> int: return llama_cpp.llama_vocab_fim_pre(self.vocab) - def token_middle(self) -> int: + def token_fim_suf(self) -> int: + return llama_cpp.llama_vocab_fim_suf(self.vocab) + + def token_fim_mid(self) -> int: return llama_cpp.llama_vocab_fim_mid(self.vocab) - def token_suffix(self) -> int: - return llama_cpp.llama_vocab_fim_suf(self.vocab) + def token_fim_pad(self) -> int: + return llama_cpp.llama_vocab_fim_pad(self.vocab) - def token_eot(self) -> int: - return llama_cpp.llama_vocab_eot(self.vocab) + def token_fim_rep(self) -> int: + return llama_cpp.llama_vocab_fim_rep(self.vocab) - def add_bos_token(self) -> bool: + def token_fim_sep(self) -> int: + return llama_cpp.llama_vocab_fim_sep(self.vocab) + + def get_add_bos(self) -> bool: return llama_cpp.llama_vocab_get_add_bos(self.vocab) - def add_eos_token(self) -> bool: + def get_add_eos(self) -> bool: return llama_cpp.llama_vocab_get_add_eos(self.vocab) + def get_add_sep(self) -> bool: + return llama_cpp.llama_vocab_get_add_sep(self.vocab) + # Tokenization def tokenize(self, text: bytes, add_bos: bool, special: bool): @@ -257,14 +276,12 @@ def __init__( self.verbose = verbose self._exit_stack = ExitStack() - ctx = llama_cpp.llama_init_from_model(self.model.model, self.params) + ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params) if ctx is None: raise ValueError("Failed to create llama_context") self.ctx = ctx - self.memory = llama_cpp.llama_get_memory(self.ctx) - self.sampler = None # LlamaContext doesn't manage samplers directly, but some cleanup code expects this attribute def free_ctx(): if self.ctx is None: @@ -283,29 +300,52 @@ def __del__(self): def n_ctx(self) -> int: return llama_cpp.llama_n_ctx(self.ctx) + def n_batch(self) -> int: + return llama_cpp.llama_n_batch(self.ctx) + + def n_batch(self) -> int: + return llama_cpp.llama_n_batch(self.ctx) + + def n_ubatch(self) -> int: + return llama_cpp.llama_n_ubatch(self.ctx) + + def n_seq_max(self) -> int: + return llama_cpp.llama_n_seq_max(self.ctx) + def pooling_type(self) -> int: return llama_cpp.llama_pooling_type(self.ctx) - def kv_cache_clear(self): - assert self.memory is not None, "Memory is not initialized" - llama_cpp.llama_memory_clear(self.memory, True) + # // Memory API + + def get_memory(self): + return llama_cpp.llama_get_memory(self.ctx) + + def memory_clear(self, data: bool): + llama_cpp.llama_memory_clear(self.get_memory(), data) + + def memory_seq_rm(self, seq_id: int, p0: int, p1: int) -> bool: + if self.ctx is not None and seq_id >= 0: + return llama_cpp.llama_memory_seq_rm(self.get_memory(), seq_id, p0, p1) + else: + return False + + def memory_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): + llama_cpp.llama_memory_seq_cp(self.get_memory(), seq_id_src, seq_id_dst, p0, p1) - def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): - assert self.memory is not None, "Memory is not initialized" - seq_id = seq_id if seq_id >= 0 else 0 - llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1) + def memory_seq_keep(self, seq_id: int): + llama_cpp.llama_memory_seq_keep(self.get_memory(), seq_id) - def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): - assert self.memory is not None, "Memory is not initialized" - llama_cpp.llama_memory_seq_cp(self.memory, seq_id_src, seq_id_dst, p0, p1) + def memory_seq_add(self, seq_id: int, p0: int, p1: int, delta: int): + llama_cpp.llama_memory_seq_add(self.get_memory(), seq_id, p0, p1, delta) - def kv_cache_seq_keep(self, seq_id: int): - assert self.memory is not None, "Memory is not initialized" - llama_cpp.llama_memory_seq_keep(self.memory, seq_id) + def memory_seq_div(self, seq_id: int, p0: int, p1: int, d: int): + llama_cpp.llama_memory_seq_div(self.get_memory(), seq_id, p0, p1, d) - def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): - assert self.memory is not None, "Memory is not initialized" - llama_cpp.llama_memory_seq_add(self.memory, seq_id, p0, p1, shift) + def memory_seq_pos_max(self, seq_id: int) -> int: + return llama_cpp.llama_memory_seq_pos_max(self.get_memory(), seq_id) + + def memory_seq_pos_min(self, seq_id: int) -> int: + return llama_cpp.llama_memory_seq_pos_min(self.get_memory(), seq_id) def get_state_size(self) -> int: return llama_cpp.llama_state_get_size(self.ctx) @@ -326,14 +366,6 @@ def decode(self, batch: LlamaBatch): if return_code != 0: raise RuntimeError(f"llama_decode returned {return_code}") - def encode(self, batch: LlamaBatch): - return_code = llama_cpp.llama_encode( - self.ctx, - batch.batch, - ) - if return_code != 0: - raise RuntimeError(f"llama_encode returned {return_code}") - def set_n_threads(self, n_threads: int, n_threads_batch: int): llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) @@ -346,16 +378,12 @@ def get_logits_ith(self, i: int): def get_embeddings(self): return llama_cpp.llama_get_embeddings(self.ctx) - def get_embeddings_ith(self, i: int): - return llama_cpp.llama_get_embeddings_ith(self.ctx, i) - - def get_embeddings_seq(self, seq_id: int): - return llama_cpp.llama_get_embeddings_seq(self.ctx, seq_id) - - # Sampling functions - deprecated, use LlamaSampler instead + # Sampling functions def set_rng_seed(self, seed: int): - raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead") + # TODO: Fix + # llama_cpp.llama_set_rng_seed(self.ctx, seed) + raise NotImplementedError("set_rng_seed is not implemented in llama.cpp") def sample_repetition_penalties( self, @@ -366,30 +394,56 @@ def sample_repetition_penalties( penalty_freq: float, penalty_present: float, ): - raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead") - - def sample_softmax(self, candidates: "_LlamaTokenDataArray"): - raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_repetition_penalties( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # last_tokens_data, + # penalty_last_n, + # penalty_repeat, + # penalty_freq, + # penalty_present, + # ) + raise NotImplementedError("sample_repetition_penalties is not implemented in llama.cpp") def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): - raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_top_k( + # self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep + # ) + raise NotImplementedError("sample_top_k is not implemented in llama.cpp") def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_top_p( + # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + # ) + raise NotImplementedError("sample_top_p is not implemented in llama.cpp") def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_min_p( + # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + # ) + raise NotImplementedError("sample_min_p is not implemented in llama.cpp") def sample_typical( self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int ): - raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_typical( + # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + # ) + raise NotImplementedError("sample_typical is not implemented in llama.cpp") def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): - raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_temp( + # self.ctx, llama_cpp.byref(candidates.candidates), temp + # ) + raise NotImplementedError("sample_temp is not implemented in llama.cpp") def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): - raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead") + # llama_cpp.llama_sample_grammar( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # grammar.grammar, + # ) + raise NotImplementedError("sample_grammar is not implemented in llama.cpp") def sample_token_mirostat( self, @@ -399,7 +453,15 @@ def sample_token_mirostat( m: int, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], ) -> int: - raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead") + raise NotImplementedError("sample_token_mirostat is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token_mirostat( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # tau, + # eta, + # m, + # mu, + # ) def sample_token_mirostat_v2( self, @@ -408,17 +470,33 @@ def sample_token_mirostat_v2( eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], ) -> int: - raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead") + raise NotImplementedError("sample_token_mirostat_v2 is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token_mirostat_v2( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # tau, + # eta, + # mu, + # ) def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: - raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead") + raise NotImplementedError("sample_token_greedy is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token_greedy( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # ) def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: - raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead") + raise NotImplementedError("sample_token is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # ) # Grammar def grammar_accept_token(self, grammar: LlamaGrammar, token: int): - raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead") + raise NotImplementedError("grammar_accept_token is not implemented in llama.cpp") + # llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token) def reset_timings(self): llama_cpp.llama_perf_context_reset(self.ctx) @@ -449,7 +527,6 @@ def __init__( raise ValueError("Failed to create llama_batch") self.batch = batch - self.sampler = None # LlamaBatch doesn't use samplers, but some cleanup code expects this attribute def free_batch(): if self.batch is None: @@ -508,11 +585,11 @@ def __init__(self, *, n_vocab: int): self.candidates = llama_cpp.llama_token_data_array( data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), size=self.n_vocab, + selected=-1, sorted=False, ) self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) - self.sampler = None # LlamaTokenDataArray doesn't use samplers, but some cleanup code expects this attribute def copy_logits(self, logits: npt.NDArray[np.single]): self.candidates_data.id[:] = self.default_candidates_data_id @@ -540,9 +617,9 @@ class LlamaSamplingParams: n_prev: int = 64 n_probs: int = 0 top_k: int = 40 + top_n_sigma: float = -1.00 top_p: float = 0.95 min_p: float = 0.05 - tfs_z: float = 1.00 typical_p: float = 1.00 temp: float = 0.80 penalty_last_n: int = 64 @@ -554,6 +631,9 @@ class LlamaSamplingParams: mirostat_eta: float = 0.10 penalize_nl: bool = True + xtc_threshold: float = 0.1 + xtc_probability: float = 0.0 + grammar: str = "" cfg_negative_prompt: str = "" @@ -601,13 +681,99 @@ def sample( idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None, ): - # This method is deprecated in favor of using LlamaSampler directly - raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead") + n_vocab = ctx_main.model.n_vocab() + id: int = 0 + + if logits_array is None: + logits = ctx_main.get_logits_ith(idx) + logits_array = np.array( + ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, + dtype=np.single, + ) + + # apply logit_bias + for token, logit_bias in self.params.logit_bias.items(): + logits_array[token] += logit_bias + + token_data_array = LlamaTokenDataArray( + n_vocab=n_vocab + ) # TODO: Only create this once + token_data_array.copy_logits(logits_array) + + # apply penalties + if len(self.prev) > 0: + nl_token = ctx_main.model.token_nl() + nl_logit = logits_array[nl_token] + last_tokens = self.prev[-self.params.penalty_last_n :] + last_tokens_size = min(len(last_tokens), self.params.penalty_last_n) + if last_tokens_size > 0: + last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens) + ctx_main.sample_repetition_penalties( + token_data_array, + last_tokens_p, + last_tokens_size, + self.params.penalty_repeat, + self.params.penalty_freq, + self.params.penalty_present, + ) + if not self.params.penalize_nl: + token_data_array.candidates_data.logit[nl_token] = nl_logit + + if self.grammar is not None: + ctx_main.sample_grammar(token_data_array, self.grammar) + + if self.params.temp < 0: + id = token_data_array.candidates_data.id[0] + elif self.params.temp == 0: + id = ctx_main.sample_token_greedy(token_data_array) + else: + if self.params.mirostat == 1: + mirostat_m = 100 + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + mirostat_m, + ctypes.pointer(self.mirostat_mu), + ) + elif self.params.mirostat == 2: + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat_v2( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + ctypes.pointer(self.mirostat_mu), + ) + else: + min_keep = max(1, self.params.n_probs) + ctx_main.sample_top_k( + token_data_array, self.params.top_k, min_keep=min_keep + ) + ctx_main.sample_typical( + token_data_array, self.params.typical_p, min_keep=min_keep + ) + ctx_main.sample_top_p( + token_data_array, self.params.top_p, min_keep=min_keep + ) + ctx_main.sample_min_p( + token_data_array, self.params.min_p, min_keep=min_keep + ) + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token(token_data_array) + return id def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool): + if apply_grammar and self.grammar is not None: + ctx_main.grammar_accept_token(self.grammar, id) self.prev.append(id) +from typing import List, Callable, Optional, Union +import ctypes +import llama_cpp + + class CustomSampler: def __init__( self, apply_func: Callable[[llama_cpp.llama_token_data_array], None] @@ -633,127 +799,163 @@ def free_wrapper(sampler: llama_cpp.llama_sampler_p): sampler_i.clone = llama_cpp.llama_sampler_i_clone(0) sampler_i.free = llama_cpp.llama_sampler_i_free(0) - self.sampler = llama_cpp.llama_sampler() - self.sampler.iface = ctypes.pointer(sampler_i) - self.sampler.ctx = None + self.sampler = llama_cpp.llama_sampler_init(ctypes.pointer(sampler_i), None) def get_sampler(self) -> llama_cpp.llama_sampler_p: - return ctypes.pointer(self.sampler) + return self.sampler class LlamaSampler: def __init__(self): - params = llama_cpp.llama_sampler_chain_default_params() + params = llama_cpp.llama_sampler_chain_params() self.sampler = llama_cpp.llama_sampler_chain_init(params) + self.samplers: List[llama_cpp.llama_sampler_p] = [] self.custom_samplers: List[Tuple[int, CustomSampler]] = [] - self._exit_stack = ExitStack() - - def free_sampler(): - if self.sampler is not None: - # NOTE: Must remove custom samplers before free or llama.cpp will try to free them - for i, _ in reversed(self.custom_samplers): - llama_cpp.llama_sampler_chain_remove(self.sampler, i) - llama_cpp.llama_sampler_free(self.sampler) - self.sampler = None - - self._exit_stack.callback(free_sampler) - - def close(self): - self._exit_stack.close() - - def __del__(self): - self.close() def add_greedy(self): sampler = llama_cpp.llama_sampler_init_greedy() - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_dist(self, seed: int): sampler = llama_cpp.llama_sampler_init_dist(seed) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) - - def add_softmax(self): - sampler = llama_cpp.llama_sampler_init_softmax() - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_top_k(self, k: int): sampler = llama_cpp.llama_sampler_init_top_k(k) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) - def add_top_p(self, p: float, min_keep: int = 1): + def add_top_p(self, p: float, min_keep: int): sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) - def add_min_p(self, p: float, min_keep: int = 1): + def add_min_p(self, p: float, min_keep: int): sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) - def add_typical(self, p: float, min_keep: int = 1): + def add_typical(self, p: float, min_keep: int): sampler = llama_cpp.llama_sampler_init_typical(p, min_keep) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) + + def add_xtc(self, p: float, t: float, min_keep: int, seed: int): + sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed) + self._add_sampler(sampler) def add_temp(self, temp: float): sampler = llama_cpp.llama_sampler_init_temp(temp) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_temp_ext(self, t: float, delta: float, exponent: float): sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) - - def add_xtc(self, p: float, t: float, min_keep: int, seed: int): - sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_top_n_sigma(self, n: float): sampler = llama_cpp.llama_sampler_init_top_n_sigma(n) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int): sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_mirostat_v2(self, seed: int, tau: float, eta: float): sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar): sampler = llama_cpp.llama_sampler_init_grammar( model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8") ) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) + + def convert_list_str_to_char_array_ptr(self, str_list: List[str]): + """ + Converts a list of strings to a char** array for C interop, and returns two values: + the char** array and the number of bytes in the list. + + Args: + str_list: List of string objects. + + Returns: + - A ctypes pointer to a char** array. + - The number of strings in the input list. + """ + # Encode strings to bytes + byte_list = [s.encode('utf-8') for s in str_list] + # Calculate the number of breakers + num_byte_list= len(byte_list) + # Define the type of a char pointer + char_ptr_type = ctypes.POINTER(ctypes.c_char) + # Define the type of an array of char pointers + char_ptr_array_type = char_ptr_type * num_byte_list + + # Allocate memory for the array of char pointers + char_ptr_array = char_ptr_array_type() + + # Populate the array with pointers to the byte strings + for i, byte_string in enumerate(byte_list): + # Create a null-terminated C-style string buffer + c_char_array = ctypes.create_string_buffer(byte_string) + # Cast the buffer to a char pointer and assign it to the array + char_ptr_array[i] = ctypes.cast(c_char_array, char_ptr_type) + + char_array_ptr = ctypes.cast(char_ptr_array, ctypes.POINTER(char_ptr_type)) + + # Return the char** pointer and the number of strings + return char_array_ptr, num_byte_list + + def add_grammar_lazy( + self, + model: LlamaModel, + grammar: LlamaGrammar, + trigger_tokens:list[llama_cpp.llama_token], + num_trigger_tokens: int, + trigger_words: list[str]=[] + ): + trigger_words_char_array_ptr, num_trigger_words = self.convert_list_str_to_char_array_ptr(trigger_words) + + sampler = llama_cpp.llama_sampler_init_grammar_lazy( + model.vocab, + grammar._grammar.encode("utf-8"), + grammar._root.encode("utf-8"), + trigger_words_char_array_ptr, + num_trigger_words, + trigger_tokens, + num_trigger_tokens + ) + self._add_sampler(sampler) def add_grammar_lazy_patterns( - self, - model: LlamaModel, - grammar: LlamaGrammar, - trigger_patterns: List[str], - trigger_tokens: List[int] - ): - # Convert patterns to C array - pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))() - for i, pattern in enumerate(trigger_patterns): - pattern_ptrs[i] = pattern.encode("utf-8") - - # Convert tokens to C array - token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens) - + self, + model: LlamaModel, + grammar: LlamaGrammar, + num_trigger_patterns: int, + trigger_tokens:list[llama_cpp.llama_token], + num_trigger_tokens: int, + trigger_patterns: list[str]=[] + ): + trigger_patterns_char_array_ptr, num_trigger_patterns = self.convert_list_str_to_char_array_ptr(trigger_patterns) sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns( model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8"), - pattern_ptrs, - len(trigger_patterns), - token_array, - len(trigger_tokens) + trigger_patterns_char_array_ptr, + num_trigger_patterns, + trigger_tokens, + num_trigger_tokens ) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) + def add_penalties( self, + n_vocab: int, + special_eos_id: int, + linefeed_id: int, penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, + penalize_nl: bool, + ignore_eos: bool, ): sampler = llama_cpp.llama_sampler_init_penalties( penalty_last_n, @@ -761,96 +963,79 @@ def add_penalties( penalty_freq, penalty_present, ) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_dry( self, model: LlamaModel, - n_ctx_train: int, dry_multiplier: float, dry_base: float, dry_allowed_length: int, dry_penalty_last_n: int, - seq_breakers: List[str] + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"] ): - # Convert seq_breakers to C array - breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))() - for i, breaker in enumerate(seq_breakers): - breaker_ptrs[i] = breaker.encode("utf-8") - + + dry_seq_breakers_char_array_ptr, num_seq_breakers = self.convert_list_str_to_char_array_ptr(dry_seq_breakers) + sampler = llama_cpp.llama_sampler_init_dry( model.vocab, - n_ctx_train, + model.n_ctx_train(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - breaker_ptrs, - len(seq_breakers) + dry_seq_breakers_char_array_ptr, + num_seq_breakers ) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) def add_logit_bias( - self, - n_vocab: int, - logit_bias: Dict[int, float] + self, n_vocab: int, logit_bias: Dict[int, float] ): - # Convert logit_bias dict to C array - bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))() + # Construct a C array to store the contents of the logit_bias dictionary + logit_bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))() + for i, (token, bias) in enumerate(logit_bias.items()): - bias_array[i].token = token - bias_array[i].bias = bias - - sampler = llama_cpp.llama_sampler_init_logit_bias( - n_vocab, - len(logit_bias), - bias_array - ) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + logit_bias_array[i].token = token + logit_bias_array[i].bias = bias - def add_infill(self, model: LlamaModel): - sampler = llama_cpp.llama_sampler_init_infill(model.vocab) - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, len(logit_bias), logit_bias_array) + self._add_sampler(sampler) def add_custom( self, apply_func: Callable[[llama_cpp.llama_token_data_array], None] ): custom_sampler = CustomSampler(apply_func) sampler = custom_sampler.get_sampler() - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self._add_sampler(sampler) # NOTE: Must remove custom samplers before free or llama.cpp will try to free them self.custom_samplers.append( - (llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler) + [llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler] ) + def _add_sampler(self, sampler: llama_cpp.llama_sampler_p): + assert self.sampler is not None + llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + self.samplers.append(sampler) + def get_seed(self) -> int: + assert self.sampler is not None return llama_cpp.llama_sampler_get_seed(self.sampler) - def sample(self, ctx: LlamaContext, idx: int = -1) -> int: + def sample(self, ctx: LlamaContext, idx: ctypes.c_int32) -> ctypes.c_int32: + assert self.sampler is not None + assert ctx.ctx is not None return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx) - def accept(self, token: int): - llama_cpp.llama_sampler_accept(self.sampler, token) + def close(self): + if self.sampler: + # NOTE: Must remove custom samplers before free or llama.cpp will try to free them + for i, _ in reversed(self.custom_samplers): + llama_cpp.llama_sampler_chain_remove(self.sampler, i) + llama_cpp.llama_sampler_free(self.sampler) + self.sampler = None + self.samplers.clear() + self.custom_samplers.clear() - def reset(self): - llama_cpp.llama_sampler_reset(self.sampler) - - def clone(self): - # NOTE: Custom samplers cannot be cloned due to Python callback limitations - if self.custom_samplers: - raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers") - - cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler) - # Create a new wrapper around the cloned sampler - new_sampler = LlamaSampler.__new__(LlamaSampler) - new_sampler.sampler = cloned_sampler - new_sampler.custom_samplers = [] - new_sampler._exit_stack = ExitStack() - - def free_sampler(): - if new_sampler.sampler is not None: - llama_cpp.llama_sampler_free(new_sampler.sampler) - new_sampler.sampler = None - - new_sampler._exit_stack.callback(free_sampler) - return new_sampler + def __del__(self): + self.close() diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..b6c9b3d77 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -79,8 +79,14 @@ def __init__( n_threads_batch: Optional[int] = None, rope_scaling_type: Optional[ int - ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + ] = llama_cpp.llama_rope_scaling_type.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, + attention_type: Optional[ + int + ] = llama_cpp.llama_attention_type.LLAMA_ATTENTION_TYPE_UNSPECIFIED, + flash_attn_type: Optional[ + int + ] = llama_cpp.llama_flash_attn_type.LLAMA_FLASH_ATTN_TYPE_AUTO, rope_freq_base: float = 0.0, rope_freq_scale: float = 0.0, yarn_ext_factor: float = -1.0, @@ -91,9 +97,9 @@ def __init__( logits_all: bool = False, embedding: bool = False, offload_kqv: bool = True, - flash_attn: bool = False, op_offload: Optional[bool] = None, swa_full: Optional[bool] = None, + kv_unified: Optional[bool] = None, # Sampling Params no_perf: bool = False, last_n_tokens_size: int = 64, @@ -161,8 +167,10 @@ def __init__( n_ubatch: Physical batch size n_threads: Number of threads to use for generation n_threads_batch: Number of threads to use for batch processing - rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054 + rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggml-org/llama.cpp/pull/2054 pooling_type: Pooling type, from `enum llama_pooling_type`. + attention_type: attention type to use for embeddings + flash_attn_type: when to enable Flash Attention rope_freq_base: RoPE base frequency, 0 = from model rope_freq_scale: RoPE frequency scaling factor, 0 = from model yarn_ext_factor: YaRN extrapolation mix factor, negative = from model @@ -173,9 +181,9 @@ def __init__( logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. offload_kqv: Offload K, Q, V to GPU. - flash_attn: Use flash attention. - op_offload: offload host tensor operations to device - swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + op_offload: whether to offload host tensor operations to device + swa_full: whether to use full-size SWA cache + kv_unified: use single unified KV buffer for the KV cache of all sequences no_perf: Measure performance timings. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. @@ -257,28 +265,28 @@ def __init__( for i, (k, v) in enumerate(kv_overrides.items()): self._kv_overrides_array[i].key = k.encode("utf-8") if isinstance(v, bool): - self._kv_overrides_array[ - i - ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL + self._kv_overrides_array[i].tag = ( + llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL + ) self._kv_overrides_array[i].value.val_bool = v elif isinstance(v, int): - self._kv_overrides_array[ - i - ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT + self._kv_overrides_array[i].tag = ( + llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT + ) self._kv_overrides_array[i].value.val_i64 = v elif isinstance(v, float): - self._kv_overrides_array[ - i - ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT + self._kv_overrides_array[i].tag = ( + llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT + ) self._kv_overrides_array[i].value.val_f64 = v elif isinstance(v, str): # type: ignore v_bytes = v.encode("utf-8") if len(v_bytes) > 128: # TODO: Make this a constant raise ValueError(f"Value for {k} is too long: {v}") v_bytes = v_bytes.ljust(128, b"\0") - self._kv_overrides_array[ - i - ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR + self._kv_overrides_array[i].tag = ( + llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR + ) # copy min(v_bytes, 128) to str_value address = typing.cast( int, @@ -294,9 +302,9 @@ def __init__( else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array[ - -1 - ].key = b"\0" # ensure sentinel element is zeroed + self._kv_overrides_array[-1].key = ( + b"\0" # ensure sentinel element is zeroed + ) self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ??? @@ -316,39 +324,23 @@ def __init__( self.context_params.rope_scaling_type = ( rope_scaling_type if rope_scaling_type is not None - else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED - ) - self.context_params.pooling_type = pooling_type - self.context_params.rope_freq_base = ( - rope_freq_base if rope_freq_base != 0.0 else 0 - ) - self.context_params.rope_freq_scale = ( - rope_freq_scale if rope_freq_scale != 0.0 else 0 - ) - self.context_params.yarn_ext_factor = ( - yarn_ext_factor if yarn_ext_factor != 0.0 else 0 - ) - self.context_params.yarn_attn_factor = ( - yarn_attn_factor if yarn_attn_factor != 0.0 else 0 - ) - self.context_params.yarn_beta_fast = ( - yarn_beta_fast if yarn_beta_fast != 0.0 else 0 + else llama_cpp.llama_rope_scaling_type.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ) - self.context_params.yarn_beta_slow = ( - yarn_beta_slow if yarn_beta_slow != 0.0 else 0 + self.context_params.pooling_type = ( + pooling_type + if pooling_type is not None + else llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED ) - self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 - self._logits_all = logits_all if draft_model is None else True - self.context_params.embeddings = embedding # TODO: Rename to embeddings - self.context_params.offload_kqv = offload_kqv - self.context_params.flash_attn = flash_attn - - if op_offload is not None: - self.context_params.op_offload = op_offload + # Propagate embedding mode into context params (was previously omitted) + self.context_params.embeddings = embedding + # swa_full override (placed after field assignments) if swa_full is not None: self.context_params.swa_full = swa_full + if kv_unified is not None: + self.context_params.kv_unified = kv_unified + # KV cache quantization if type_k is not None: self.context_params.type_k = type_k @@ -358,6 +350,10 @@ def __init__( self.context_params.no_perf = no_perf self.last_n_tokens_size = last_n_tokens_size + # Store whether we want to keep logits for all tokens (needed for logprobs) + # Must be set before any property access that references _logits_all + self._logits_all = logits_all if draft_model is None else True + self.cache: Optional[BaseLlamaCache] = None self.lora_base = lora_base @@ -443,9 +439,9 @@ def free_lora_adapter(): self.chat_format = chat_format self.chat_handler = chat_handler - self._chat_handlers: Dict[ - str, llama_chat_format.LlamaChatCompletionHandler - ] = {} + self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = ( + {} + ) self.draft_model = draft_model @@ -637,7 +633,7 @@ def eval(self, tokens: Sequence[int]): Args: tokens: The list of tokens to evaluate. """ - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + self._ctx.memory_seq_rm(0, self.n_tokens, -1) for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = self.n_tokens @@ -671,6 +667,7 @@ def eval(self, tokens: Sequence[int]): def _init_sampler( self, top_k: int = 40, + top_n_sigma: float = -1.00, top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, @@ -678,55 +675,48 @@ def _init_sampler( repeat_penalty: float = 1.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], penalize_nl: bool = True, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, ): sampler = internals.LlamaSampler() - if logits_processor is not None: - # Create and add a custom sampler - def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): - size = token_data_array.contents.size - data_soa = token_data_array.contents.data - data_soa_address = ctypes.addressof(data_soa.contents) - # NOTE: This is probably broken - recarray = np.recarray( - shape=(size,), - dtype=np.dtype( - [("id", np.intc), ("logit", np.single), ("p", np.single)], - align=True, - ), - buf=(llama_cpp.llama_token_data * size).from_address( - data_soa_address - ), - ) - for logit_processor in logits_processor: - recarray.logit[:] = logit_processor(self._input_ids, recarray.logit) - - sampler.add_custom(apply_func) + if logit_bias is not None: + sampler.add_logit_bias(self.n_vocab(), logit_bias) sampler.add_penalties( - # n_vocab=self._n_vocab, - # special_eos_id=self._token_eos, - # linefeed_id=self._token_nl, + n_vocab=self._n_vocab, + special_eos_id=self._token_eos, + linefeed_id=self._token_nl, penalty_last_n=self.last_n_tokens_size, penalty_repeat=repeat_penalty, penalty_freq=frequency_penalty, penalty_present=presence_penalty, - # penalize_nl=penalize_nl, - # ignore_eos=False, + penalize_nl=penalize_nl, + ignore_eos=False, ) if grammar is not None: sampler.add_grammar(self._model, grammar) + # Store logits_processor for application just before sampling (Python-side mutation of raw logits) + # This avoids inserting a custom C sampler (which was causing instability / aborts) + self._logits_processor_chain = ( + logits_processor if logits_processor and len(logits_processor) > 0 else None + ) + if temp < 0.0: - sampler.add_softmax() sampler.add_dist(self._seed) elif temp == 0.0: sampler.add_greedy() @@ -748,18 +738,30 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): ) else: n_probs = 0 - min_keep = max(1, n_probs) + # Ensure at least 2 candidates survive filtering so that stochastic sampling (temp>0) can produce variability across calls + min_keep = max(2, n_probs) sampler.add_top_k(top_k) sampler.add_typical(typical_p, min_keep) + sampler.add_top_n_sigma(top_n_sigma) sampler.add_top_p(top_p, min_keep) sampler.add_min_p(min_p, min_keep) sampler.add_temp(temp) sampler.add_dist(self._seed) + sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed) + sampler.add_dry( + self._model, + dry_multiplier, + dry_base, + dry_allowed_length, + dry_penalty_last_n, + dry_seq_breakers, + ) return sampler def sample( self, top_k: int = 40, + top_n_sigma: float = -1.00, top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, @@ -767,11 +769,18 @@ def sample( repeat_penalty: float = 1.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], penalize_nl: bool = True, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, idx: Optional[int] = None, @@ -795,6 +804,7 @@ def sample( tmp_sampler = True self._sampler = self._init_sampler( top_k=top_k, + top_n_sigma=top_n_sigma, top_p=top_p, min_p=min_p, typical_p=typical_p, @@ -802,11 +812,18 @@ def sample( repeat_penalty=repeat_penalty, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, ) @@ -814,6 +831,37 @@ def sample( ridx = idx - self.n_tokens if idx is not None else -1 assert self.ctx is not None + # Obtain logits pointer for processor application (no extra noise injection) + try: + c_ctx = self._ctx.ctx + # Always pull the most recent logits (-1) for deterministic modification + logits_ptr = llama_cpp.llama_get_logits_ith(c_ctx, ctypes.c_int32(-1)) + except Exception: + logits_ptr = None + + # Apply Python-side logits processors (if any) BEFORE sampler distribution sampling + # Apply logits processors if either a persistent chain or per-call processors provided + if ( + getattr(self, "_logits_processor_chain", None) or logits_processor + ) and logits_ptr: + try: + vocab = self._n_vocab + logits_view = np.ctypeslib.as_array(logits_ptr, shape=(vocab,)) + proc_logits = logits_view.copy() + input_ids = self._input_ids + base_chain = ( + list(self._logits_processor_chain) # type: ignore + if getattr(self, "_logits_processor_chain", None) + else [] + ) + # Merge persistent + per-call + processors = base_chain + (list(logits_processor) if logits_processor else []) # type: ignore + for proc in processors: # type: ignore + proc_logits = proc(input_ids, proc_logits) + logits_view[:] = proc_logits + except Exception as e: + if self.verbose: + print(f"logits_processor (pre-sample) error: {e}", file=sys.stderr) token = self._sampler.sample(self._ctx, ridx) if tmp_sampler: self._sampler = None @@ -823,6 +871,7 @@ def generate( self, tokens: Sequence[int], top_k: int = 40, + top_n_sigma: float = -1.00, top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, @@ -831,11 +880,18 @@ def generate( reset: bool = True, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], penalize_nl: bool = True, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, grammar: Optional[LlamaGrammar] = None, @@ -863,6 +919,7 @@ def generate( self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau) self._sampler = self._init_sampler( top_k=top_k, + top_n_sigma=top_n_sigma, top_p=top_p, min_p=min_p, typical_p=typical_p, @@ -870,11 +927,18 @@ def generate( repeat_penalty=repeat_penalty, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, ) @@ -915,6 +979,7 @@ def generate( while sample_idx < self.n_tokens: token = self.sample( top_k=top_k, + top_n_sigma=top_n_sigma, top_p=top_p, min_p=min_p, typical_p=typical_p, @@ -922,10 +987,17 @@ def generate( repeat_penalty=repeat_penalty, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, penalize_nl=penalize_nl, @@ -934,7 +1006,8 @@ def generate( sample_idx += 1 if stopping_criteria is not None and stopping_criteria( - self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :] + self._input_ids[:sample_idx], + self._scores[sample_idx - self.n_tokens, :], ): return tokens_or_none = yield token @@ -945,7 +1018,7 @@ def generate( if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: self.n_tokens = sample_idx - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + self._ctx.memory_seq_rm(0, self.n_tokens, -1) break if self.draft_model is not None: @@ -1041,7 +1114,9 @@ def embed( data: Union[List[List[float]], List[List[List[float]]]] = [] def decode_batch(seq_sizes: List[int]): - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + llama_cpp.llama_memory_clear( + llama_cpp.llama_get_memory(self._ctx.ctx), True + ) self._ctx.decode(self._batch) self._batch.reset() @@ -1112,7 +1187,7 @@ def decode_batch(seq_sizes: List[int]): output = data[0] if isinstance(input, str) else data - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(self._ctx.ctx), True) self.reset() if return_count: @@ -1136,17 +1211,24 @@ def _create_completion( presence_penalty: float = 0.0, repeat_penalty: float = 1.0, top_k: int = 40, + top_n_sigma: float = -1.00, stream: bool = False, seed: Optional[int] = None, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[int, float]] = None, ) -> Union[ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] ]: @@ -1154,83 +1236,48 @@ def _create_completion( completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) - bos_token_id: int = self.token_bos() - cls_token_id: int = self._model.token_cls() + bos_token_id: int = self._model.token_bos() + eos_token_id: int = self._model.token_eos() sep_token_id: int = self._model.token_sep() - prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix - middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix - suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix - add_space_prefix: bool = ( - self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true" - ) - bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id] - eos_tokens: List[int] = [ - sep_token_id if sep_token_id != -1 else self.token_eos() - ] - - if ( - (isinstance(prompt, list) and suffix is None) - or not self._model.add_bos_token() - or bos_tokens[:1] == [-1] - ): - bos_tokens = [] - - if (isinstance(prompt, list) and suffix is None) or ( - not self._model.add_eos_token() and sep_token_id == -1 - ): - eos_tokens = [] - - suffix_space_prefix: int = 0 - # Tokenizer hack to remove leading space - if add_space_prefix and suffix_token_id >= 0 and suffix: - suffix = "☺" + suffix - suffix_space_prefix = 2 - - # If prompt is empty, initialize completion with BOS token to avoid - # detokenization including a space at the beginning of the completion - completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id] - # Add blank space to start of prompt to match OG llama tokenizer - prefix_tokens: List[int] = ( - [prefix_token_id] if prefix_token_id >= 0 and suffix is not None else [] - ) + ( - ( + prefix_token_id: int = self._model.token_fim_pre() + middle_token_id: int = self._model.token_fim_mid() + suffix_token_id: int = self._model.token_fim_suf() + # Simplified prompt handling: do not insert BOS/EOS or space-prefix hacks automatically. + if suffix is None: + if isinstance(prompt, str): + prompt_tokens: List[int] = self.tokenize( + prompt.encode("utf-8"), add_bos=False, special=True + ) + else: + prompt_tokens = list(prompt) + completion_tokens: List[int] = [] + else: + # Preserve explicit infill behavior only when suffix provided. + completion_tokens = [] + prefix_tokens: List[int] = ( + [prefix_token_id] if prefix_token_id >= 0 else [] + ) + ( self.tokenize( - prompt.encode("utf-8"), - add_bos=False, - special=(prefix_token_id < 0 or suffix is None), + prompt.encode("utf-8"), add_bos=False, special=(prefix_token_id < 0) ) - if prompt != "" - else [] + if isinstance(prompt, str) and prompt != "" + else (prompt if isinstance(prompt, list) else []) ) - if isinstance(prompt, str) - else prompt - ) - suffix_tokens: List[int] = ( - ( + suffix_tokens: List[int] = ( [suffix_token_id] + ( - self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[ - suffix_space_prefix: - ] + self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False) if suffix else [] ) + if suffix_token_id >= 0 + else [] ) - if suffix_token_id >= 0 and suffix is not None - else [] - ) - middle_tokens: List[int] = ( - [middle_token_id] if middle_token_id >= 0 and suffix is not None else [] - ) - prompt_tokens: List[int] = ( - bos_tokens - + ( - (suffix_tokens + prefix_tokens + middle_tokens) - if self.spm_infill - else (prefix_tokens + suffix_tokens + middle_tokens) - ) - + eos_tokens - ) + middle_tokens: List[int] = [middle_token_id] if middle_token_id >= 0 else [] + if self.spm_infill: + prompt_tokens = suffix_tokens + prefix_tokens + middle_tokens + else: + prompt_tokens = prefix_tokens + suffix_tokens + middle_tokens text: bytes = b"" returned_tokens: int = 0 stop = ( @@ -1238,11 +1285,7 @@ def _create_completion( ) model_name: str = model if model is not None else self.model_path - if prompt_tokens[:2] == [self.token_bos()] * 2: - warnings.warn( - f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...', - RuntimeWarning, - ) + # Duplicate BOS warning removed (no implicit BOS insertion anymore). # NOTE: This likely doesn't work correctly for the first token in the prompt # because of the extra space added to the start of the prompt_tokens @@ -1312,28 +1355,41 @@ def logit_bias_processor( if self.verbose: print("Llama._create_completion: cache miss", file=sys.stderr) - if seed is not None: - self.set_seed(seed) + # Seeding strategy: + # - If user supplies a seed: use it (deterministic run) + # - Else: advance internal seed so successive completions differ + explicit_seed = seed is not None + if explicit_seed: + self.set_seed(seed) # type: ignore[arg-type] else: - self.set_seed(random.Random(self._seed).randint(0, 2 ** 32)) + next_seed = random.Random(self._seed).randint(0, 2**32) + self.set_seed(next_seed) finish_reason = "length" multibyte_fix = 0 for token in self.generate( prompt_tokens, top_k=top_k, + top_n_sigma=top_n_sigma, top_p=top_p, min_p=min_p, typical_p=typical_p, temp=temperature, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, stopping_criteria=stopping_criteria, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, ): @@ -1521,6 +1577,10 @@ def logit_bias_processor( if self.verbose: self._ctx.print_timings() + # Post-generation: if we used implicit seeding, advance again so the next completion starts from a new seed. + if not explicit_seed: + self._seed = random.Random(self._seed).randint(0, 2**32) + if stream: remaining_tokens = completion_tokens[returned_tokens:] remaining_text = self.detokenize( @@ -1648,7 +1708,12 @@ def logit_bias_processor( text_str = text.decode("utf-8", errors="ignore") if echo: - text_str = prompt + text_str + if isinstance(prompt, str): + text_str = prompt + text_str + else: + # When prompt supplied as token ids, reconstruct its string form for echo + prompt_text = self.detokenize(prompt).decode("utf-8", errors="ignore") + text_str = prompt_text + text_str if suffix_token_id < 0 and suffix is not None: text_str = text_str + suffix @@ -1656,18 +1721,14 @@ def logit_bias_processor( logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: text_offset = 0 if echo else len(prompt) - token_offset = 0 if echo else len(prompt_tokens[1:]) + token_offset = 0 if echo else len(prompt_tokens) text_offsets: List[int] = [] token_logprobs: List[Optional[float]] = [] tokens: List[str] = [] top_logprobs: List[Optional[Dict[str, float]]] = [] if echo: - # Remove leading BOS token if exists - all_tokens = ( - prompt_tokens[1 if prompt_tokens[0] == self.token_bos() else 0 :] - + completion_tokens - ) + all_tokens = prompt_tokens + completion_tokens else: all_tokens = completion_tokens @@ -1756,17 +1817,24 @@ def create_completion( presence_penalty: float = 0.0, repeat_penalty: float = 1.0, top_k: int = 40, + top_n_sigma: float = -1.00, stream: bool = False, seed: Optional[int] = None, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[int, float]] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1776,7 +1844,7 @@ def create_completion( max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. temperature: The temperature to use for sampling. top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. logprobs: The number of logprobs to return. If None, no logprobs are returned. echo: Whether to echo the prompt. @@ -1785,17 +1853,24 @@ def create_completion( presence_penalty: The penalty to apply to tokens based on their presence in the prompt. repeat_penalty: The penalty to apply to repeated tokens. top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled). stream: Whether to stream the results. seed: The seed to use for sampling. - tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. mirostat_mode: The mirostat sampling mode. mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + dry_multiplier: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled. + dry_base`: Set the DRY repetition penalty base value. Default: `1.75` + dry_allowed_length: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2` + dry_penalty_last_n: How many tokens to scan for repetitions. Default: `0`, where `0` is disabled and `-1` is context size. + dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']` model: The name to use for the model in the completion object. stopping_criteria: A list of stopping criteria to use. + logit_bias: A logit bias to use. logits_processor: A list of logits processors to use. grammar: A grammar to use for constrained sampling. - logit_bias: A logit bias to use. Raises: ValueError: If the requested tokens exceed the context window. @@ -1819,17 +1894,24 @@ def create_completion( presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, top_k=top_k, + top_n_sigma=top_n_sigma, stream=stream, seed=seed, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, - logit_bias=logit_bias, ) if stream: chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks @@ -1853,17 +1935,24 @@ def __call__( presence_penalty: float = 0.0, repeat_penalty: float = 1.0, top_k: int = 40, + top_n_sigma: float = -1.00, stream: bool = False, seed: Optional[int] = None, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[int, float]] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1873,7 +1962,7 @@ def __call__( max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. temperature: The temperature to use for sampling. top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. logprobs: The number of logprobs to return. If None, no logprobs are returned. echo: Whether to echo the prompt. @@ -1882,17 +1971,24 @@ def __call__( presence_penalty: The penalty to apply to tokens based on their presence in the prompt. repeat_penalty: The penalty to apply to repeated tokens. top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled). stream: Whether to stream the results. seed: The seed to use for sampling. - tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. mirostat_mode: The mirostat sampling mode. mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + dry_multiplier: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled. + dry_base`: Set the DRY repetition penalty base value. Default: `1.75` + dry_allowed_length: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2` + dry_penalty_last_n: How many tokens to scan for repetitions. Default: `0`, where `0` is disabled and `-1` is context size. + dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']` model: The name to use for the model in the completion object. stopping_criteria: A list of stopping criteria to use. + logit_bias: A logit bias to use. logits_processor: A list of logits processors to use. grammar: A grammar to use for constrained sampling. - logit_bias: A logit bias to use. Raises: ValueError: If the requested tokens exceed the context window. @@ -1916,17 +2012,24 @@ def __call__( presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, top_k=top_k, + top_n_sigma=top_n_sigma, stream=stream, seed=seed, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, - logit_bias=logit_bias, ) def create_chat_completion( @@ -1939,6 +2042,7 @@ def create_chat_completion( temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + top_n_sigma: float = -1.00, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, @@ -1949,14 +2053,20 @@ def create_chat_completion( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.0, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_seq_breakers: list[str] = ["\n", ":", '"', "*"], model: Optional[str] = None, + logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[int, float]] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, ) -> Union[ @@ -1973,7 +2083,8 @@ def create_chat_completion( temperature: The temperature to use for sampling. top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled). + min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. stream: Whether to stream the results. stop: A list of strings to stop generation when encountered. @@ -1983,14 +2094,20 @@ def create_chat_completion( presence_penalty: The penalty to apply to tokens based on their presence in the prompt. frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt. repeat_penalty: The penalty to apply to repeated tokens. - tfs_z: The tail-free sampling parameter. mirostat_mode: The mirostat sampling mode. mirostat_tau: The mirostat sampling tau parameter. mirostat_eta: The mirostat sampling eta parameter. + xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1).XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + dry_multiplier: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled. + dry_base`: Set the DRY repetition penalty base value. Default: `1.75` + dry_allowed_length: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2` + dry_penalty_last_n: How many tokens to scan for repetitions. Default: `0`, where `0` is disabled and `-1` is context size. + dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']` model: The name to use for the model in the completion object. - logits_processor: A list of logits processors to use. - grammar: A grammar to use. logit_bias: A logit bias to use. + logits_processor: A list of logits processors to use. + grammar: A grammar to use for constrained sampling. Returns: Generated chat completion or a stream of chat completion chunks. @@ -2010,6 +2127,7 @@ def create_chat_completion( temperature=temperature, top_p=top_p, top_k=top_k, + top_n_sigma=top_n_sigma, min_p=min_p, typical_p=typical_p, logprobs=logprobs, @@ -2022,14 +2140,20 @@ def create_chat_completion( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, + logit_bias=logit_bias, logits_processor=logits_processor, grammar=grammar, - logit_bias=logit_bias, ) def create_chat_completion_openai_v1( @@ -2086,6 +2210,8 @@ def __getstate__(self): n_threads_batch=self.context_params.n_threads_batch, rope_scaling_type=self.context_params.rope_scaling_type, pooling_type=self.context_params.pooling_type, + attention_type=self.context_params.attention_type, + flash_attn_type=self.context_params.flash_attn_type, rope_freq_base=self.context_params.rope_freq_base, rope_freq_scale=self.context_params.rope_freq_scale, yarn_ext_factor=self.context_params.yarn_ext_factor, @@ -2096,9 +2222,9 @@ def __getstate__(self): logits_all=self._logits_all, embedding=self.context_params.embeddings, offload_kqv=self.context_params.offload_kqv, - flash_attn=self.context_params.flash_attn, op_offload=self.context_params.op_offload, swa_full=self.context_params.swa_full, + kv_unified=self.context_params.kv_unified, # Sampling Params no_perf=self.context_params.no_perf, last_n_tokens_size=self.last_n_tokens_size, @@ -2177,6 +2303,10 @@ def n_embd(self) -> int: """Return the embedding size.""" return self._model.n_embd() + def n_head_kv(self) -> int: + """Return the head_kv size.""" + return self._model.n_head_kv() + def n_vocab(self) -> int: """Return the vocabulary size.""" return self._model.n_vocab() @@ -2185,18 +2315,34 @@ def tokenizer(self) -> LlamaTokenizer: """Return the llama tokenizer for this model.""" return LlamaTokenizer(self) + def token_bos(self) -> int: + """Return the beginning-of-sequence token.""" + return self._model.token_bos() + def token_eos(self) -> int: """Return the end-of-sequence token.""" return self._model.token_eos() - def token_bos(self) -> int: - """Return the beginning-of-sequence token.""" - return self._model.token_bos() + def token_eot(self) -> int: + """Return the end-of-turn token.""" + return self._model.token_eot() + + def token_sep(self) -> int: + """Return the sentence-separator token.""" + return self._model.token_sep() def token_nl(self) -> int: - """Return the newline token.""" + """Return the next-line token.""" return self._model.token_nl() + def token_pad(self) -> int: + """Return the padding token.""" + return self._model.token_pad() + + def token_mask(self) -> int: + """Return the mask token.""" + return self._model.token_mask() + def pooling_type(self) -> str: """Return the pooling type.""" return self._ctx.pooling_type() @@ -2318,7 +2464,11 @@ def from_pretrained( if additional_files: for additonal_file_name in additional_files: # find the additional shard file: - matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)] + matching_additional_files = [ + file + for file in file_list + if fnmatch.fnmatch(file, additonal_file_name) + ] if len(matching_additional_files) == 0: raise ValueError( diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index f738ab9bb..f73d6a1e8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -5,10 +5,10 @@ import json import ctypes import dataclasses +import datetime import random import string -from datetime import datetime from contextlib import ExitStack from typing import ( Any, @@ -34,6 +34,7 @@ import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar +from ._ggml import GGMLLogLevel from ._logger import logger from ._utils import suppress_stdout_stderr, Singleton @@ -55,6 +56,17 @@ # Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" +# Source: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/tokenizer_config.json +LLAMA4_INSTRUCT_BOS_TOKEN = "<|begin_of_text|>" +LLAMA4_INSTRUCT_EOS_TOKEN = "<|eot|>" +LLAMA4_INSTRUCT_CHAT_TEMPLATE = "{% if custom_tools is defined %}\n {% set tools = custom_tools %}\n{% endif %}\n{% if not tools_in_user_message is defined %}\n {% set tools_in_user_message = true %}\n{% endif %}\n{% if not date_string is defined %}\n {% if strftime_now is defined %}\n {% set date_string = strftime_now(\"%d %b %Y\") %}\n {% else %}\n {% set date_string = \"26 Jul 2024\" %}\n {% endif %}\n{% endif %}\n{% if not tools is defined %}\n {% set tools = none %}\n{% endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{% if messages[0]['role'] == 'system' %} \n {% if messages[0]['content'] is string %}\n {% set system_message = messages[0]['content']|trim %}\n {% else %}\n {#- FIXME: The processor requires an array, always. #}\n {% set system_message = messages[0]['content'][0]['text']|trim %}\n {% endif %}\n {% set messages = messages[1:] %}\n {% set user_supplied_system_message = true %}\n{% else %}\n {% set system_message = \"\" %}\n {% set user_supplied_system_message = false %}\n{% endif %}\n\n{#- System message if the user supplied one #}\n{% if user_supplied_system_message %}\n {{ \"<|header_start|>system<|header_end|>\\n\\n\" }}\n {% if tools is not none %}\n {{ \"Environment: ipython\\n\" }}\n {% endif %}\n {% if tools is not none and not tools_in_user_message %}\n {{ \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{ 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{ \"Do not use variables.\\n\\n\" }}\n {% for t in tools %}\n {{ t | tojson(indent=4) }}\n {{ \"\\n\\n\" }}\n {% endfor %}\n {% endif %}\n {{ system_message }}\n {{ \"<|eot|>\" }}\n{% endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{% if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {% if messages | length != 0 %}\n {% set first_user_message = messages[0]['content']|trim %}\n {% set messages = messages[1:] %}\n {% else %}\n {{ raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{% endif %}\n {{ '<|header_start|>user<|header_end|>\\n\\n' -}}\n {{ \"Given the following functions, please respond with a JSON for a function call \" }}\n {{ \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{ 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{ \"Do not use variables.\\n\\n\" }}\n {% for t in tools %}\n {{ t | tojson(indent=4) }}\n {{ \"\\n\\n\" }}\n {% endfor %}\n {{ first_user_message + \"<|eot|>\"}}\n{% endif %}\n\n{% for message in messages %}\n {% if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{ '<|header_start|>' + message['role'] + '<|header_end|>\\n\\n' }}\n {% if message['content'] is string %}\n {{ message['content'] }}\n {% else %}\n {% for content in message['content'] %}\n {% if content['type'] == 'image' %}\n {{ '<|image|>' }}\n {% elif content['type'] == 'text' %}\n {{ content['text'] }}\n {% endif %}\n {% endfor %}\n {% endif %}\n {{ \"<|eot|>\" }}\n {% elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{ '<|header_start|>assistant<|header_end|>\\n\\n' -}}\n {{ '<|python_start|>' }}\n {% if message['content'] is string %}\n {{ message['content'] }}\n {% else %}\n {% for content in message['content'] %}\n {% if content['type'] == 'image' %}\n {{ '<|image|>' }}\n {% elif content['type'] == 'text' %}\n {{ content['text'] }}\n {% endif %}\n {% endfor %}\n {% endif %}\n {{ '<|python_end|>' }}\n {% for tool_call in message.tool_calls %}\n {{ '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{ '\"parameters\": ' }}\n {{ tool_call.function.arguments | tojson }}\n {{ \"}\" }}\n {% endfor %}\n {{ \"<|eot|>\" }}\n {% elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{ \"<|header_start|>ipython<|header_end|>\\n\\n\" }}\n {% if message.content is mapping or message.content is iterable %}\n {{ message.content | tojson }}\n {% else %}\n {{ message.content }}\n {% endif %}\n {{ \"<|eot|>\" }}\n {% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\n {{ '<|header_start|>assistant<|header_end|>\\n\\n' }}\n{% endif %}\n" + + +# Source: https://huggingface.co/openai/gpt-oss-20b/blob/main/tokenizer_config.json +GPT_OSS_BOS_TOKEN = "<|startoftext|>" +GPT_OSS_EOS_TOKEN = "<|return|>" +GPT_OSS_PAD_TOKEN = "<|endoftext|>" + ### Chat Completion Handler ### @@ -79,6 +91,7 @@ def __call__( temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + top_n_sigma: float = -1.00, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, @@ -94,10 +107,16 @@ def __call__( # llama.cpp parameters min_p: float = 0.05, typical_p: float = 1.0, - tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n:int = 0, + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"], logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, logprobs: Optional[bool] = None, @@ -215,10 +234,6 @@ def __init__( lstrip_blocks=True, ).from_string(self.template) - @staticmethod - def strftime_now(f: str) -> str: - return datetime.now().strftime(f) - def __call__( self, *, @@ -232,17 +247,23 @@ def __call__( def raise_exception(message: str): raise ValueError(message) + def strftime_now(format_string="%Y-%m-%d %H:%M:%S") -> str: + """ + Returns the current time formatted as a string. + """ + return datetime.datetime.now().strftime(format_string) + prompt = self._environment.render( messages=messages, eos_token=self.eos_token, bos_token=self.bos_token, raise_exception=raise_exception, + strftime_now=strftime_now, add_generation_prompt=self.add_generation_prompt, functions=functions, function_call=function_call, tools=tools, tool_choice=tool_choice, - strftime_now=self.strftime_now, ) stopping_criteria = None @@ -578,10 +599,17 @@ def chat_completion_handler( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, - tfs_z: float = 1.0, + top_n_sigma: float = -1.00, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n:int = 0, + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"], model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, @@ -681,10 +709,17 @@ def chat_completion_handler( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, stopping_criteria=stopping_criteria, @@ -1021,6 +1056,25 @@ def format_llama3( return ChatFormatterResponse(prompt=_prompt, stop=_sep) +# Chat format for Llama-4 models text only, see more details at: +# https://github.com/meta-llama/llama-models/blob/main/models/llama4/chat_format.py#L61-L316 +@register_chat_format("llama-4") +def format_llama4( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _roles = dict( + system="<|header_start|>system<|header_end|>\n\n", + user="<|header_start|>user<|header_end|>\n\n", + assistant="<|header_start|>assistant<|header_end|>\n\n", + ) + _sep = "<|eot|>" + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_no_colon_single("", _messages, _sep) + return ChatFormatterResponse(prompt=_prompt, stop=_sep) + + @register_chat_format("alpaca") def format_alpaca( messages: List[llama_types.ChatCompletionRequestMessage], @@ -1418,10 +1472,17 @@ def functionary_chat_handler( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, - tfs_z: float = 1.0, + top_n_sigma: float = -1.00, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n:int = 0, + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"], model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, @@ -1624,10 +1685,17 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, @@ -1705,10 +1773,17 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, ) # type: ignore @@ -1777,10 +1852,17 @@ def functionary_v1_v2_chat_handler( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, - tfs_z: float = 1.0, + top_n_sigma: float = -1.00, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n:int = 0, + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"], model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, @@ -1993,10 +2075,17 @@ def prepare_messages_for_inference( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, @@ -2056,10 +2145,17 @@ def create_completion(prompt, stop, grammar): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, @@ -2758,10 +2854,10 @@ def _create_bitmap_from_bytes(self, image_bytes: bytes): (ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)), len(image_bytes) ) - + if bitmap is None: raise ValueError("Failed to create bitmap from image bytes") - + return bitmap def __call__( @@ -2788,10 +2884,17 @@ def __call__( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, - tfs_z: float = 1.0, + top_n_sigma: float = -1.00, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n:int = 0, + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"], model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, @@ -2820,10 +2923,10 @@ def __call__( trim_blocks=True, lstrip_blocks=True, ).from_string(self.CHAT_FORMAT) - + # Get the default media marker media_marker = self._mtmd_cpp.mtmd_default_marker().decode('utf-8') - + # Replace image URLs with media markers in the template text = template.render( messages=messages, @@ -2831,7 +2934,7 @@ def __call__( eos_token=llama.detokenize([llama.token_eos()]), bos_token=llama.detokenize([llama.token_bos()]), ) - + # Replace image URLs in text with media markers for image_url in image_urls: text = text.replace(image_url, media_marker) @@ -2876,45 +2979,45 @@ def __call__( # Reset llama context llama.reset() - llama._ctx.kv_cache_clear() + llama._ctx.memory_clear(True) # Process each chunk n_past = llama_cpp.llama_pos(0) n_chunks = self._mtmd_cpp.mtmd_input_chunks_size(chunks) - + for i in range(n_chunks): chunk = self._mtmd_cpp.mtmd_input_chunks_get(chunks, i) if chunk is None: continue chunk_type = self._mtmd_cpp.mtmd_input_chunk_get_type(chunk) - - if chunk_type == self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_TEXT: + + if chunk_type == self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_TEXT: # Handle text chunk n_tokens_out = ctypes.c_size_t() tokens_ptr = self._mtmd_cpp.mtmd_input_chunk_get_tokens_text( chunk, ctypes.byref(n_tokens_out) ) - + if tokens_ptr and n_tokens_out.value > 0: # Convert ctypes array to Python list tokens = [tokens_ptr[j] for j in range(n_tokens_out.value)] - + if llama.n_tokens + len(tokens) > llama.n_ctx(): raise ValueError( f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}" ) llama.eval(tokens) - - elif chunk_type in [self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE, self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_AUDIO]: + + elif chunk_type in [self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_IMAGE, self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO]: # Handle image/audio chunk using helper chunk_n_tokens = self._mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk) - + if llama.n_tokens + chunk_n_tokens > llama.n_ctx(): raise ValueError( f"Prompt exceeds n_ctx: {llama.n_tokens + chunk_n_tokens} > {llama.n_ctx()}" ) - + new_n_past = llama_cpp.llama_pos(0) result = self._mtmd_cpp.mtmd_helper_eval_chunk_single( self.mtmd_ctx, @@ -2926,10 +3029,10 @@ def __call__( False, # logits_last ctypes.byref(new_n_past) ) - + if result != 0: raise ValueError(f"Failed to evaluate chunk: error code {result}") - + # Update llama's token count llama.n_tokens = new_n_past.value @@ -3010,16 +3113,23 @@ def __call__( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, logit_bias=logit_bias, ) - + if tool is not None: tool_name = tool["function"]["name"] return _convert_completion_to_chat_function( @@ -3427,6 +3537,7 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." CHAT_FORMAT = ( + "{% set image_count = namespace(value=0) %}" "{% for message in messages %}" "{% if loop.first and messages[0]['role'] != 'system' %}" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" @@ -3436,10 +3547,12 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): "{% for content in message['content'] %}" "{% if content.type == 'image_url' %}" "{% if content.image_url is string %}" - "{{ content.image_url }}" + "{% set image_count.value = image_count.value + 1 %}" + "{{ image_count.value }}: {{ content.image_url }}" "{% endif %}" "{% if content.image_url is mapping %}" - "{{ content.image_url.url }}" + "{% set image_count.value = image_count.value + 1 %}" + "{{ image_count.value }}: {{ content.image_url.url }}" "{% endif %}" "{% endif %}" "{% endfor %}" @@ -3461,11 +3574,70 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +class Gemma3ChatHandler(Llava15ChatHandler): + DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." + + GEMMA3_BOI_TOKEN = "" + GEMMA3_EOI_TOKEN = "" + GEMMA3_BOS_TOKEN = "" + GEMMA3_EOS_TOKEN = "" + + CHAT_FORMAT = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% if messages[0]['content'] is string %}" + "{% set first_user_prefix = messages[0]['content'] + '\n\n' %}" + "{% else %}" + "{% set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' %}" + "{% endif %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set first_user_prefix = '' %}" + "{% endif %}" + + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}" + "{% endif %}" + + "{% if message['role'] == 'assistant' %}" + "{% set role = 'model' %}" + "{% else %}" + "{% set role = message['role'] %}" + "{% endif %}" + + "{{ '' + role + '\n' + (first_user_prefix if loop.first else '') }}" + + "{% if message['content'] is string %}" + "{{ message['content'] | trim }}" + "{% elif message['content'] is iterable %}" + "{% for item in message['content'] %}" + "{% if item['type'] == 'image_url' and item['image_url'] is string %}" + "{{ '' + item['image_url'] + '' }}" + "{% elif item['type'] == 'image_url' and item['image_url'] is mapping %}" + "{{ '' + item['image_url']['url'] + '' }}" + "{% elif item['type'] == 'text' %}" + "{{ item['text'] | trim }}" + "{% endif %}" + "{% endfor %}" + "{% else %}" + "{{ raise_exception('Invalid content type') }}" + "{% endif %}" + + "\n" + "{% endfor %}" + + "{% if add_generation_prompt %}" + "model\n" + "{% endif %}" + ) + + class Qwen25VLChatHandler(Llava15ChatHandler): DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." CHAT_FORMAT = ( - #"{% set image_count = namespace(value=0) %}" + "{% set image_count = namespace(value=0) %}" #"{% set video_count = namespace(value=0) %}" "{% for message in messages %}" "{% if loop.first and message['role'] != 'system' %}" @@ -3479,11 +3651,12 @@ class Qwen25VLChatHandler(Llava15ChatHandler): "{% for content in message['content'] %}" "{% if content['type'] == 'image_url' %}" "{% if content.image_url is string %}" - "{{ content.image_url }}" + "{% set image_count.value = image_count.value + 1 %}" + "Picture {{ image_count.value }}: <|vision_start|> {{ content.image_url }} <|vision_end|>" "{% else %}" - "{{ content.image_url.url }}" + "{% set image_count.value = image_count.value + 1 %}" + "Picture {{ image_count.value }}: <|vision_start|> {{ content.image_url.url }} <|vision_end|>" "{% endif %}" - #"{% set image_count.value = image_count.value + 1 %}" "{% elif content['type'] == 'text' %}" "{{ content['text'] }}" "{% endif %}" @@ -3499,7 +3672,7 @@ def __call__(self, **kwargs): # Clear state for multiple runs llama.reset() - llama._ctx.kv_cache_clear() + llama._ctx.memory_clear(True) llama.n_tokens = 0 if hasattr(llama, 'input_ids'): @@ -3519,6 +3692,7 @@ def __call__(self, **kwargs): return super().__call__(**kwargs) + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, @@ -3539,10 +3713,17 @@ def chatml_function_calling( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, - tfs_z: float = 1.0, + top_n_sigma: float = -1.00, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_threshold: float = 0.1, + xtc_probability: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 1.75, + dry_allowed_length: int = 2, + dry_penalty_last_n:int = 0, + dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"], model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, @@ -3670,10 +3851,17 @@ def chatml_function_calling( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, @@ -3723,10 +3911,17 @@ def chatml_function_calling( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, @@ -3767,10 +3962,17 @@ def chatml_function_calling( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=llama_grammar.LlamaGrammar.from_string( @@ -3795,10 +3997,17 @@ def chatml_function_calling( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=llama_grammar.LlamaGrammar.from_string( @@ -3842,10 +4051,17 @@ def chatml_function_calling( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, @@ -3871,10 +4087,17 @@ def chatml_function_calling( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, - tfs_z=tfs_z, + top_n_sigma=top_n_sigma, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_threshold=xtc_threshold, + xtc_probability=xtc_probability, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=llama_grammar.LlamaGrammar.from_string( diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 711d42a6a..9c3719324 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,9 +1,14 @@ from __future__ import annotations -import os import ctypes +import enum +import os import pathlib +from ._ggml import ( + ggml_opt_get_optimizer_params +) + from typing import ( Callable, Union, @@ -73,7 +78,10 @@ # GGML_TYPE_I64 = 27, # GGML_TYPE_F64 = 28, # GGML_TYPE_IQ1_M = 29, -# GGML_TYPE_COUNT, +# GGML_TYPE_BF16 = 30, +# GGML_TYPE_TQ1_0 = 34, +# GGML_TYPE_TQ2_0 = 35, +# GGML_TYPE_COUNT = 39 # }; GGML_TYPE_F32 = 0 GGML_TYPE_F16 = 1 @@ -103,7 +111,10 @@ GGML_TYPE_I64 = 27 GGML_TYPE_F64 = 28 GGML_TYPE_IQ1_M = 29 -GGML_TYPE_COUNT = 30 +GGML_TYPE_BF16 = 30 +GGML_TYPE_TQ1_0 = 34 +GGML_TYPE_TQ2_0 = 35 +GGML_TYPE_COUNT = 39 # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); @@ -161,13 +172,17 @@ llama_context_p = NewType("llama_context_p", int) llama_context_p_ctypes = ctypes.c_void_p -# typedef struct llama_memory_i * llama_memory_t; -llama_memory_t = NewType("llama_memory_t", int) -llama_memory_t_ctypes = ctypes.c_void_p +# # struct llama_sampler; +# llama_sampler_p = NewType("llama_sampler_p", int) +# llama_sampler_p_ctypes = ctypes.c_void_p + +# struct llama_opt_params; +llama_opt_params_p = NewType("llama_opt_params_p", int) +llama_opt_params_p_ctypes = ctypes.c_void_p -# struct llama_kv_cache; (DEPRECATED) -llama_kv_cache_p = NewType("llama_kv_cache_p", int) -llama_kv_cache_p_ctypes = ctypes.c_void_p +# typedef struct llama_memory_i * llama_memory_t; +llama_memory_i_p = NewType("llama_memory_i_p", int) +llama_memory_i_p_ctypes = ctypes.c_void_p # typedef int32_t llama_pos; llama_pos = ctypes.c_int32 @@ -242,6 +257,9 @@ # LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, # LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, # LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, +# LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, +# LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, +# LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, # }; LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0 LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1 @@ -259,7 +277,7 @@ LLAMA_VOCAB_PRE_TYPE_DBRX = 13 LLAMA_VOCAB_PRE_TYPE_SMAUG = 14 LLAMA_VOCAB_PRE_TYPE_PORO = 15 -LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16 +LLAMA_VOCAV_PRE_TYPE_CHATGLM3 = 16 LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17 LLAMA_VOCAB_PRE_TYPE_VIKING = 18 LLAMA_VOCAB_PRE_TYPE_JAIS = 19 @@ -279,6 +297,9 @@ LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33 LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34 LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35 +LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36 +LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37 +LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38 # // note: these values should be synchronized with ggml_rope @@ -429,14 +450,15 @@ # LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, # LLAMA_ROPE_SCALING_TYPE_YARN = 2, # LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, -# LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, +# LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN, # }; -LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1 -LLAMA_ROPE_SCALING_TYPE_NONE = 0 -LLAMA_ROPE_SCALING_TYPE_LINEAR = 1 -LLAMA_ROPE_SCALING_TYPE_YARN = 2 -LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3 -LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE +class llama_rope_scaling_type(enum.IntEnum): + LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1 + LLAMA_ROPE_SCALING_TYPE_NONE = 0 + LLAMA_ROPE_SCALING_TYPE_LINEAR = 1 + LLAMA_ROPE_SCALING_TYPE_YARN = 2 + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3 + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN # enum llama_pooling_type { # LLAMA_POOLING_TYPE_UNSPECIFIED = -1, @@ -458,15 +480,38 @@ # LLAMA_ATTENTION_TYPE_CAUSAL = 0, # LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1, # }; -LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1 -LLAMA_ATTENTION_TYPE_CAUSAL = 0 -LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1 +class llama_attention_type(enum.IntEnum): + LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1 + LLAMA_ATTENTION_TYPE_CAUSAL = 0 + LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1 + +# enum llama_flash_attn_type { +# LLAMA_FLASH_ATTN_TYPE_AUTO = -1, +# LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, +# LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, +# }; +class llama_flash_attn_type(enum.IntEnum): + LLAMA_FLASH_ATTN_TYPE_AUTO = -1 + LLAMA_FLASH_ATTN_TYPE_DISABLED = 0 + LLAMA_FLASH_ATTN_TYPE_ENABLED = 1 +# LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); +@ctypes_function( + "llama_flash_attn_type_name", + [ctypes.c_int], + ctypes.c_char_p, +) +def llama_flash_attn_type_name( + flash_attn_type: llama_flash_attn_type, / +) -> bytes: + """ + Gets the name of a llama_flash_attn_type. + """ # enum llama_split_mode { # LLAMA_SPLIT_MODE_NONE = 0, // single GPU # LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs -# LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported +# LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs # }; LLAMA_SPLIT_MODE_NONE = 0 LLAMA_SPLIT_MODE_LAYER = 1 @@ -507,7 +552,7 @@ class llama_token_data(ctypes.Structure): # llama_token_data * data; # size_t size; # int64_t selected; // this is the index in the data array (i.e. not the token id) -# bool sorted; +# bool sorted; // note: do not assume the data is sorted - always check this flag # } llama_token_data_array; class llama_token_data_array(ctypes.Structure): """Used to sample tokens given logits @@ -551,6 +596,7 @@ class llama_token_data_array(ctypes.Structure): # // - seq_id : the sequence to which the respective token belongs # // (if set to NULL, the sequence ID will be assumed to be 0) # // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output +# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output # // (if set to NULL: # // - if embeddings: all tokens are output # // - if not: only the last token is output @@ -567,7 +613,7 @@ class llama_token_data_array(ctypes.Structure): # int8_t * logits; // TODO: rename this to "output" # } llama_batch; class llama_batch(ctypes.Structure): - """Input data for llama_encode/llama_decode + """Input data for llama_decode A llama_batch object can contain input about one or many sequences @@ -654,24 +700,34 @@ class llama_model_kv_override(ctypes.Structure): key: bytes value: Union[int, float, bool, bytes] - # struct llama_model_tensor_buft_override { # const char * pattern; # ggml_backend_buffer_type_t buft; # }; +class llama_model_tensor_buft_override(ctypes.Structure): + _fields_ = [ + ("pattern", ctypes.c_char_p), + ("buft", ctypes.c_void_p), + ] + if TYPE_CHECKING: + pattern: ctypes.c_char_p + buft: ctypes.c_void_p # struct llama_model_params { # // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) # ggml_backend_dev_t * devices; - +# # // NULL-terminated list of buffer types to use for tensors that match a pattern # const struct llama_model_tensor_buft_override * tensor_buft_overrides; - +# # int32_t n_gpu_layers; // number of layers to store in VRAM # enum llama_split_mode split_mode; // how to split the model across multiple GPUs -# // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE +# // main_gpu interpretation depends on split_mode: +# // LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model +# // LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results +# // LLAMA_SPLIT_MODE_LAYER: ignored # int32_t main_gpu; # // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() @@ -688,11 +744,12 @@ class llama_model_kv_override(ctypes.Structure): # // override key-value pairs of the model meta data # const struct llama_model_kv_override * kv_overrides; + # // Keep the booleans together to avoid misalignment during copy-by-value. -# bool vocab_only; // only load the vocabulary, no weights -# bool use_mmap; // use mmap if possible -# bool use_mlock; // force system to keep model in RAM -# bool check_tensors; // validate model tensor data +# bool vocab_only; // only load the vocabulary, no weights +# bool use_mmap; // use mmap if possible +# bool use_mlock; // force system to keep model in RAM +# bool check_tensors; // validate model tensor data # bool use_extra_bufts; // use extra buffer types (used for weight repacking) # }; class llama_model_params(ctypes.Structure): @@ -700,10 +757,10 @@ class llama_model_params(ctypes.Structure): Attributes: devices (ctypes.Array[ggml_backend_dev_t]): NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) - tensor_buft_overrides (ctypes.Array[llama_model_tensor_buft_override]): NULL-terminated list of buffer types to use for tensors that match a pattern + tensor_buft_overrides(llama_model_tensor_buft_override): NULL-terminated list of buffer types to use for tensors that match a pattern n_gpu_layers (int): number of layers to store in VRAM split_mode (int): how to split the model across multiple GPUs - main_gpu (int): the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE + main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored tensor_split (ctypes.Array[ctypes.ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted. progress_callback_user_data (ctypes.ctypes.c_void_p): context pointer passed to the progress callback @@ -716,7 +773,7 @@ class llama_model_params(ctypes.Structure): if TYPE_CHECKING: devices: CtypesArray[ctypes.c_void_p] # NOTE: unused - tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override] # NOTE: unused + tensor_buft_overrides: ctypes.POINTER(llama_model_tensor_buft_override) n_gpu_layers: int split_mode: int main_gpu: int @@ -732,7 +789,7 @@ class llama_model_params(ctypes.Structure): _fields_ = [ ("devices", ctypes.c_void_p), # NOTE: unnused - ("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused + ("tensor_buft_overrides", ctypes.POINTER(llama_model_tensor_buft_override)), ("n_gpu_layers", ctypes.c_int32), ("split_mode", ctypes.c_int), ("main_gpu", ctypes.c_int32), @@ -761,6 +818,7 @@ class llama_model_params(ctypes.Structure): # enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id # enum llama_attention_type attention_type; // attention type to use for embeddings +# enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention # // ref: https://github.com/ggml-org/llama.cpp/pull/2054 # float rope_freq_base; // RoPE base frequency, 0 = from model @@ -770,14 +828,13 @@ class llama_model_params(ctypes.Structure): # float yarn_beta_fast; // YaRN low correction dim # float yarn_beta_slow; // YaRN high correction dim # uint32_t yarn_orig_ctx; // YaRN original context size -# float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) +# float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, < 0 disabled (default) # ggml_backend_sched_eval_callback cb_eval; # void * cb_eval_user_data; # enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] # enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] - # // Abort callback # // if it returns true, execution of llama_decode() will be aborted # // currently works only with CPU execution @@ -787,11 +844,10 @@ class llama_model_params(ctypes.Structure): # // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU -# bool flash_attn; // use flash attention [EXPERIMENTAL] # bool no_perf; // measure performance timings # bool op_offload; // offload host tensor operations to device # bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) -# // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases +# // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some casesAdd commentMore actions # // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 # bool kv_unified; // use a unified buffer across the input sequences when computing the attention # // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix @@ -810,6 +866,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) attention_type (int): attention type to use for embeddings + flash_attn_type (int): when to enable Flash Attention rope_freq_base (float): RoPE base frequency, 0 = from model rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model @@ -817,7 +874,7 @@ class llama_context_params(ctypes.Structure): yarn_beta_fast (float): YaRN low correction dim yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size - defrag_thold (float): defragment the KV cache if holes/size > thold, <= 0 disabled (default) + defrag_thold (float): [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval type_k (int): data type for K cache @@ -826,11 +883,10 @@ class llama_context_params(ctypes.Structure): abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback embeddings (bool): if true, extract embeddings (together with logits) offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU - flash_attn (bool): whether to use flash attention no_perf (bool): whether to measure performance timings - op_offload (bool): offload host tensor operations to device - swa_full (bool): use full-size SWA cache - kv_unified (bool): use a unified buffer across the input sequences when computing the attention + op_offload(bool): whether to offload host tensor operations to device + swa_full(bool): whether to use full-size SWA cache + kv_unified(bool): use a unified buffer across the input sequences when computing the attention """ if TYPE_CHECKING: @@ -843,6 +899,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type: int pooling_type: int attention_type: int + flash_attn_type: int rope_freq_base: float rope_freq_scale: float yarn_ext_factor: float @@ -859,11 +916,10 @@ class llama_context_params(ctypes.Structure): abort_callback_data: ctypes.c_void_p embeddings: bool offload_kqv: bool - flash_attn: bool no_perf: bool - op_offload: bool - swa_full: bool - kv_unified: bool + op_offload:bool + swa_full:bool + kv_unified:bool _fields_ = [ ("n_ctx", ctypes.c_uint32), @@ -875,6 +931,7 @@ class llama_context_params(ctypes.Structure): ("rope_scaling_type", ctypes.c_int), ("pooling_type", ctypes.c_int), ("attention_type", ctypes.c_int), + ("flash_attn_type", ctypes.c_int), ("rope_freq_base", ctypes.c_float), ("rope_freq_scale", ctypes.c_float), ("yarn_ext_factor", ctypes.c_float), @@ -891,7 +948,6 @@ class llama_context_params(ctypes.Structure): ("abort_callback_data", ctypes.c_void_p), ("embeddings", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), - ("flash_attn", ctypes.c_bool), ("no_perf", ctypes.c_bool), ("op_offload", ctypes.c_bool), ("swa_full", ctypes.c_bool), @@ -917,19 +973,19 @@ class llama_context_params(ctypes.Structure): # // model quantization parameters # typedef struct llama_model_quantize_params { -# int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() -# enum llama_ftype ftype; // quantize to this llama_ftype -# enum ggml_type output_tensor_type; // output tensor type -# enum ggml_type token_embedding_type; // token embeddings tensor type -# bool allow_requantize; // allow quantizing non-f32/f16 tensors -# bool quantize_output_tensor; // quantize output.weight -# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored -# bool pure; // quantize all tensors to the default type -# bool keep_split; // quantize to the same number of shards -# void * imatrix; // pointer to importance matrix data -# void * kv_overrides; // pointer to vector containing overrides -# void * tensor_types; // pointer to vector containing tensor types -# void * prune_layers; // pointer to vector containing layer indices to prune +# int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() +# enum llama_ftype ftype; // quantize to this llama_ftype +# enum ggml_type output_tensor_type; // output tensor type +# enum ggml_type token_embedding_type; // token embeddings tensor type +# bool allow_requantize; // allow quantizing non-f32/f16 tensors +# bool quantize_output_tensor; // quantize output.weight +# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored +# bool pure; // quantize all tensors to the default type +# bool keep_split; // quantize to the same number of shards +# void * imatrix; // pointer to importance matrix data +# void * kv_overrides; // pointer to vector containing overrides +# void * tensor_types; // pointer to vector containing tensor types +# void * prune_layers; // pointer to vector containing layer indices to prune # } llama_model_quantize_params; class llama_model_quantize_params(ctypes.Structure): """Parameters for llama_model_quantize @@ -1089,6 +1145,7 @@ def llama_model_quantize_default_params() -> llama_model_quantize_params: # // Initialize the llama + ggml backend # // If numa is true, use NUMA optimizations # // Call once at the start of the program +# LLAMA_API void llama_backend_init(bool numa); # LLAMA_API void llama_backend_init(void); @ctypes_function( "llama_backend_init", @@ -1097,6 +1154,7 @@ def llama_model_quantize_default_params() -> llama_model_quantize_params: ) def llama_backend_init(): """Initialize the llama + ggml backend + If numa is true, use NUMA optimizations Call once at the start of the program""" ... @@ -1202,7 +1260,7 @@ def llama_model_load_from_file( llama_model_p_ctypes, ) def llama_model_load_from_splits( - paths: List[bytes], n_paths: int, params: llama_model_params, / + paths: list[bytes], n_paths: int, params: llama_model_params, / ) -> Optional[llama_model_p]: """Load the model from multiple splits (support custom naming scheme) @@ -1219,7 +1277,6 @@ def llama_model_load_from_splits( None, ) def llama_model_save_to_file(model: llama_model_p, path_model: bytes, /): - """Save the model to a file""" ... @@ -1392,29 +1449,16 @@ def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: # LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); -@ctypes_function("llama_get_memory", [llama_context_p_ctypes], llama_memory_t_ctypes) -def llama_get_memory(ctx: llama_context_p, /) -> Optional[llama_memory_t]: - """Get the memory for the context""" +@ctypes_function("llama_get_memory", [llama_context_p_ctypes], llama_memory_i_p_ctypes) +def llama_get_memory(ctx: llama_context_p, /) -> Optional[llama_memory_i_p]: ... - -# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); +# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type @ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) def llama_pooling_type(ctx: llama_context_p, /) -> int: ... -# DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); -@ctypes_function( - "llama_get_kv_self", - [llama_context_p_ctypes], - llama_kv_cache_p_ctypes, -) -def llama_get_kv_self(ctx: llama_context_p, /) -> Optional[llama_kv_cache_p]: - """Get the KV cache for self-attention (DEPRECATED)""" - ... - - # LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); @ctypes_function("llama_model_get_vocab", [llama_model_p_ctypes], llama_vocab_p_ctypes) def llama_model_get_vocab(model: llama_model_p, /) -> Optional[llama_vocab_p]: @@ -1451,15 +1495,15 @@ def llama_model_n_head(model: llama_model_p, /) -> int: ... -# LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); -@ctypes_function("llama_model_n_head_kv", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_head_kv(model: llama_model_p, /) -> int: + # LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); +@ctypes_function("llama_model_n_swa", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_swa(model: llama_model_p, /) -> int: ... -# LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); -@ctypes_function("llama_model_n_swa", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_swa(model: llama_model_p, /) -> int: +# LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); +@ctypes_function("llama_model_n_head_kv", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_head_kv(model: llama_model_p, /) -> int: ... @@ -1470,26 +1514,9 @@ def llama_model_rope_freq_scale_train(model: llama_model_p, /) -> float: ... -# // Returns the number of classifier outputs (only valid for classifier models) -# // Undefined behavior for non-classifier models -# LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model); -@ctypes_function("llama_model_n_cls_out", [llama_model_p_ctypes], ctypes.c_uint32) -def llama_model_n_cls_out(model: llama_model_p, /) -> int: - """Returns the number of classifier outputs (only valid for classifier models)""" - ... - - -# // Returns label of classifier output by index ( Optional[bytes]: - """Returns label of classifier output by index. Returns None if no label provided""" - ... - - # LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); -@ctypes_function("llama_vocab_type", [llama_vocab_p_ctypes], ctypes.c_int) -def llama_vocab_type(vocab: llama_vocab_p, /) -> int: +@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) +def llama_vocab_type(model: llama_model_p, /) -> int: ... @@ -1695,10 +1722,6 @@ def llama_model_quantize( ... -# // -# // Adapters -# // - # // Load a LoRA adapter from file # LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( # struct llama_model * model, @@ -1714,6 +1737,93 @@ def llama_adapter_lora_init( ... +# // Functions to access the adapter's GGUF metadata scalar values +# // - The functions return the length of the string on success, or -1 on failure +# // - The output string is always null-terminated and cleared on failure +# // - When retrieving a string, an extra byte must be allocated to account for the null terminator +# // - GGUF array values are not supported by these functions + +# // Get metadata value as a string by key name +# LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size); +@ctypes_function( + "llama_adapter_meta_val_str", + [ + llama_adapter_lora_p_ctypes, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_size_t, + ], + ctypes.c_int32, +) +def llama_adapter_meta_val_str( + adapter: llama_adapter_lora_p, + key: ctypes.c_char_p, + buf: ctypes.c_char_p, + buf_size: ctypes.c_size_t, + /, +) -> int: + """Get metadata value as a string by key name""" + ... + + +# // Get the number of metadata key/value pairs +# LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter); +@ctypes_function( + "llama_adapter_meta_count", + [llama_adapter_lora_p_ctypes], + ctypes.c_int32, +) +def llama_adapter_meta_count(adapter: llama_adapter_lora_p) -> int: + """Get the number of metadata key/value pairs""" + ... + + +# // Get metadata key name by index +# LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); +@ctypes_function( + "llama_adapter_meta_key_by_index", + [ + llama_adapter_lora_p_ctypes, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.c_size_t, + ], + ctypes.c_int32, +) +def llama_adapter_meta_key_by_index( + adapter: llama_adapter_lora_p, + i: ctypes.c_int32, + buf: ctypes.c_char_p, + buf_size: ctypes.c_size_t, + /, +) -> int: + """Get metadata key name by index""" + ... + + +# // Get metadata value as a string by index +# LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); +@ctypes_function( + "llama_adapter_meta_val_str_by_index", + [ + llama_adapter_lora_p_ctypes, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.c_size_t, + ], + ctypes.c_int32, +) +def llama_adapter_meta_val_str_by_index( + adapter: llama_adapter_lora_p, + i: ctypes.c_int32, + buf: ctypes.c_char_p, + buf_size: ctypes.c_size_t, + /, +) -> int: + """Get metadata value as a string by index""" + ... + + # // Manually free a LoRA adapter # // Note: loaded adapters will be free when the associated model is deleted # LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); @@ -1726,6 +1836,27 @@ def llama_adapter_lora_free(adapter: llama_adapter_lora_p, /): ... +# // Get the invocation tokens if the current lora is an alora +# LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); +@ctypes_function( + "llama_adapter_get_alora_n_invocation_tokens", + [llama_adapter_lora_p_ctypes], + ctypes.c_uint64, +) +def llama_adapter_get_alora_n_invocation_tokens(adapter: llama_adapter_lora_p, /) -> ctypes.c_uint64: + ... + + +# LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter); +@ctypes_function( + "llama_adapter_get_alora_invocation_tokens", + [llama_adapter_lora_p_ctypes], + ctypes.c_uint64, +) +def llama_adapter_get_alora_invocation_tokens(adapter: llama_adapter_lora_p, /) -> llama_token_p: + ... + + # // The following functions operate on a llama_context, hence the naming: llama_verb_... @@ -1829,15 +1960,18 @@ def llama_apply_adapter_cvec( # // If data == true, the data buffers will also be cleared together with the metadata # LLAMA_API void llama_memory_clear( # llama_memory_t mem, -# bool data); +# bool data); @ctypes_function( - "llama_memory_clear", - [llama_memory_t_ctypes, ctypes.c_bool], + "llama_memory_clear", [ + llama_memory_i_p_ctypes, + ctypes.c_bool + ], None, ) -def llama_memory_clear(mem: llama_memory_t, data: bool, /): - """Clear the memory contents - If data == true, the data buffers will also be cleared together with the metadata""" +def llama_memory_clear( + mem: llama_memory_i_p, + data: bool +): ... @@ -1848,13 +1982,13 @@ def llama_memory_clear(mem: llama_memory_t, data: bool, /): # // p1 < 0 : [p0, inf) # LLAMA_API bool llama_memory_seq_rm( # llama_memory_t mem, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1); +# llama_seq_id seq_id, +# llama_pos p0, +# llama_pos p1); @ctypes_function( "llama_memory_seq_rm", [ - llama_memory_t_ctypes, + llama_memory_i_p_ctypes, llama_seq_id, llama_pos, llama_pos, @@ -1862,19 +1996,17 @@ def llama_memory_clear(mem: llama_memory_t, data: bool, /): ctypes.c_bool, ) def llama_memory_seq_rm( - mem: llama_memory_t, + mem: llama_memory_i_p, seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], p1: Union[llama_pos, int], /, ) -> bool: """Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - - Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails - - seq_id < 0 : match any sequence - p0 < 0 : [0, p1] - p1 < 0 : [p0, inf)""" + Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails + seq_id < 0 : match any sequence + p0 < 0 : [0, p1] + p1 < 0 : [p0, inf)""" ... @@ -1883,14 +2015,14 @@ def llama_memory_seq_rm( # // p1 < 0 : [p0, inf) # LLAMA_API void llama_memory_seq_cp( # llama_memory_t mem, -# llama_seq_id seq_id_src, -# llama_seq_id seq_id_dst, -# llama_pos p0, -# llama_pos p1); +# llama_seq_id seq_id_src, +# llama_seq_id seq_id_dst, +# llama_pos p0, +# llama_pos p1); @ctypes_function( "llama_memory_seq_cp", [ - llama_memory_t_ctypes, + llama_memory_i_p_ctypes, llama_seq_id, llama_seq_id, llama_pos, @@ -1899,7 +2031,7 @@ def llama_memory_seq_rm( None, ) def llama_memory_seq_cp( - mem: llama_memory_t, + mem: llama_memory_i_p, seq_id_src: Union[llama_seq_id, int], seq_id_dst: Union[llama_seq_id, int], p0: Union[llama_pos, int], @@ -1907,19 +2039,19 @@ def llama_memory_seq_cp( /, ): """Copy all tokens that belong to the specified sequence to another sequence - p0 < 0 : [0, p1] - p1 < 0 : [p0, inf)""" + p0 < 0 : [0, p1] + p1 < 0 : [p0, inf)""" ... # // Removes all tokens that do not belong to the specified sequence # LLAMA_API void llama_memory_seq_keep( # llama_memory_t mem, -# llama_seq_id seq_id); +# llama_seq_id seq_id); @ctypes_function( - "llama_memory_seq_keep", [llama_memory_t_ctypes, llama_seq_id], None + "llama_memory_seq_keep", [llama_memory_i_p_ctypes, llama_seq_id], None ) -def llama_memory_seq_keep(mem: llama_memory_t, seq_id: Union[llama_seq_id, int], /): +def llama_memory_seq_keep(mem: llama_memory_i_p, seq_id: Union[llama_seq_id, int], /): """Removes all tokens that do not belong to the specified sequence""" ... @@ -1929,14 +2061,14 @@ def llama_memory_seq_keep(mem: llama_memory_t, seq_id: Union[llama_seq_id, int], # // p1 < 0 : [p0, inf) # LLAMA_API void llama_memory_seq_add( # llama_memory_t mem, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# llama_pos delta); +# llama_seq_id seq_id, +# llama_pos p0, +# llama_pos p1, +# llama_pos delta); @ctypes_function( "llama_memory_seq_add", [ - llama_memory_t_ctypes, + llama_memory_i_p_ctypes, llama_seq_id, llama_pos, llama_pos, @@ -1945,7 +2077,7 @@ def llama_memory_seq_keep(mem: llama_memory_t, seq_id: Union[llama_seq_id, int], None, ) def llama_memory_seq_add( - mem: llama_memory_t, + mem: llama_memory_i_p, seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], p1: Union[llama_pos, int], @@ -1963,14 +2095,14 @@ def llama_memory_seq_add( # // p1 < 0 : [p0, inf) # LLAMA_API void llama_memory_seq_div( # llama_memory_t mem, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# int d); +# llama_seq_id seq_id, +# llama_pos p0, +# llama_pos p1, +# int d); @ctypes_function( "llama_memory_seq_div", [ - llama_memory_t_ctypes, + llama_memory_i_p_ctypes, llama_seq_id, llama_pos, llama_pos, @@ -1979,7 +2111,7 @@ def llama_memory_seq_add( None, ) def llama_memory_seq_div( - mem: llama_memory_t, + mem: llama_memory_i_p, seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], p1: Union[llama_pos, int], @@ -1987,8 +2119,8 @@ def llama_memory_seq_div( /, ): """Integer division of the positions by factor of `d > 1` - p0 < 0 : [0, p1] - p1 < 0 : [p0, inf)""" + p0 < 0 : [0, p1] + p1 < 0 : [p0, inf)""" ... @@ -1998,16 +2130,24 @@ def llama_memory_seq_div( # // Return -1 if the sequence is empty # LLAMA_API llama_pos llama_memory_seq_pos_min( # llama_memory_t mem, -# llama_seq_id seq_id); +# llama_seq_id seq_id); @ctypes_function( - "llama_memory_seq_pos_min", [llama_memory_t_ctypes, llama_seq_id], llama_pos + "llama_memory_seq_pos_min", + [ + llama_memory_i_p_ctypes, + llama_seq_id, + ], + ctypes.c_int32, ) def llama_memory_seq_pos_min( - mem: llama_memory_t, seq_id: Union[llama_seq_id, int], / -) -> int: + mem: llama_memory_i_p, + seq_id: Union[llama_seq_id, int] + ,/) -> int: """Returns the smallest position present in the memory for the specified sequence - This is typically non-zero only for SWA caches - Return -1 if the sequence is empty""" + This is typically non-zero only for SWA caches + Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + Return -1 if the sequence is empty + """ ... @@ -2016,275 +2156,32 @@ def llama_memory_seq_pos_min( # // Return -1 if the sequence is empty # LLAMA_API llama_pos llama_memory_seq_pos_max( # llama_memory_t mem, -# llama_seq_id seq_id); +# llama_seq_id seq_id); @ctypes_function( - "llama_memory_seq_pos_max", [llama_memory_t_ctypes, llama_seq_id], llama_pos + "llama_memory_seq_pos_max", + [ + llama_memory_i_p_ctypes, + llama_seq_id, + ], + ctypes.c_int32, ) def llama_memory_seq_pos_max( - mem: llama_memory_t, seq_id: Union[llama_seq_id, int], / -) -> int: + mem: llama_memory_i_p, + seq_id: Union[llama_seq_id, int] + ,/) -> int: """Returns the largest position present in the memory for the specified sequence - Return -1 if the sequence is empty""" + This is typically non-zero only for SWA caches + Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + Return -1 if the sequence is empty + """ ... # // Check if the memory supports shifting # LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); -@ctypes_function("llama_memory_can_shift", [llama_memory_t_ctypes], ctypes.c_bool) -def llama_memory_can_shift(mem: llama_memory_t, /) -> bool: - """Check if the memory supports shifting""" - ... - - -# // -# // KV cache for self-attention (TODO: deprecate in favor of llama_memory) -# // - -# // Returns the number of tokens in the KV cache (slow, use only for debug) -# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times -# DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), -# "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); @ctypes_function( - "llama_kv_self_n_tokens", [llama_context_p_ctypes], ctypes.c_int32 -) -def llama_kv_self_n_tokens(ctx: llama_context_p, /) -> int: - """Returns the number of tokens in the KV cache (slow, use only for debug) (DEPRECATED)""" - ... - - -# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) -# DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), -# "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); -@ctypes_function( - "llama_kv_self_used_cells", [llama_context_p_ctypes], ctypes.c_int32 -) -def llama_kv_self_used_cells(ctx: llama_context_p, /) -> int: - """Returns the number of used KV cells (DEPRECATED)""" - ... - - -# // Clear the KV cache - both cell info is erased and KV data is zeroed -# DEPRECATED(LLAMA_API void llama_kv_self_clear( -# struct llama_context * ctx), -# "Use llama_memory_clear() instead"); -@ctypes_function( - "llama_kv_self_clear", [llama_context_p_ctypes], None -) -def llama_kv_self_clear(ctx: llama_context_p, /): - """Clear the KV cache (DEPRECATED)""" - ... - - -# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) -# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails -# // seq_id < 0 : match any sequence -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1), -# "Use llama_memory_seq_rm() instead"); -@ctypes_function( - "llama_kv_self_seq_rm", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ], - ctypes.c_bool, -) -def llama_kv_self_seq_rm( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - /, -) -> bool: - """Remove tokens from KV cache (DEPRECATED)""" - ... - - -# // Copy all tokens that belong to the specified sequence to another sequence -# // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( -# struct llama_context * ctx, -# llama_seq_id seq_id_src, -# llama_seq_id seq_id_dst, -# llama_pos p0, -# llama_pos p1), -# "Use llama_memory_seq_cp() instead"); -@ctypes_function( - "llama_kv_self_seq_cp", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_seq_id, - llama_pos, - llama_pos, - ], - None, -) -def llama_kv_self_seq_cp( - ctx: llama_context_p, - seq_id_src: Union[llama_seq_id, int], - seq_id_dst: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - /, -): - """Copy tokens in KV cache (DEPRECATED)""" - ... - - -# // Removes all tokens that do not belong to the specified sequence -# DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_keep() instead"); -@ctypes_function( - "llama_kv_self_seq_keep", [llama_context_p_ctypes, llama_seq_id], None -) -def llama_kv_self_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /): - """Keep only specified sequence in KV cache (DEPRECATED)""" - ... - - -# // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) -# // If the KV cache is RoPEd, the KV data is updated accordingly: -# // - lazily on next llama_decode() -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_add( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# llama_pos delta), -# "Use llama_memory_seq_add() instead"); -@ctypes_function( - "llama_kv_self_seq_add", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - llama_pos, - ], - None, -) -def llama_kv_self_seq_add( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - delta: Union[llama_pos, int], - /, -): - """Add delta to sequence positions in KV cache (DEPRECATED)""" - ... - - -# // Integer division of the positions by factor of `d > 1` -# // If the KV cache is RoPEd, the KV data is updated accordingly: -# // - lazily on next llama_decode() -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_div( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# int d), -# "Use llama_memory_seq_div() instead"); -@ctypes_function( - "llama_kv_self_seq_div", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ctypes.c_int, - ], - None, -) -def llama_kv_self_seq_div( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - d: Union[ctypes.c_int, int], - /, -): - """Divide sequence positions in KV cache (DEPRECATED)""" - ... - - -# // Returns the smallest position present in the KV cache for the specified sequence -# // This is typically non-zero only for SWA caches -# // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache -# // Return -1 if the sequence is empty -# DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_pos_min() instead"); -@ctypes_function( - "llama_kv_self_seq_pos_min", [llama_context_p_ctypes, llama_seq_id], llama_pos -) -def llama_kv_self_seq_pos_min( - ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / -) -> int: - """Returns the smallest position in KV cache for sequence (DEPRECATED)""" - ... - - -# // Returns the largest position present in the KV cache for the specified sequence -# // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache -# // Return -1 if the sequence is empty -# DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_pos_max() instead"); -@ctypes_function( - "llama_kv_self_seq_pos_max", [llama_context_p_ctypes, llama_seq_id], llama_pos -) -def llama_kv_self_seq_pos_max( - ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / -) -> int: - """Returns the largest position in KV cache for sequence (DEPRECATED)""" - ... - - -# // Defragment the KV cache -# // This will be applied: -# // - lazily on next llama_decode() -# DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), -# "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); -@ctypes_function("llama_kv_self_defrag", [llama_context_p_ctypes], None) -def llama_kv_self_defrag(ctx: llama_context_p, /): - """Defragment the KV cache (DEPRECATED)""" - ... - - -# // Check if the context supports KV cache shifting -# DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), -# "use llama_memory_can_shift() instead"); -@ctypes_function("llama_kv_self_can_shift", [llama_context_p_ctypes], ctypes.c_bool) -def llama_kv_self_can_shift(ctx: llama_context_p, /) -> bool: - """Check if the context supports KV cache shifting (DEPRECATED)""" - ... - - -# // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) -# DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), -# "simply remove this call, updates are applied lazily on the next llama_decode()"); -@ctypes_function("llama_kv_self_update", [llama_context_p_ctypes], None) -def llama_kv_self_update(ctx: llama_context_p, /): - """Apply the KV cache updates (DEPRECATED)""" + "llama_memory_can_shift", [llama_memory_i_p_ctypes], ctypes.c_bool) +def llama_memory_can_shift(mem: llama_memory_i_p, /) -> bool: ... @@ -2292,13 +2189,14 @@ def llama_kv_self_update(ctx: llama_context_p, /): # // State / sessions # // + # // Returns the *actual* size in bytes of the state # // (logits, embedding and memory) # // Only use when saving the state, not when restoring it, otherwise the size may be too small. # LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); @ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t) def llama_state_get_size(ctx: llama_context_p, /) -> int: - """Returns the *actual* size in bytes of the state (logits, embedding and memory)""" + """Returns the *actual* size in bytes of the state (rng, logits, embedding and memory) - will often be smaller after compacting tokens""" ... @@ -2306,7 +2204,8 @@ def llama_state_get_size(ctx: llama_context_p, /) -> int: # "use llama_state_get_size instead"); @ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t) def llama_get_state_size(ctx: llama_context_p, /) -> int: - """Returns the size in bytes of the state (DEPRECATED)""" + """Returns the maximum size in bytes of the state (rng, logits, embedding + and kv_cache) - will often be smaller after compacting tokens""" ... @@ -2353,7 +2252,9 @@ def llama_state_get_data( def llama_copy_state_data( ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], / ) -> int: - """Copies the state to the specified destination address (DEPRECATED)""" + """Copies the state to the specified destination address. + Destination needs to have allocated enough memory. + Returns the number of bytes copied""" ... @@ -2391,7 +2292,7 @@ def llama_state_set_data( def llama_set_state_data( ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], / ) -> int: - """Set the state reading from the specified address (DEPRECATED)""" + """Set the state reading from the specified address""" ... @@ -2440,7 +2341,7 @@ def llama_state_load_file( ctypes.c_size_t, ctypes.POINTER(ctypes.c_size_t), ], - ctypes.c_bool, + ctypes.c_size_t, ) def llama_load_session_file( ctx: llama_context_p, @@ -2449,7 +2350,7 @@ def llama_load_session_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> bool: +) -> int: ... @@ -2492,7 +2393,7 @@ def llama_state_save_file( llama_token_p, ctypes.c_size_t, ], - ctypes.c_bool, + ctypes.c_size_t, ) def llama_save_session_file( ctx: llama_context_p, @@ -2500,7 +2401,7 @@ def llama_save_session_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> bool: +) -> int: ... @@ -2571,7 +2472,7 @@ def llama_state_seq_set_data( dest_seq_id: llama_seq_id, /, ) -> int: - """Copy the sequence data into the specified sequence""" + """Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence""" ... @@ -2638,6 +2539,7 @@ def llama_state_seq_load_file( # // Decoding # // + # // Return batch for single sequence of tokens # // The sequence ID will be fixed to 0 # // The position of the tokens will be tracked automatically by llama_decode @@ -2660,7 +2562,7 @@ def llama_batch_get_one( n_tokens: Union[ctypes.c_int, int], /, ) -> llama_batch: - """Return batch for single sequence of tokens + """Return batch for single sequence of tokens starting at pos_0 NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it """ @@ -2682,9 +2584,9 @@ def llama_batch_get_one( "llama_batch_init", [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32], llama_batch ) def llama_batch_init( - n_tokens: Union[ctypes.c_int32, int], - embd: Union[ctypes.c_int32, int], - n_seq_max: Union[ctypes.c_int32, int], + n_tokens: ctypes.c_int32, + embd: ctypes.c_int32, + n_seq_max: ctypes.c_int32, /, ) -> llama_batch: """Allocates a batch of tokens on the heap that can hold a maximum of n_tokens @@ -2716,7 +2618,8 @@ def llama_batch_free(batch: llama_batch, /): # struct llama_batch batch); @ctypes_function("llama_encode", [llama_context_p_ctypes, llama_batch], ctypes.c_int32) def llama_encode(ctx: llama_context_p, batch: llama_batch, /) -> int: - """Process a batch of tokens using the encoder. + """Processes a batch of tokens with the ecoder part of the encoder-decoder model. + Stores the encoder output internally for later use by the decoder cross-attention layers. 0 - success < 0 - error""" ... @@ -2736,15 +2639,16 @@ def llama_encode(ctx: llama_context_p, batch: llama_batch, /) -> int: # // < -1 - fatal error (processed ubatches will remain in the context's memory) # LLAMA_API int32_t llama_decode( # struct llama_context * ctx, -# struct llama_batch batch); +# struct llama_batch batch); @ctypes_function("llama_decode", [llama_context_p_ctypes, llama_batch], ctypes.c_int32) def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int: - """Process a batch of tokens. - 0 - success - 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - 2 - aborted (processed ubatches will remain in the context's memory) - -1 - invalid input batch - < -1 - fatal error (processed ubatches will remain in the context's memory)""" + """Positive return values does not mean a fatal error, but rather a warning. + 0 - success + 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + 2 - aborted (processed ubatches will remain in the context's memory) + -1 - invalid input batch + < -1 - fatal error (processed ubatches will remain in the context's memory) + """ ... @@ -2790,12 +2694,14 @@ def llama_n_threads_batch(ctx: llama_context_p, /) -> int: ... -# // Set whether the context outputs embeddings or not +# // Set whether the model is in embeddings mode or not # // TODO: rename to avoid confusion with llama_get_embeddings() # LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); @ctypes_function("llama_set_embeddings", [llama_context_p_ctypes, ctypes.c_bool], None) def llama_set_embeddings(ctx: llama_context_p, embeddings: bool, /): - """Set whether the context outputs embeddings or not""" + """ + Set whether the model is in embeddings model or not + """ ... @@ -2808,17 +2714,15 @@ def llama_set_causal_attn(ctx: llama_context_p, causal_attn: bool, /): If set to true, the model will only attend to the past tokens""" ... - # // Set whether the model is in warmup mode or not # // If true, all model tensors are activated during llama_decode() to load and cache their weights. # LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); @ctypes_function("llama_set_warmup", [llama_context_p_ctypes, ctypes.c_bool], None) def llama_set_warmup(ctx: llama_context_p, warmup: bool, /): - """Set whether the model is in warmup mode or not - If true, all model tensors are activated during llama_decode() to load and cache their weights.""" + """ Set whether the model is in warmup mode or not + If true, all model tensors are activated during llama_decode() to load and cache their weights""" ... - # // Set abort callback # LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); @ctypes_function( @@ -2881,10 +2785,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]: ctypes.POINTER(ctypes.c_float), ) def llama_get_logits_ith( - ctx: llama_context_p, i: Union[ctypes.c_int32, int], / -) -> CtypesArray[ctypes.c_float]: + ctx: llama_context_p, i: ctypes.c_int32, / +) -> ctypes.POINTER(ctypes.c_float): """Logits for the ith token. Equivalent to: - llama_get_logits(ctx) + i*n_vocab""" + llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab""" ... @@ -2926,7 +2830,7 @@ def llama_get_embeddings_ith( # // Get the embeddings for a sequence id # // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE -# // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence +# // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence # // otherwise: float[n_embd] (1-dimensional) # LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); @ctypes_function( @@ -2947,6 +2851,7 @@ def llama_get_embeddings_seq( # // Vocab # // + # LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); @ctypes_function( "llama_vocab_get_text", [llama_vocab_p_ctypes, llama_token], ctypes.c_char_p @@ -3000,6 +2905,8 @@ def llama_vocab_is_control( # // Special tokens + + # LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence @ctypes_function("llama_vocab_bos", [llama_vocab_p_ctypes], llama_token) def llama_vocab_bos(vocab: llama_vocab_p, /) -> llama_token: @@ -3139,7 +3046,7 @@ def llama_vocab_fim_sep(vocab: llama_vocab_p, /) -> llama_token: ... -# DEPRECATED functions + # DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); @ctypes_function( "llama_token_get_text", @@ -3353,6 +3260,7 @@ def llama_vocab_cls(vocab: llama_vocab_p, /) -> llama_token: # // The API is thread-safe. # // + # /// @details Convert the provided text into tokens. # /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. # /// @return Returns the number of tokens on success, no more than n_tokens_max @@ -3400,7 +3308,7 @@ def llama_tokenize( text_len: The length of the text. tokens: The tokens pointer must be large enough to hold the resulting tokens. n_max_tokens: The maximum number of tokens to return. - add_special: Allow adding special tokens if the model is configured to do so. + add_special: Allow adding special tokenns if the model is configured to do so. parse_special: Allow parsing special tokens. Returns: @@ -3458,6 +3366,23 @@ def llama_token_to_piece( ... +# # // check if token0 is contained as a prefix in token1 +# # LLAMA_API bool llama_token_is_prefix( +# # const struct llama_model * model, +# # llama_token token0, +# # llama_token token1); +# @ctypes_function( +# "llama_token_is_prefix", +# [llama_model_p_ctypes, llama_token, llama_token], +# ctypes.c_bool, +# ) +# def llama_token_is_prefix( +# model: llama_model_p, token0: Union[llama_token, int], token1: Union[llama_token, int], / +# ) -> bool: +# """Check if token0 is contained as a prefix in token1""" +# ... + + # /// @details Convert the provided tokens into text (inverse of llama_tokenize()). # /// @param text The char pointer must be large enough to hold the resulting text. # /// @return Returns the number of chars/bytes on success, no more than text_len_max. @@ -3465,7 +3390,7 @@ def llama_token_to_piece( # /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. # /// @param unparse_special If true, special tokens are rendered in the output. # LLAMA_API int32_t llama_detokenize( -# const struct llama_vocab * vocab, +# const struct llama_model * model, # const llama_token * tokens, # int32_t n_tokens, # char * text, @@ -3475,8 +3400,8 @@ def llama_token_to_piece( @ctypes_function( "llama_detokenize", [ - llama_vocab_p_ctypes, - ctypes.POINTER(llama_token), + llama_model_p_ctypes, + llama_token_p, ctypes.c_int32, ctypes.c_char_p, ctypes.c_int32, @@ -3486,7 +3411,7 @@ def llama_token_to_piece( ctypes.c_int32, ) def llama_detokenize( - vocab: llama_vocab_p, + model: llama_model_p, tokens: CtypesArray[llama_token], n_tokens: Union[ctypes.c_int, int], text: bytes, @@ -3498,7 +3423,7 @@ def llama_detokenize( """Convert the provided tokens into text (inverse of llama_tokenize()). Args: - vocab: The vocabulary to use for tokenization. + model: The model to use for tokenization. tokens: The tokens to convert. n_tokens: The number of tokens. text: The buffer to write the text to. @@ -3512,10 +3437,11 @@ def llama_detokenize( # // Chat templates # // + # /// Apply chat template. Inspired by hf apply_chat_template() on python. # /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" -# /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template -# /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model's default chat template will be used instead. +# /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template +# /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. # /// @param chat Pointer to a list of multiple llama_chat_message # /// @param n_msg Number of llama_chat_message in this chat # /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. @@ -3597,6 +3523,41 @@ def llama_chat_builtin_templates( # // # // Sampling API # // +# // Sample usage: +# // +# // // prepare the sampling chain at the start +# // auto sparams = llama_sampler_chain_default_params(); +# // +# // llama_sampler * smpl = llama_sampler_chain_init(sparams); +# // +# // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); +# // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); +# // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); +# // +# // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat" +# // // this sampler will be responsible to select the actual token +# // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed)); +# // +# // ... +# // +# // // decoding loop: +# // while (...) { +# // ... +# // +# // llama_decode(ctx, batch); +# // +# // // sample from the logits of the last token in the batch +# // const llama_token id = llama_sampler_sample(smpl, ctx, -1); +# // +# // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.) +# // llama_sampler_accept(smpl, id); +# // ... +# // } +# // +# // llama_sampler_free(smpl); +# // +# // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). +# // # typedef void * llama_sampler_context_t; llama_sampler_context_t = ctypes.c_void_p @@ -3610,7 +3571,7 @@ def llama_chat_builtin_templates( # void (*reset) ( struct llama_sampler * smpl); // can be NULL # struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL # void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - +# # // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph # //void (*apply_ggml) (struct llama_sampler * smpl, ...); # }; @@ -3619,8 +3580,8 @@ class llama_sampler_i(ctypes.Structure): # struct llama_sampler { -# const struct llama_sampler_i * iface; -# llama_sampler_context_t ctx; +# const struct llama_sampler_i * iface; +# llama_sampler_context_t ctx; # }; class llama_sampler(ctypes.Structure): _fields_ = [ @@ -3731,7 +3692,7 @@ def llama_sampler_free(smpl: llama_sampler_p, /): # // llama_sampler_chain # // a type of llama_sampler that can chain multiple samplers one after another - +# # LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); @ctypes_function( "llama_sampler_chain_init", @@ -3789,7 +3750,7 @@ def llama_sampler_chain_remove( # // available samplers: - +# # LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); @ctypes_function("llama_sampler_init_greedy", [], llama_sampler_p_ctypes) def llama_sampler_init_greedy() -> llama_sampler_p: @@ -3802,15 +3763,6 @@ def llama_sampler_init_dist(seed: int) -> llama_sampler_p: ... -# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. -# /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. -# DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), -# "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); -@ctypes_function("llama_sampler_init_softmax", [], llama_sampler_p_ctypes) -def llama_sampler_init_softmax() -> llama_sampler_p: - ... - - # /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # /// Setting k <= 0 makes this a noop # LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -3830,7 +3782,7 @@ def llama_sampler_init_top_p(p: float, min_keep: int) -> llama_sampler_p: ... -# /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 +# /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 # LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); @ctypes_function( "llama_sampler_init_min_p", @@ -3852,7 +3804,6 @@ def llama_sampler_init_typical(p: float, min_keep: int) -> llama_sampler_p: ... -# /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf # LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); @ctypes_function("llama_sampler_init_temp", [ctypes.c_float], llama_sampler_p_ctypes) def llama_sampler_init_temp(t: float) -> llama_sampler_p: @@ -3897,6 +3848,11 @@ def llama_sampler_init_top_n_sigma(n: float, /) -> llama_sampler_p: # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +# /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +# /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +# /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +# /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. # LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( # int32_t n_vocab, # uint32_t seed, @@ -3915,6 +3871,10 @@ def llama_sampler_init_mirostat( # /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +# /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +# /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +# /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. # LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( # uint32_t seed, # float tau, @@ -3931,6 +3891,9 @@ def llama_sampler_init_mirostat_v2( # /// @details Intializes a GBNF grammar, see grammars/README.md for details. +# /// @param vocab The vocabulary that this grammar will be used with. +# /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. +# /// @param grammar_root The name of the start symbol for the grammar. # LLAMA_API struct llama_sampler * llama_sampler_init_grammar( # const struct llama_vocab * vocab, # const char * grammar_str, @@ -3946,42 +3909,9 @@ def llama_sampler_init_grammar( ... -# DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( -# const struct llama_vocab * vocab, -# const char * grammar_str, -# const char * grammar_root, -# const char ** trigger_words, -# size_t num_trigger_words, -# const llama_token * trigger_tokens, -# size_t num_trigger_tokens), -# "use llama_sampler_init_grammar_lazy_patterns instead"); -@ctypes_function( - "llama_sampler_init_grammar_lazy", - [ - llama_vocab_p_ctypes, - ctypes.c_char_p, - ctypes.c_char_p, - ctypes.POINTER(ctypes.c_char_p), - ctypes.c_size_t, - ctypes.POINTER(llama_token), - ctypes.c_size_t, - ], - llama_sampler_p_ctypes, -) -def llama_sampler_init_grammar_lazy( - vocab: llama_vocab_p, - grammar_str: bytes, - grammar_root: bytes, - trigger_words: CtypesArray[bytes], - num_trigger_words: int, - trigger_tokens: CtypesArray[llama_token], - num_trigger_tokens: int, - /, -) -> llama_sampler_p: - ... - - # /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 +# /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. +# /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. # LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( # const struct llama_vocab * vocab, # const char * grammar_str, @@ -3998,7 +3928,7 @@ def llama_sampler_init_grammar_lazy( ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p), ctypes.c_size_t, - ctypes.POINTER(llama_token), + llama_token_p, ctypes.c_size_t, ], llama_sampler_p_ctypes, @@ -4056,7 +3986,7 @@ def llama_sampler_init_penalties( ctypes.c_float, ctypes.c_int32, ctypes.c_int32, - ctypes.POINTER(ctypes.c_char_p), + ctypes.POINTER(ctypes.POINTER(ctypes.c_char)), ctypes.c_size_t, ], llama_sampler_p_ctypes, @@ -4068,7 +3998,7 @@ def llama_sampler_init_dry( dry_base: float, dry_allowed_length: int, dry_penalty_last_n: int, - seq_breakers, + seq_breakers: CtypesArray[bytes], num_breakers: int, /, ) -> llama_sampler_p: @@ -4091,6 +4021,26 @@ def llama_sampler_init_logit_bias( # // this sampler is meant to be used for fill-in-the-middle infilling +# // it's supposed to be used after top_k + top_p sampling +# // +# // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG +# // 2. combine probs of tokens that have the same prefix +# // +# // example: +# // +# // - before: +# // "hel": 0.5 +# // "hell": 0.2 +# // "hello": 0.1 +# // "dummy": 0.1 +# // +# // - after: +# // "hel": 0.8 +# // "dummy": 0.1 +# // +# // 3. discard non-EOG tokens with low prob +# // 4. if no tokens are left -> pick EOT +# // # LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); @ctypes_function( "llama_sampler_init_infill", @@ -4113,6 +4063,15 @@ def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int: # /// @details Sample and accept a token from the idx-th output of the last evaluation +# // +# // Shorthand for: +# // const auto * logits = llama_get_logits_ith(ctx, idx); +# // llama_token_data_array cur_p = { ... init from logits ... }; +# // llama_sampler_apply(smpl, &cur_p); +# // auto token = cur_p.data[cur_p.selected].id; +# // llama_sampler_accept(smpl, token); +# // return token; +# // Returns the sampled token # LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); @ctypes_function( "llama_sampler_sample", @@ -4120,8 +4079,8 @@ def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int: llama_token, ) def llama_sampler_sample( - smpl: llama_sampler_p, ctx: llama_context_p, idx: int, / -) -> int: + smpl: llama_sampler_p, ctx: llama_context_p, idx: ctypes.c_int32, / +) -> ctypes.c_int32: ... @@ -4129,7 +4088,10 @@ def llama_sampler_sample( # // Model split # // + # /// @details Build a split GGUF final path for this chunk. +# /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" +# // Returns the split_path length. # LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); @ctypes_function( "llama_split_path", @@ -4149,6 +4111,8 @@ def llama_split_path( # /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. +# /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" +# // Returns the split_prefix length. # LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); @ctypes_function( "llama_split_prefix", @@ -4196,13 +4160,16 @@ def llama_log_set( # // # // Performance utils # // +# // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements. +# // + # struct llama_perf_context_data { # double t_start_ms; # double t_load_ms; # double t_p_eval_ms; # double t_eval_ms; - +# # int32_t n_p_eval; # int32_t n_eval; # int32_t n_reused; // number of times a ggml compute graph had been reused @@ -4221,7 +4188,7 @@ class llama_perf_context_data(ctypes.Structure): # struct llama_perf_sampler_data { # double t_sample_ms; - +# # int32_t n_sample; # }; class llama_perf_sampler_data(ctypes.Structure): @@ -4298,19 +4265,20 @@ def llama_perf_sampler_reset(chain: llama_sampler_p, /): # // function that returns whether or not a given tensor contains trainable parameters # typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); -llama_opt_param_filter = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p) +llama_opt_param_filter = ctypes.CFUNCTYPE( + ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p +) + # // always returns true # LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); -@ctypes_function( - "llama_opt_param_filter_all", - [ctypes.c_void_p, ctypes.c_void_p], - ctypes.c_bool, -) -def llama_opt_param_filter_all(tensor: ctypes.c_void_p, userdata: ctypes.c_void_p, /) -> bool: +@ctypes_function("llama_opt_param_filter_all", [ctypes.c_void_p, ctypes.c_void_p], ctypes.c_bool) +def llama_opt_param_filter_all( + tensor: llama_model_p, + userdata: ctypes.c_void_p, / +) -> bool: ... - # struct llama_opt_params { # uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0 @@ -4325,7 +4293,7 @@ class llama_opt_params(ctypes.Structure): ("n_ctx_train", ctypes.c_uint32), ("param_filter", llama_opt_param_filter), ("param_filter_ud", ctypes.c_void_p), - ("get_opt_pars", ctypes.c_void_p), # ggml_opt_get_optimizer_params - not implemented here + ("get_opt_pars", ggml_opt_get_optimizer_params), ("get_opt_pars_ud", ctypes.c_void_p), ] @@ -4333,13 +4301,16 @@ class llama_opt_params(ctypes.Structure): # LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); @ctypes_function( "llama_opt_init", - [llama_context_p_ctypes, llama_model_p_ctypes, llama_opt_params], + [llama_context_p_ctypes, llama_model_p_ctypes, llama_opt_params_p_ctypes], None, ) -def llama_opt_init(lctx: llama_context_p, model: llama_model_p, lopt_params: llama_opt_params, /): +def llama_opt_init( + lctx: llama_context_p, + model: llama_model_p, + lopt_params: llama_opt_params_p, / +): ... - # LLAMA_API void llama_opt_epoch( # struct llama_context * lctx, # ggml_opt_dataset_t dataset, @@ -4349,15 +4320,14 @@ def llama_opt_init(lctx: llama_context_p, model: llama_model_p, lopt_params: lla # ggml_opt_epoch_callback callback_train, # ggml_opt_epoch_callback callback_eval); @ctypes_function( - "llama_opt_epoch", - [ + "llama_opt_epoch",[ llama_context_p_ctypes, - ctypes.c_void_p, # ggml_opt_dataset_t - ctypes.c_void_p, # ggml_opt_result_t - ctypes.c_void_p, # ggml_opt_result_t + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, ctypes.c_int64, - ctypes.c_void_p, # ggml_opt_epoch_callback - ctypes.c_void_p, # ggml_opt_epoch_callback + ctypes.c_void_p, + ctypes.c_void_p ], None, ) @@ -4366,9 +4336,8 @@ def llama_opt_epoch( dataset: ctypes.c_void_p, result_train: ctypes.c_void_p, result_eval: ctypes.c_void_p, - idata_split: int, + idata_split: ctypes.c_int64, callback_train: ctypes.c_void_p, - callback_eval: ctypes.c_void_p, - /, + callback_eval: ctypes.c_void_p, / ): ... diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index b95c77ab5..1079c1d2e 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -131,6 +131,15 @@ def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGramm ws ::= ([ \t\n] ws)? """ +ENGLISH_GBNF = r""" +# note: this might be incomplete, mostly an example +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= letter | digit | punctuation +letter ::= [a-zA-Z] +digit ::= [0-9] +punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~] +""" + JAPANESE_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py deleted file mode 100644 index d9dfaf5fd..000000000 --- a/llama_cpp/llava_cpp.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import annotations - -import os -from ctypes import ( - c_bool, - c_char_p, - c_int, - c_uint8, - c_float, - c_void_p, - POINTER, - _Pointer, # type: ignore - Structure, -) -import pathlib -from typing import ( - Union, - NewType, - Optional, - TYPE_CHECKING, -) - -import llama_cpp.llama_cpp as llama_cpp - -from llama_cpp._ctypes_extensions import ( - load_shared_library, - ctypes_function_for_shared_library, -) - -if TYPE_CHECKING: - from llama_cpp._ctypes_extensions import ( - CtypesArray, - ) - - -# Specify the base name of the shared library to load -_libllava_base_name = "llava" -_libllava_override_path = os.environ.get("LLAVA_CPP_LIB") -_libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path() - -# Load the library -_libllava = load_shared_library(_libllava_base_name, _libllava_base_path) - -ctypes_function = ctypes_function_for_shared_library(_libllava) - - -################################################ -# llava.h -################################################ - -# struct clip_ctx; -clip_ctx_p = NewType("clip_ctx_p", int) -clip_ctx_p_ctypes = c_void_p - - -# struct llava_image_embed { -# float * embed; -# int n_image_pos; -# }; -class llava_image_embed(Structure): - _fields_ = [ - ("embed", POINTER(c_float)), - ("n_image_pos", c_int), - ] - - -# /** sanity check for clip <-> llava embed size match */ -# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); -@ctypes_function( - "llava_validate_embed_size", - [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], - c_bool, -) -def llava_validate_embed_size( - ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, / -) -> bool: - ... - - -# /** build an image embed from image file bytes */ -# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); -@ctypes_function( - "llava_image_embed_make_with_bytes", - [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], - POINTER(llava_image_embed), -) -def llava_image_embed_make_with_bytes( - ctx_clip: clip_ctx_p, - n_threads: Union[c_int, int], - image_bytes: CtypesArray[c_uint8], - image_bytes_length: Union[c_int, int], - /, -) -> "_Pointer[llava_image_embed]": - ... - - -# /** build an image embed from a path to an image filename */ -# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -@ctypes_function( - "llava_image_embed_make_with_filename", - [clip_ctx_p_ctypes, c_int, c_char_p], - POINTER(llava_image_embed), -) -def llava_image_embed_make_with_filename( - ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, / -) -> "_Pointer[llava_image_embed]": - ... - - -# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); -# /** free an embedding made with llava_image_embed_make_* */ -@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None) -def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): - ... - - -# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); -@ctypes_function( - "llava_eval_image_embed", - [ - llama_cpp.llama_context_p_ctypes, - POINTER(llava_image_embed), - c_int, - POINTER(c_int), - ], - c_bool, -) -def llava_eval_image_embed( - ctx_llama: llama_cpp.llama_context_p, - embed: "_Pointer[llava_image_embed]", - n_batch: Union[c_int, int], - n_past: "_Pointer[c_int]", - /, -) -> bool: - ... - - -################################################ -# clip.h -################################################ - - -# /** load mmproj model */ -# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); -@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes) -def clip_model_load( - fname: bytes, verbosity: Union[c_int, int], / -) -> Optional[clip_ctx_p]: - ... - - -# /** free mmproj model */ -# CLIP_API void clip_free(struct clip_ctx * ctx); -@ctypes_function("clip_free", [clip_ctx_p_ctypes], None) -def clip_free(ctx: clip_ctx_p, /): - ... - diff --git a/llama_cpp/mtmd_cpp.py b/llama_cpp/mtmd_cpp.py index a45f8f406..856033906 100644 --- a/llama_cpp/mtmd_cpp.py +++ b/llama_cpp/mtmd_cpp.py @@ -1,11 +1,14 @@ from __future__ import annotations +import enum import os from ctypes import ( c_bool, c_char_p, c_int, + c_uint, c_uint8, + c_int32, c_uint32, c_float, c_void_p, @@ -13,7 +16,6 @@ POINTER, _Pointer, # type: ignore Structure, - byref, ) import pathlib from typing import ( @@ -36,214 +38,668 @@ ) -# Specify the base name of the shared library to load +# --- mtmd library loading --- _libmtmd_base_name = "mtmd" _libmtmd_override_path = os.environ.get("MTMD_CPP_LIB") -_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path() +_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path(_libmtmd_override_path).parent -# Load the library +# Load the mtmd library _libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path) +ctypes_function_mtmd = ctypes_function_for_shared_library(_libmtmd) -ctypes_function = ctypes_function_for_shared_library(_libmtmd) ################################################ -# mtmd.h types +# mtmd.h +# /** +# * libmtmd: A library for multimodal support in llama.cpp. +# * +# * WARNING: This API is experimental and subject to many BREAKING CHANGES. +# * Issues related to API usage may receive lower priority support. +# * +# * For the usage, see an example in mtmd-cli.cpp +# */ ################################################ -# Opaque types + +# enum mtmd_input_chunk_type { +# MTMD_INPUT_CHUNK_TYPE_TEXT, +# MTMD_INPUT_CHUNK_TYPE_IMAGE, +# MTMD_INPUT_CHUNK_TYPE_AUDIO, +# }; +class mtmd_input_chunk_type(enum.IntEnum): + MTMD_INPUT_CHUNK_TYPE_TEXT = 0 + MTMD_INPUT_CHUNK_TYPE_IMAGE = 1 + MTMD_INPUT_CHUNK_TYPE_AUDIO = 2 + +# // opaque types + +# struct mtmd_context; mtmd_context_p = NewType("mtmd_context_p", int) mtmd_context_p_ctypes = c_void_p +# struct mtmd_bitmap; mtmd_bitmap_p = NewType("mtmd_bitmap_p", int) mtmd_bitmap_p_ctypes = c_void_p +# struct mtmd_image_tokens; mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int) mtmd_image_tokens_p_ctypes = c_void_p +# struct mtmd_input_chunk; mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int) mtmd_input_chunk_p_ctypes = c_void_p +# struct mtmd_input_chunks; mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int) mtmd_input_chunks_p_ctypes = c_void_p -# Enums -MTMD_INPUT_CHUNK_TYPE_TEXT = 0 -MTMD_INPUT_CHUNK_TYPE_IMAGE = 1 -MTMD_INPUT_CHUNK_TYPE_AUDIO = 2 -# Structures +# struct mtmd_input_text { +# const char * text; +# bool add_special; +# bool parse_special; +# }; +class mtmd_input_text(Structure): + _fields_ = [ + ("text", c_char_p), + ("add_special", c_bool), + ("parse_special", c_bool), + ] +mtmd_input_text_p = NewType("mtmd_input_text_p", int) +mtmd_input_text_p_ctypes = POINTER(mtmd_input_text) + +# struct mtmd_context_params { +# bool use_gpu; +# bool print_timings; +# int n_threads; +# enum ggml_log_level verbosity; +# const char * image_marker; // deprecated, use media_marker instead +# const char * media_marker; +# }; class mtmd_context_params(Structure): _fields_ = [ ("use_gpu", c_bool), ("print_timings", c_bool), ("n_threads", c_int), - ("verbosity", c_int), # ggml_log_level + ("verbosity", c_int), ("image_marker", c_char_p), ("media_marker", c_char_p), ] -class mtmd_input_text(Structure): - _fields_ = [ - ("text", c_char_p), - ("add_special", c_bool), - ("parse_special", c_bool), - ] - -################################################ -# mtmd.h functions -################################################ +mtmd_context_params_p = NewType("mtmd_context_params_p", int) +mtmd_context_params_p_ctypes = POINTER(mtmd_context_params) # MTMD_API const char * mtmd_default_marker(void); -@ctypes_function("mtmd_default_marker", [], c_char_p) -def mtmd_default_marker() -> bytes: +@ctypes_function_mtmd( + "mtmd_default_marker", + [], + c_char_p, +) +def mtmd_default_marker() -> c_char_p: ... + # MTMD_API struct mtmd_context_params mtmd_context_params_default(void); -@ctypes_function("mtmd_context_params_default", [], mtmd_context_params) +@ctypes_function_mtmd( + "mtmd_context_params_default", + [], + mtmd_context_params, +) def mtmd_context_params_default() -> mtmd_context_params: ... + +# // initialize the mtmd context +# // return nullptr on failure # MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, # const struct llama_model * text_model, # const struct mtmd_context_params ctx_params); -@ctypes_function( - "mtmd_init_from_file", - [c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params], - mtmd_context_p_ctypes +@ctypes_function_mtmd( + "mtmd_init_from_file", [ + c_char_p, + llama_cpp.llama_model_p_ctypes, + mtmd_context_params, + ], + mtmd_context_p_ctypes, ) def mtmd_init_from_file( - mmproj_fname: bytes, + mmproj_fname: c_char_p, text_model: llama_cpp.llama_model_p, ctx_params: mtmd_context_params, /, -) -> Optional[mtmd_context_p]: +) -> mtmd_context_p: + """ + initialize the mtmd context + return nullptr on failure + """ ... + # MTMD_API void mtmd_free(mtmd_context * ctx); -@ctypes_function("mtmd_free", [mtmd_context_p_ctypes], None) -def mtmd_free(ctx: mtmd_context_p, /): +@ctypes_function_mtmd("mtmd_free", [mtmd_context_p_ctypes], None) +def mtmd_free(ctx: mtmd_context_p): ... +# // whether we need to set non-causal mask before llama_decode +# MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); +@ctypes_function_mtmd( + "mtmd_decode_use_non_causal", [mtmd_context_p_ctypes], c_bool) +def mtmd_decode_use_non_causal(ctx: mtmd_context_p) -> c_bool: + ... + +# // whether the current model use M-RoPE for llama_decode +# MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); +@ctypes_function_mtmd( + "mtmd_decode_use_mrope", [mtmd_context_p_ctypes], c_bool) +def mtmd_decode_use_mrope(ctx: mtmd_context_p) -> c_bool: + ... + +# // whether the current model supports vision input # MTMD_API bool mtmd_support_vision(mtmd_context * ctx); -@ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool) -def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool: +@ctypes_function_mtmd( + "mtmd_support_vision", [mtmd_context_p_ctypes], c_bool) +def mtmd_support_vision(ctx: mtmd_context_p) -> c_bool: + ... + +# // whether the current model supports audio input +# MTMD_API bool mtmd_support_audio(mtmd_context * ctx); +@ctypes_function_mtmd( + "mtmd_support_audio", [mtmd_context_p_ctypes], c_bool) +def mtmd_support_audio(ctx: mtmd_context_p) -> c_bool: + ... + +# // get audio bitrate in Hz, for example 16000 for Whisper +# // return -1 if audio is not supported +# MTMD_API int mtmd_get_audio_bitrate(mtmd_context * ctx); +@ctypes_function_mtmd( + "mtmd_get_audio_bitrate", [mtmd_context_p_ctypes], c_int) +def mtmd_get_audio_bitrate(ctx: mtmd_context_p) -> c_int: ... -# MTMD_API mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, uint32_t ny, const unsigned char * data); -@ctypes_function( - "mtmd_bitmap_init", - [c_uint32, c_uint32, POINTER(c_uint8)], - mtmd_bitmap_p_ctypes +# // mtmd_bitmap +# // +# // if bitmap is image: +# // length of data must be nx * ny * 3 +# // the data is in RGBRGBRGB... format +# // if bitmap is audio: +# // length of data must be n_samples * sizeof(float) +# // the data is in float format (PCM F32) + +# MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data); +@ctypes_function_mtmd( + "mtmd_bitmap_init", [ + c_uint32, + c_uint32, + c_char_p, + ], + mtmd_bitmap_p_ctypes, ) def mtmd_bitmap_init( - nx: Union[c_uint32, int], - ny: Union[c_uint32, int], - data: CtypesArray[c_uint8], + nx: c_uint32, + ny: c_uint32, + data: c_char_p, /, -) -> Optional[mtmd_bitmap_p]: +) -> mtmd_bitmap_p: ... -# MTMD_API void mtmd_bitmap_free(mtmd_bitmap * bitmap); -@ctypes_function("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None) -def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /): + +# MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data); +@ctypes_function_mtmd( + "mtmd_bitmap_init_from_audio", [ + c_uint, + POINTER(c_float) + ], + mtmd_bitmap_p_ctypes, +) +def mtmd_bitmap_init_from_audio( + n_samples: c_uint, + data: POINTER(c_float), + /, +) -> mtmd_bitmap_p: ... -# MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); -@ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes) -def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]: + +# MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_get_nx", [mtmd_bitmap_p_ctypes], c_uint32) +def mtmd_bitmap_get_nx(bitmap: mtmd_bitmap_p) -> c_uint32: ... -# MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); -@ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None) -def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /): +# MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_get_ny", [mtmd_bitmap_p_ctypes], c_uint32) +def mtmd_bitmap_get_ny(bitmap: mtmd_bitmap_p) -> c_uint32: ... -# MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks); -@ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t) -def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int: +# MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_get_data", [mtmd_bitmap_p_ctypes], c_char_p) +def mtmd_bitmap_get_data(bitmap: mtmd_bitmap_p) -> c_char_p: ... -# MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx); -@ctypes_function( - "mtmd_input_chunks_get", - [mtmd_input_chunks_p_ctypes, c_size_t], - mtmd_input_chunk_p_ctypes -) +# MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_get_n_bytes", [mtmd_bitmap_p_ctypes], c_size_t) +def mtmd_bitmap_get_n_bytes(bitmap: mtmd_bitmap_p) -> c_size_t: + ... + +# MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_is_audio", [mtmd_bitmap_p_ctypes], c_bool) +def mtmd_bitmap_is_audio(bitmap: mtmd_bitmap_p) -> c_bool: + ... + +# MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None) +def mtmd_bitmap_free(bitmap: mtmd_bitmap_p): + ... + +# // bitmap ID is optional, but useful for KV cache tracking +# // these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data() +# MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap); +@ctypes_function_mtmd("mtmd_bitmap_get_id", [mtmd_bitmap_p_ctypes], c_char_p) +def mtmd_bitmap_get_id(bitmap: mtmd_bitmap_p) -> c_char_p: + """ + bitmap ID is optional, but useful for KV cache tracking + these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data() + """ + ... + + +# MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id); +@ctypes_function_mtmd( + "mtmd_bitmap_set_id", [ + mtmd_bitmap_p_ctypes, + c_char_p, + ], None) +def mtmd_bitmap_set_id( + bitmap: mtmd_bitmap_p, + id: c_char_p, + /, +): + ... + + +# // mtmd_input_chunks +# // +# // this is simply a list of mtmd_input_chunk +# // the elements can only be populated via mtmd_tokenize() +# MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); +@ctypes_function_mtmd("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes) +def mtmd_input_chunks_init() -> mtmd_input_chunks_p: + """ + this is simply a list of mtmd_input_chunk + the elements can only be populated via mtmd_tokenize() + """ + ... + + +# MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks); +@ctypes_function_mtmd("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t) +def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p) -> c_size_t: + ... + + +# MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get (const mtmd_input_chunks * chunks, size_t idx); +@ctypes_function_mtmd( + "mtmd_input_chunks_get", [ + mtmd_input_chunks_p_ctypes, + c_int32, + ], mtmd_input_chunk_p_ctypes) def mtmd_input_chunks_get( - chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], / -) -> Optional[mtmd_input_chunk_p]: + chunks: mtmd_input_chunks_p, + idx: c_int32, + /, +) -> mtmd_input_chunk_p: + ... + + +# MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); +@ctypes_function_mtmd("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None) +def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p): + ... + + +# // mtmd_input_chunk +# // +# // the instance will be constructed via mtmd_tokenize() +# // it will be freed along with mtmd_input_chunks +# MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int32) +def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p) -> c_int32: + """ + the instance will be constructed via mtmd_tokenize() + it will be freed along with mtmd_input_chunks + """ + ... + +# MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output); +@ctypes_function_mtmd( + "mtmd_input_chunk_get_tokens_text", + [mtmd_input_chunk_p_ctypes, POINTER(c_size_t)], + POINTER(llama_cpp.llama_token) +) +def mtmd_input_chunk_get_tokens_text( + chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", / +) -> Optional["_Pointer[llama_cpp.llama_token]"]: + ... + +# MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_get_tokens_image", [mtmd_input_chunk_p_ctypes], mtmd_image_tokens_p_ctypes) +def mtmd_input_chunk_get_tokens_image(chunk: mtmd_input_chunk_p) -> mtmd_image_tokens_p: ... +# MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t) +def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p) -> c_size_t: + ... + +# // returns nullptr for ID on text chunk +# MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_get_id", [mtmd_input_chunk_p_ctypes], c_char_p) +def mtmd_input_chunk_get_id(chunk: mtmd_input_chunk_p) -> c_char_p: + """ + returns nullptr for ID on text chunk + """ + ... + +# // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) +# MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_get_n_pos", [mtmd_input_chunk_p_ctypes], c_int32) +def mtmd_input_chunk_get_n_pos(chunk: mtmd_input_chunk_p) -> c_int32: + """ + number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) + """ + ... + +# // in case you want to use custom logic to handle the chunk (i.e. KV cache management) +# // you can move the chunk ownership to your own code by copying it +# // remember to free the chunk when you are done with it +# MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_copy", [mtmd_input_chunk_p_ctypes], mtmd_input_chunk_p_ctypes) +def mtmd_input_chunk_copy(chunk: mtmd_input_chunk_p) -> mtmd_input_chunk_p: + """ + in case you want to use custom logic to handle the chunk (i.e. KV cache management) + you can move the chunk ownership to your own code by copying it + remember to free the chunk when you are done with it + """ + ... + +# MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk); +@ctypes_function_mtmd("mtmd_input_chunk_free", [mtmd_input_chunk_p_ctypes], None) +def mtmd_input_chunk_free(chunk: mtmd_input_chunk_p): + """ + remember to free the chunk when you are done with it + """ + ... + + +# // mtmd_image_tokens +# // +# // the instance will be constructed via mtmd_tokenize() +# // it will be freed along with mtmd_input_chunk +# MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate +@ctypes_function_mtmd( + "mtmd_image_tokens_get_n_tokens", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_n_tokens(image_tokens: mtmd_image_tokens_p) -> c_size_t: + ... + +# MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens); +@ctypes_function_mtmd( + "mtmd_image_tokens_get_nx", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_nx(image_tokens: mtmd_image_tokens_p) -> c_size_t: + ... + +# MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens); +@ctypes_function_mtmd( + "mtmd_image_tokens_get_ny", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_ny(image_tokens: mtmd_image_tokens_p) -> c_size_t: + ... + +# MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate +@ctypes_function_mtmd( + "mtmd_image_tokens_get_id", [mtmd_image_tokens_p_ctypes], c_char_p) +def mtmd_image_tokens_get_id(image_tokens: mtmd_image_tokens_p) -> c_char_p: + ... + +# // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) +# MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate +@ctypes_function_mtmd( + "mtmd_image_tokens_get_n_pos", [mtmd_image_tokens_p_ctypes], c_int32) +def mtmd_image_tokens_get_n_pos(image_tokens: mtmd_image_tokens_p) -> c_int32: + ... + +# // tokenize an input text prompt and a list of bitmaps (images/audio) +# // the prompt must have the input image marker (default: "<__media__>") in it +# // the default marker is defined by mtmd_default_marker() +# // the marker will be replaced with the image/audio chunk +# // for example: +# // "here is an image: <__media__>\ndescribe it in detail." +# // this will gives 3 chunks: +# // 1. "here is an image: " +# // 2. (image/audio tokens) +# // 3. "\ndescribe it in detail." +# // number of bitmaps must be equal to the number of markers in the prompt +# // this function is thread-safe (shared ctx) +# // return values: +# // 0 on success +# // 1 on number of bitmaps not matching the number of markers +# // 2 on image preprocessing error # MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, # mtmd_input_chunks * output, # const mtmd_input_text * text, # const mtmd_bitmap ** bitmaps, # size_t n_bitmaps); -@ctypes_function( - "mtmd_tokenize", - [ +@ctypes_function_mtmd( + "mtmd_tokenize", [ mtmd_context_p_ctypes, mtmd_input_chunks_p_ctypes, - POINTER(mtmd_input_text), + mtmd_input_text_p_ctypes, POINTER(mtmd_bitmap_p_ctypes), - c_size_t, + c_uint, ], - c_int, + c_int32, ) def mtmd_tokenize( ctx: mtmd_context_p, output: mtmd_input_chunks_p, - text: "_Pointer[mtmd_input_text]", - bitmaps: CtypesArray[mtmd_bitmap_p_ctypes], - n_bitmaps: Union[c_size_t, int], + text: mtmd_input_text_p, + bitmaps: POINTER(mtmd_bitmap_p), + n_bitmaps: c_uint, + /, +) -> c_int32: + """ + tokenize an input text prompt and a list of bitmaps (images/audio) + the prompt must have the input image marker (default: "<__media__>") in it + the default marker is defined by mtmd_default_marker() + the marker will be replaced with the image/audio chunk + return values: + 0 on success + 1 on number of bitmaps not matching the number of markers + 2 on image preprocessing error + """ + ... + +# // returns 0 on success +# // TODO: deprecate +# MTMD_API int32_t mtmd_encode(mtmd_context * ctx, +# const mtmd_image_tokens * image_tokens); +@ctypes_function_mtmd( + "mtmd_encode", [ + mtmd_context_p_ctypes, + mtmd_image_tokens_p_ctypes + ], + c_int32, +) +def mtmd_encode( + ctx: mtmd_context_p, + image_tokens: mtmd_image_tokens_p, /, -) -> int: +) -> c_int32: ... -# MTMD_API size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk); -@ctypes_function("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t) -def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int: + +# // returns 0 on success +# MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx, +# const mtmd_input_chunk * chunk); +@ctypes_function_mtmd( + "mtmd_encode_chunk", [ + mtmd_context_p_ctypes, + mtmd_input_chunk_p_ctypes + ], + c_int32, +) +def mtmd_encode_chunk( + ctx: mtmd_context_p, + chunk: mtmd_input_chunk_p, + /, +) -> c_int32: ... -# MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk); -@ctypes_function("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int) -def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int: +# // get output embeddings from the last encode pass +# // the reading size (in bytes) is equal to: +# // llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float) +# MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); +@ctypes_function_mtmd( + "mtmd_get_output_embd", [mtmd_context_p_ctypes], POINTER(c_float)) +def mtmd_get_output_embd(ctx: mtmd_context_p) -> POINTER(c_float): + """ + get output embeddings from the last encode pass + """ ... -# MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output); -@ctypes_function( - "mtmd_input_chunk_get_tokens_text", - [mtmd_input_chunk_p_ctypes, POINTER(c_size_t)], - POINTER(llama_cpp.llama_token) + +# // test function, to be used in test-mtmd-c-api.c +# MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void); +@ctypes_function_mtmd( + "mtmd_test_create_input_chunks", + [], + mtmd_input_chunk_p_ctypes, ) -def mtmd_input_chunk_get_tokens_text( - chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", / -) -> Optional["_Pointer[llama_cpp.llama_token]"]: +def mtmd_test_create_input_chunks() -> mtmd_input_chunk_p: ... -################################################ -# mtmd-helper.h functions -################################################ +# // +# // libmtmd helper functions +# // +# // Please note that these helpers are not guaranteed to be stable. +# // BREAKING CHANGES are expected. +# // + + +# // helper function to construct a mtmd_bitmap from a file +# // it calls mtmd_helper_bitmap_init_from_buf() internally +# // returns nullptr on failure +# // this function is thread-safe +# MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname); + +@ctypes_function_mtmd( + "mtmd_helper_bitmap_init_from_file", [mtmd_context_p_ctypes, c_char_p], mtmd_bitmap_p_ctypes) +def mtmd_helper_bitmap_init_from_file(ctx: mtmd_context_p, fname: c_char_p) -> mtmd_bitmap_p: + """ + helper function to construct a mtmd_bitmap from a file + it calls mtmd_helper_bitmap_init_from_buf() internally + returns nullptr on failure + """ + ... + + +# // helper function to construct a mtmd_bitmap from a buffer containing a file +# // supported formats: +# // image: formats supported by stb_image: jpg, png, bmp, gif, etc. +# // audio: formats supported by miniaudio: wav, mp3, flac +# // note: audio files will be auto-detected based on magic bytes +# // returns nullptr on failure +# // this function is thread-safe # MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len); -@ctypes_function( - "mtmd_helper_bitmap_init_from_buf", - [mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t], - mtmd_bitmap_p_ctypes -) +@ctypes_function_mtmd( + "mtmd_helper_bitmap_init_from_buf", [mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t], mtmd_bitmap_p_ctypes) def mtmd_helper_bitmap_init_from_buf( ctx: mtmd_context_p, buf: CtypesArray[c_uint8], - length: Union[c_size_t, int], + len: c_size_t, /, -) -> Optional[mtmd_bitmap_p]: +) -> mtmd_bitmap_p: + """ + helper function to construct a mtmd_bitmap from a buffer containing a file + supported formats: + image: formats supported by stb_image: jpg, png, bmp, gif, etc. + audio: formats supported by miniaudio: wav, mp3, flac + note: audio files will be auto-detected based on magic bytes + returns nullptr on failure + """ ... + +# // helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache # MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks); -@ctypes_function("mtmd_helper_get_n_tokens", [mtmd_input_chunks_p_ctypes], c_size_t) -def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: +@ctypes_function_mtmd( + "mtmd_helper_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t) +def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunk_p) -> c_size_t: + """ + helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache + """ ... + +# // helper to count the total position of tokens from a list of chunks, useful to keep track of n_past +# // normally, n_pos is equal to n_tokens, but for M-RoPE it is different +# MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); +@ctypes_function_mtmd( + "mtmd_helper_get_n_pos", [mtmd_input_chunk_p_ctypes], c_int32) +def mtmd_helper_get_n_pos(chunks: mtmd_input_chunk_p) -> c_int32: + """ + helper to count the total position of tokens from a list of chunks, useful to keep track of n_past + normally, n_pos is equal to n_tokens, but for M-RoPE it is different + """ + ... + + +# // helper function that automatically: +# // 1. run llama_decode() on text chunks +# // 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() +# // if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error +# // otherwise, returns 0 on success +# // this function is NOT thread-safe +# MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, +# struct llama_context * lctx, +# const mtmd_input_chunks * chunks, +# llama_pos n_past, +# llama_seq_id seq_id, +# int32_t n_batch, +# bool logits_last, +# llama_pos * new_n_past); +@ctypes_function_mtmd( + "mtmd_helper_eval_chunks", [ + mtmd_context_p_ctypes, + llama_cpp.llama_context_p_ctypes, + mtmd_input_chunk_p_ctypes, + c_int32, + c_int32, + c_int32, + c_bool, + POINTER(c_int32), + ], + c_int32) +def mtmd_helper_eval_chunks( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunks: mtmd_input_chunk_p, + n_past: c_int32, + seq_id: c_int32, + n_batch: c_int32, + logits_last: c_bool, + new_n_past: POINTER(c_int32), + /, +) -> c_int32: + """ + helper function that automatically: + 1. run llama_decode() on text chunks + 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() + if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error + otherwise, returns 0 on success + """ + ... + + +# // works like mtmd_helper_eval_chunks(), but only for a single chunk +# // this function is NOT thread-safe # MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, # struct llama_context * lctx, # const mtmd_input_chunk * chunk, @@ -252,29 +708,72 @@ def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: # int32_t n_batch, # bool logits_last, # llama_pos * new_n_past); -@ctypes_function( - "mtmd_helper_eval_chunk_single", - [ +@ctypes_function_mtmd( + "mtmd_helper_eval_chunk_single", [ mtmd_context_p_ctypes, llama_cpp.llama_context_p_ctypes, mtmd_input_chunk_p_ctypes, - llama_cpp.llama_pos, - llama_cpp.llama_seq_id, - c_int, + c_int32, + c_int32, + c_int32, c_bool, - POINTER(llama_cpp.llama_pos), + POINTER(c_int32), ], - c_int, -) + c_int32) def mtmd_helper_eval_chunk_single( ctx: mtmd_context_p, lctx: llama_cpp.llama_context_p, - chunk: mtmd_input_chunk_p, - n_past: llama_cpp.llama_pos, - seq_id: llama_cpp.llama_seq_id, - n_batch: Union[c_int, int], - logits_last: Union[c_bool, bool], - new_n_past: "_Pointer[llama_cpp.llama_pos]", + chunks: mtmd_input_chunk_p, + n_past: c_int32, + seq_id: c_int32, + n_batch: c_int32, + logits_last: c_bool, + new_n_past: POINTER(c_int32), + /, +) -> c_int32: + """ + works like mtmd_helper_eval_chunks(), but only for a single chunk + """ + ... + + +# // helper function to decode an image whose embeddings have already been calculated +# // this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention) +# // ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure +# MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, +# struct llama_context * lctx, +# const mtmd_input_chunk * chunk, +# float * encoded_embd, +# llama_pos n_past, +# llama_seq_id seq_id, +# int32_t n_batch, +# llama_pos * new_n_past); +@ctypes_function_mtmd( + "mtmd_helper_decode_image_chunk", [ + mtmd_context_p_ctypes, + llama_cpp.llama_context_p_ctypes, + mtmd_input_chunk_p_ctypes, + POINTER(c_float), + c_int32, + c_int32, + c_int32, + POINTER(c_int32), + ], + c_int32) +def mtmd_helper_decode_image_chunk( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunks: mtmd_input_chunk_p, + encoded_embd: POINTER(c_float), + n_past: c_int32, + seq_id: c_int32, + n_batch: c_int32, + new_n_past: c_int32, /, -) -> int: +) -> c_int32: + """ + helper function to decode an image whose embeddings have already been calculated + this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention) + ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure + """ ... diff --git a/llama_cpp/raw.py b/llama_cpp/raw.py new file mode 100644 index 000000000..68c9c46fb --- /dev/null +++ b/llama_cpp/raw.py @@ -0,0 +1,271 @@ +"""Minimal transparent bindings over llama.cpp: no hidden logic, no token/header munging. + +This module provides LlamaRaw – a deliberately tiny, explicit wrapper around the +ctypes surface. It aims to (a) expose only direct llama.cpp operations and (b) +perform zero implicit mutation of user-provided data other than what the +underlying C API itself performs. + +Design principles: + 1. 1:1 parameter mapping to underlying llama.cpp defaults where possible. + 2. No automatic BOS/EOS insertion, space-prefix hacks, infill re‑ordering, + caching, seed advancement, repetition penalties, grammar, or processor chains. + 3. No Python‑side logits mutation. You receive raw logits exactly as produced. + 4. Resource ownership explicit: model + context + (optional) batch. User frees. + 5. Small, easy to audit code (< ~200 lines) and self‑documenting. + +Motivation: Allow advanced users / downstream frameworks to build their own +policies (prompt construction, sampling, caching, grammar) without fighting the +high‑level convenience layer. + +NOTE: High‑level `Llama` class remains unchanged and continues to provide its +richer feature set. This raw layer is opt‑in. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import ctypes +from typing import Iterable, List, Sequence, Optional + +import numpy as np + +import llama_cpp.llama_cpp as llama_cpp +import llama_cpp._internals as internals +from ._utils import suppress_stdout_stderr + +_backend_initialized = False + + +def _ensure_backend_initialized(): + global _backend_initialized + if not _backend_initialized: + with suppress_stdout_stderr(disable=True): # silent like upstream + llama_cpp.llama_backend_init() + _backend_initialized = True + + +@dataclass +class RawModelResources: + model: internals.LlamaModel + ctx: internals.LlamaContext + batch: internals.LlamaBatch + + def close(self): # explicit free + # contextlib.closing already handles destructors, but offer explicit hook + self.batch.close() + self.ctx.close() + self.model.close() + + +class LlamaRaw: + """Ultra‑thin wrapper exposing only primitive llama.cpp operations. + + Public methods intentionally mirror underlying C semantics. + """ + + def __init__( + self, + model_path: str, + *, + n_ctx: int = 512, + n_batch: int = 512, + seed: int = llama_cpp.LLAMA_DEFAULT_SEED, + n_gpu_layers: int = 0, + split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER, + main_gpu: int = 0, + tensor_split: Optional[Sequence[float]] = None, + vocab_only: bool = False, + use_mmap: bool = True, + use_mlock: bool = False, + kv_overrides: Optional[dict] = None, + logits_all: bool = False, + embedding: bool = False, + verbose: bool = False, + ) -> None: + _ensure_backend_initialized() + + self.verbose = verbose + self._seed = seed + self._logits_all = logits_all + self._n_batch = min(n_ctx, n_batch) + + # --- model params (direct copy, no extra logic) --- + mparams = llama_cpp.llama_model_default_params() + mparams.n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers + mparams.split_mode = split_mode + mparams.main_gpu = main_gpu + if tensor_split is not None: + if len(tensor_split) > llama_cpp.LLAMA_MAX_DEVICES: + raise ValueError("tensor_split length exceeds LLAMA_MAX_DEVICES") + FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES + self._c_tensor_split = FloatArray(*tensor_split) # keep ref + mparams.tensor_split = self._c_tensor_split + mparams.vocab_only = vocab_only + mparams.use_mmap = use_mmap + mparams.use_mlock = use_mlock + self._kv_overrides = None + if kv_overrides: + # Direct translation – identical to high level but without abstraction. + arr_len = len(kv_overrides) + 1 + overrides = (llama_cpp.llama_model_kv_override * arr_len)() + for i, (k, v) in enumerate(kv_overrides.items()): + overrides[i].key = k.encode("utf-8") + if isinstance(v, bool): + overrides[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL + overrides[i].value.val_bool = v + elif isinstance(v, int): + overrides[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT + overrides[i].value.val_i64 = v + elif isinstance(v, float): + overrides[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT + overrides[i].value.val_f64 = v + elif isinstance(v, str): + vb = v.encode("utf-8") + if len(vb) > 128: + raise ValueError("kv_override str too long (max 128)") + vb = vb.ljust(128, b"\0") + overrides[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR + # Directly copy into the union's byte array region + addr = ctypes.addressof(overrides[i].value) + ctypes.memmove(addr, vb, 128) + else: + raise TypeError(f"Unsupported kv_override type for {k}: {type(v)}") + overrides[-1].key = b"\0" + mparams.kv_overrides = overrides + self._kv_overrides = overrides # keep ref + + # --- context params --- + cparams = llama_cpp.llama_context_default_params() + cparams.n_ctx = n_ctx + cparams.n_batch = self._n_batch + cparams.n_ubatch = self._n_batch + cparams.embeddings = embedding + cparams.seed = seed # if supported in version, else ignored silently + if logits_all: + cparams.logits_all = 1 # some versions may ignore; harmless + + # Load model + context + self._model = internals.LlamaModel( + path_model=model_path, params=mparams, verbose=verbose + ) + self._ctx = internals.LlamaContext( + model=self._model, params=cparams, verbose=verbose + ) + self._batch = internals.LlamaBatch( + n_tokens=self._n_batch, embd=0, n_seq_max=cparams.n_ctx, verbose=verbose + ) + + self._n_vocab = self.n_vocab() + self._n_ctx = self.n_ctx() + + # ---------------------- primitive queries ---------------------- + def n_vocab(self) -> int: # direct + return self._model.n_vocab() + + def n_ctx(self) -> int: # direct + return self._ctx.n_ctx() + + def n_embd(self) -> int: + return self._model.n_embd() + + # ---------------------- tokenization (raw) ---------------------- + def tokenize( + self, text: bytes, add_bos: bool = True, parse_special: bool = False + ) -> List[int]: + # No prefix space, no BOS/EOS stripping beyond requested add_bos. + # We ask model for max potential tokens (heuristic: len(text)+8) like upstream guidance. + max_tokens = len(text) + 8 + arr = (llama_cpp.llama_token * max_tokens)() + # llama_tokenize expects a model pointer + n = llama_cpp.llama_tokenize( + self._model.model, + text, + len(text), + arr, + max_tokens, + add_bos, + parse_special, + ) + if n < 0: + raise RuntimeError("Tokenization failed (buffer too small?)") + return [arr[i] for i in range(n)] + + def detokenize(self, tokens: Sequence[int]) -> bytes: + # Straight concatenation of token pieces (no space heuristics) + parts: List[bytes] = [] + vocab = self._model.vocab + for t in tokens: + parts.append(llama_cpp.llama_token_get_text(vocab, t)) + return b"".join(parts) + + # ---------------------- evaluation ---------------------- + def eval(self, tokens: Sequence[int]): + # Evaluate tokens sequentially in mini-batches; no cache prefix tricks. + for i in range(0, len(tokens), self._n_batch): + batch_slice = tokens[i : i + self._n_batch] + self._batch.set_batch(batch_slice, n_past=self._ctx.n_tokens(), logits_all=self._logits_all) # type: ignore[arg-type] + self._ctx.decode(self._batch) + + def get_logits(self) -> np.ndarray: + # Return view (copy to ensure immutability by caller) + shape = (self._n_vocab,) + logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=shape) + return logits.copy() + + # ---------------------- sampling ---------------------- + def sample_greedy(self) -> int: + # Pure greedy: argmax over last logits + logits = self.get_logits() + return int(int(np.argmax(logits))) + + # (Optional) user can implement their own top-k/p externally using returned logits. + + # ---------------------- embeddings ---------------------- + def get_embeddings(self): + if not self._ctx.ctx: # safety + raise RuntimeError("Context not initialized") + ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx) + if ptr is None: + raise RuntimeError("Embeddings not available (model not in embedding mode)") + n_embd = self.n_embd() + return ptr[:n_embd] + + # ---------------------- state I/O ---------------------- + def save_state_bytes(self) -> bytes: + size = llama_cpp.llama_get_state_size(self._ctx.ctx) + buf = (ctypes.c_uint8 * size)() + wrote = llama_cpp.llama_copy_state_data(self._ctx.ctx, buf) + if wrote != size: + raise RuntimeError("State copy size mismatch") + return bytes(buf) + + def load_state_bytes(self, data: bytes): + size = len(data) + buf = (ctypes.c_uint8 * size).from_buffer_copy(data) + wrote = llama_cpp.llama_set_state_data(self._ctx.ctx, buf) + if wrote != size: + raise RuntimeError("State restore size mismatch") + + # ---------------------- RNG seed ---------------------- + def set_seed(self, seed: int): + # Only updates internal tracking; llama.cpp global sampler seeds are separate when using custom sampling. + self._seed = seed + + # ---------------------- teardown ---------------------- + def close(self): + self._batch.close() + self._ctx.close() + self._model.close() + + # Support context manager + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + def __del__(self): # safety – explicit free + try: + self.close() + except Exception: + pass diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index 11bd363b5..670a250bc 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -171,6 +171,20 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: chat_handler = llama_cpp.llama_chat_format.MiniCPMv26ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) + elif settings.chat_format == "gemma3": + assert settings.clip_model_path is not None, "clip model not found" + if settings.hf_model_repo_id is not None: + chat_handler = ( + llama_cpp.llama_chat_format.Gemma3ChatHandler.from_pretrained( + repo_id=settings.hf_model_repo_id, + filename=settings.clip_model_path, + verbose=settings.verbose, + ) + ) + else: + chat_handler = llama_cpp.llama_chat_format.Gemma3ChatHandler( + clip_model_path=settings.clip_model_path, verbose=settings.verbose + ) elif settings.chat_format == "qwen2.5-vl": assert settings.clip_model_path is not None, "clip model not found" if settings.hf_model_repo_id is not None: @@ -267,6 +281,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: n_threads=settings.n_threads, n_threads_batch=settings.n_threads_batch, rope_scaling_type=settings.rope_scaling_type, + pooling_type=settings.pooling_type, + attention_type=settings.attention_type, + flash_attn_type=settings.flash_attn_type, rope_freq_base=settings.rope_freq_base, rope_freq_scale=settings.rope_freq_scale, yarn_ext_factor=settings.yarn_ext_factor, @@ -278,7 +295,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: logits_all=settings.logits_all, embedding=settings.embedding, offload_kqv=settings.offload_kqv, - flash_attn=settings.flash_attn, + op_offload=settings.op_offload, + swa_full=settings.swa_full, + kv_unified=settings.kv_unified, # Sampling Params last_n_tokens_size=settings.last_n_tokens_size, # LoRA Params diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 13c951241..e15bedde2 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -84,7 +84,20 @@ class ModelSettings(BaseSettings): description="The number of threads to use when batch processing. Use -1 for max cpu threads", ) rope_scaling_type: int = Field( - default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED + default=llama_cpp.llama_rope_scaling_type.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + description="RoPE scaling type, from `enum llama_rope_scaling_type", + ) + pooling_type: int = Field( + default=llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, + description="whether to pool (sum) embedding results by sequence id", + ) + attention_type: int = Field( + default=llama_cpp.llama_attention_type.LLAMA_ATTENTION_TYPE_UNSPECIFIED, + description="attention type to use for embeddings", + ) + flash_attn_type: int = Field( + default=llama_cpp.llama_flash_attn_type.LLAMA_FLASH_ATTN_TYPE_AUTO, + description="when to enable Flash Attention", ) rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") rope_freq_scale: float = Field( @@ -103,8 +116,14 @@ class ModelSettings(BaseSettings): offload_kqv: bool = Field( default=True, description="Whether to offload kqv to the GPU." ) - flash_attn: bool = Field( - default=False, description="Whether to use flash attention." + op_offload: bool = Field( + default=True, description="Whether to offload host tensor operations to device" + ) + swa_full: bool = Field( + default=True, description="Whether to use full-size SWA cache" + ) + kv_unified: bool = Field( + default=True, description="enable single unified KV buffer for the KV cache of all sequences" ) # Sampling Params last_n_tokens_size: int = Field( diff --git a/scripts/debug_tokens.py b/scripts/debug_tokens.py new file mode 100644 index 000000000..c8b5177a0 --- /dev/null +++ b/scripts/debug_tokens.py @@ -0,0 +1,97 @@ +import os, multiprocessing +from huggingface_hub import hf_hub_download +import llama_cpp, llama_cpp._internals as internals + +repo_id = "Qwen/Qwen2-0.5B-Instruct-GGUF" +filename = "qwen2-0_5b-instruct-q8_0.gguf" +model_path = hf_hub_download(repo_id, filename) + +PROMPT = b"The quick brown fox jumps" +seed = 1337 + +# Low-level setup (mirrors test_real_model) +params = llama_cpp.llama_model_default_params() +params.use_mmap = llama_cpp.llama_supports_mmap() +params.use_mlock = False +params.check_tensors = False +model = internals.LlamaModel(path_model=model_path, params=params) + +cparams = llama_cpp.llama_context_default_params() +cparams.n_ctx = 32 +cparams.n_batch = 32 +cparams.n_ubatch = 32 +cparams.n_threads = multiprocessing.cpu_count() +cparams.n_threads_batch = multiprocessing.cpu_count() +cparams.logits_all = False +cparams.flash_attn = True +context = internals.LlamaContext(model=model, params=cparams) + +low_tokens = model.tokenize(PROMPT, add_bos=True, special=True) +print("low_level_tokens", low_tokens) + +sampler = internals.LlamaSampler() +sampler.add_top_k(50) +sampler.add_top_p(0.9, 1) +sampler.add_temp(0.8) +sampler.add_dist(seed) + +# Evaluate prompt once (like test_real_model loop first iteration) +batch = internals.LlamaBatch(n_tokens=len(low_tokens), embd=0, n_seq_max=1) +batch.set_batch(low_tokens, n_past=0, logits_all=False) +context.decode(batch) + +# Grab logits pointer for last token of prompt (not exposed since logits_all False, so we re-run with True) +# Re-run with logits_all True to capture distribution for last prompt token +cparams2 = llama_cpp.llama_context_default_params() +for attr in ("n_ctx", "n_batch", "n_ubatch", "n_threads", "n_threads_batch"): + setattr(cparams2, attr, getattr(cparams, attr)) +cparams2.logits_all = True +context2 = internals.LlamaContext(model=model, params=cparams2) +batch2 = internals.LlamaBatch(n_tokens=len(low_tokens), embd=0, n_seq_max=1) +batch2.set_batch(low_tokens, n_past=0, logits_all=True) +context2.decode(batch2) +import numpy as np + +rows = len(low_tokens) +cols = model.n_vocab() +raw_logits = context2.get_logits() +logits = np.ctypeslib.as_array(raw_logits, shape=(rows * cols,)) +last_logits = logits[(rows - 1) * cols : rows * cols] +# Extract logits for tokens 31 and 916 +print( + "low_level_last_logits_selected", + {31: float(last_logits[31]), 916: float(last_logits[916])}, +) +# Top10 tokens +top_indices = np.argsort(last_logits)[-10:][::-1] +print("low_level_top10", list(map(int, top_indices))) + +# Now high-level API path +llama = llama_cpp.Llama( + model_path, n_ctx=32, n_batch=32, n_ubatch=32, logits_all=True, flash_attn=True +) +# Simulate _create_completion prompt tokenization logic +add_bos_flag = llama._model.add_bos_token() +hi_tokens_add = llama.tokenize(PROMPT, add_bos=True, special=True) +hi_tokens_no = llama.tokenize(PROMPT, add_bos=False, special=True) +print("high_level_add_bos_tokens", hi_tokens_add) +print("high_level_no_bos_tokens", hi_tokens_no) +constructed_prompt_tokens = ( + [] if (not add_bos_flag) else [llama.token_bos()] +) + hi_tokens_no +print("constructed_prompt_tokens", constructed_prompt_tokens) + +# Evaluate using llama.eval directly to capture logits for prompt +llama.eval(constructed_prompt_tokens) +if llama._logits_all: + prompt_last_logits = llama._scores[llama.n_tokens - 1, :] + print( + "high_level_prompt_last_logits_selected", + {31: float(prompt_last_logits[31]), 916: float(prompt_last_logits[916])}, + ) + top_hi = np.argsort(prompt_last_logits)[-10:][::-1] + print("high_level_prompt_top10", list(map(int, top_hi))) +else: + print("ERROR: logits_all not enabled for high-level eval") + +print("DONE") diff --git a/scripts/diagnostics_compare.py b/scripts/diagnostics_compare.py new file mode 100644 index 000000000..a3d79c526 --- /dev/null +++ b/scripts/diagnostics_compare.py @@ -0,0 +1,152 @@ +import multiprocessing +import numpy as np +from huggingface_hub import hf_hub_download +import llama_cpp +import llama_cpp._internals as internals + +# Diagnostic comparison between low-level and high-level first token sampling. +# +# Creates two low-level contexts (n_ctx=16,32) and mirrors the sampler chain used in +# tests/test_llama.py::test_real_model, then compares first sampled token and +# full 4-token continuation against the high-level API with identical parameters. +# +# Run: +# python scripts/diagnostics_compare.py +# Optional env: +# LLAMA_CPP_DEBUG_DET=1 (already used by high-level path) + +PROMPT = b"The quick brown fox jumps" +SEED = 1337 +REPO_ID = "Qwen/Qwen2-0.5B-Instruct-GGUF" +FILENAME = "qwen2-0_5b-instruct-q8_0.gguf" + +model_path = hf_hub_download(REPO_ID, FILENAME) + +params = llama_cpp.llama_model_default_params() +params.use_mmap = llama_cpp.llama_supports_mmap() +params.use_mlock = llama_cpp.llama_supports_mlock() +params.check_tensors = False +model = internals.LlamaModel(path_model=model_path, params=params) + +print("Low-level comparisons:") +low_level_logits = {} +for ctx_n in (16, 32): + cparams = llama_cpp.llama_context_default_params() + cparams.n_ctx = ctx_n + cparams.n_batch = ctx_n + cparams.n_ubatch = ctx_n + cparams.n_threads = multiprocessing.cpu_count() + cparams.n_threads_batch = multiprocessing.cpu_count() + cparams.logits_all = False + cparams.flash_attn = True + + context = internals.LlamaContext(model=model, params=cparams) + tokens = model.tokenize(PROMPT, add_bos=True, special=True) + batch = internals.LlamaBatch(n_tokens=len(tokens), embd=0, n_seq_max=1) + + sampler = internals.LlamaSampler() + sampler.add_top_k(50) + sampler.add_top_p(0.9, 1) + sampler.add_temp(0.8) + sampler.add_dist(SEED) + + toks = tokens.copy() + n_eval = 0 + result = tokens.copy() + first_tok = None + + # Step 0: evaluate prompt only, capture logits for last prompt token + batch.set_batch(toks, n_past=n_eval, logits_all=False) + context.decode(batch) + n_eval += len(toks) + ll_ptr = context.get_logits_ith(-1) + if ll_ptr is not None: + low_level_logits[ctx_n] = np.ctypeslib.as_array( + ll_ptr, shape=(model.n_vocab(),) + ).copy() + # Sample first token + token_id = sampler.sample(context, -1) + first_tok = token_id + toks = [token_id] + result.append(token_id) + + # Steps 1-3: sample remaining tokens (total 4 sampled tokens) + for _ in range(3): + batch.set_batch(toks, n_past=n_eval, logits_all=False) + context.decode(batch) + n_eval += len(toks) + token_id = sampler.sample(context, -1) + toks = [token_id] + result.append(token_id) + + out_text = model.detokenize(result[len(tokens) :], special=True).decode( + "utf-8", errors="ignore" + ) + print(f" n_ctx={ctx_n} first_token={first_tok} continuation='{out_text}'") + +print("\nHigh-level comparison (n_ctx=32):") +ll = llama_cpp.Llama( + model_path, + n_ctx=32, + n_batch=32, + n_ubatch=32, + logits_all=False, + flash_attn=True, + verbose=False, +) + +# Reproduce low-level prompt eval inside high-level object +prompt_tokens_hl = ll.tokenize(PROMPT, add_bos=True, special=True) +ll.reset() +ll.eval(prompt_tokens_hl) +hl_logits_ptr = ll._ctx.get_logits_ith( + -1 +) # noqa: SLF001 access internal for diagnostics +if hl_logits_ptr is not None: + hl_logits = np.ctypeslib.as_array( + hl_logits_ptr, shape=(ll._n_vocab,) + ).copy() # noqa: SLF001 + # Compare with low-level n_ctx=32 logits + ref = low_level_logits.get(32) + if ref is not None: + diff = np.max(np.abs(ref - hl_logits)) + print(f" Logits diff (low-level n_ctx=32 vs high-level) max_abs={diff:.6f}") + # Show top 5 indices where difference is largest + delta = np.abs(ref - hl_logits) + top_idx = np.argsort(delta)[-5:][::-1] + print( + " Top differing token ids:", [(int(i), float(delta[i])) for i in top_idx] + ) + + # Show probabilities for target tokens 31 and 916 after softmax for sanity + def softmax(x): + m = np.max(x) + e = np.exp(x - m) + return e / np.sum(e) + + ref_p = softmax(ref) + hl_p = softmax(hl_logits) + for tid in (31, 916): + print( + f" token {tid}: ref_logit={ref[tid]:.4f} hl_logit={hl_logits[tid]:.4f} ref_p={ref_p[tid]:.6f} hl_p={hl_p[tid]:.6f}" + ) + +resp = ll.create_completion( + "The quick brown fox jumps", + max_tokens=4, + top_k=50, + top_p=0.9, + temperature=0.8, + seed=SEED, + stream=False, +) +if isinstance(resp, dict): + print(" high_level_text=", resp["choices"][0]["text"]) +else: + # If somehow streaming iterator was returned, consume first + collected = None + for part in resp: + collected = part + if collected: + print(" high_level_text=", collected["choices"][0]["text"]) +print("Done.") diff --git a/tests/test_llama.py b/tests/test_llama.py index 0a1a9f5ad..2c72bb5f3 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -66,6 +66,7 @@ def llama_cpp_model_path(): def test_real_model(llama_cpp_model_path): import os + assert os.path.exists(llama_cpp_model_path) params = llama_cpp.llama_model_default_params() @@ -114,6 +115,7 @@ def test_real_model(llama_cpp_model_path): output_text = model.detokenize(output, special=True) assert output_text == b" over the lazy dog" + def test_real_llama(llama_cpp_model_path): model = llama_cpp.Llama( llama_cpp_model_path, @@ -132,11 +134,10 @@ def test_real_llama(llama_cpp_model_path): top_k=50, top_p=0.9, temperature=0.8, - seed=1337 + seed=1337, ) assert output["choices"][0]["text"] == " over the lazy dog" - output = model.create_completion( "The capital of france is paris, 'true' or 'false'?:\n", max_tokens=4, @@ -144,22 +145,23 @@ def test_real_llama(llama_cpp_model_path): top_p=0.9, temperature=0.8, seed=1337, - grammar=llama_cpp.LlamaGrammar.from_string(""" + grammar=llama_cpp.LlamaGrammar.from_string( + """ root ::= "true" | "false" -""") +""" + ), ) assert output["choices"][0]["text"] == "true" suffix = b"rot" tokens = model.tokenize(suffix, add_bos=True, special=True) + def logit_processor_func(input_ids, logits): for token in tokens: logits[token] *= 1000 return logits - logit_processors = llama_cpp.LogitsProcessorList( - [logit_processor_func] - ) + logit_processors = llama_cpp.LogitsProcessorList([logit_processor_func]) output = model.create_completion( "The capital of france is par", @@ -168,7 +170,7 @@ def logit_processor_func(input_ids, logits): top_p=0.9, temperature=0.8, seed=1337, - logits_processor=logit_processors + logits_processor=logit_processors, ) assert output["choices"][0]["text"].lower().startswith("rot") @@ -182,9 +184,11 @@ def logit_processor_func(input_ids, logits): top_k=50, top_p=0.9, temperature=0.8, - grammar=llama_cpp.LlamaGrammar.from_string(""" + grammar=llama_cpp.LlamaGrammar.from_string( + """ root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10" -""") +""" + ), ) number_1 = output["choices"][0]["text"] @@ -194,9 +198,11 @@ def logit_processor_func(input_ids, logits): top_k=50, top_p=0.9, temperature=0.8, - grammar=llama_cpp.LlamaGrammar.from_string(""" + grammar=llama_cpp.LlamaGrammar.from_string( + """ root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10" -""") +""" + ), ) number_2 = output["choices"][0]["text"] @@ -208,9 +214,11 @@ def logit_processor_func(input_ids, logits): top_k=50, top_p=0.9, temperature=0.8, - grammar=llama_cpp.LlamaGrammar.from_string(""" + grammar=llama_cpp.LlamaGrammar.from_string( + """ root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10" -""") +""" + ), ) number_3 = output["choices"][0]["text"] @@ -228,7 +236,7 @@ def test_real_llama_embeddings(llama_cpp_model_path): n_threads_batch=multiprocessing.cpu_count(), logits_all=False, flash_attn=True, - embedding=True + embedding=True, ) # Smoke test for now model.embed("Hello World") diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4227c9be4..8ff206097 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4227c9be4268ac844921b90f31595f81236bd317 +Subproject commit 8ff206097c2bf3ca1c7aa95f9d6db779fc7bdd68