Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,14 @@ def __del__(self):
self._llama_batch_free(self.batch)
self.batch = None

def n_tokens(self) -> int:
assert self.batch is not None
return self.batch.n_tokens

def reset(self):
assert self.batch is not None
self.batch.n_tokens = 0

def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
Expand All @@ -522,6 +530,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True

def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
n_tokens0 = self.batch.n_tokens
self.batch.n_tokens += n_tokens
for i in range(n_tokens):
j = n_tokens0 + i
self.batch.token[j] = batch[i]
self.batch.pos[j] = i
self.batch.seq_id[j][0] = seq_id
self.batch.n_seq_id[j] = 1
self.batch.logits[j] = logits_all
self.batch.logits[n_tokens - 1] = True


class _LlamaTokenDataArray:
def __init__(self, *, n_vocab: int):
Expand Down
135 changes: 101 additions & 34 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,53 @@ def create_embedding(
Returns:
An embedding object.
"""
assert self._ctx.ctx is not None
assert self._model.model is not None
model_name: str = model if model is not None else self.model_path

# get numeric embeddings
embeds: List[List[float]]
total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore

# convert to CreateEmbeddingResponse
data: List[Embedding] = [
{
"object": "embedding",
"embedding": emb,
"index": idx,
}
for idx, emb in enumerate(embeds)
]

return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}

def embed(
self,
input: Union[str, List[str]],
normalize: bool = True,
truncate: bool = True,
return_count: bool = False,
):
"""Embed a string.

Args:
input: The utf-8 encoded string to embed.

Returns:
A list of embeddings
"""
assert self._ctx.ctx is not None
n_embd = self.n_embd()
n_ctx = self.n_ctx()

if self.context_params.embedding == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
Expand All @@ -734,48 +777,72 @@ def create_embedding(
else:
inputs = input

data: List[Embedding] = []
# reset batch
self._batch.reset()

# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(sizes: List[int]):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()

# store embeddings
for i, s in enumerate(sizes):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
norm = np.linalg.norm(embedding) if normalize else s
embedding: List[float] = [v / float(norm) for v in embedding]
data.append(embedding)

# init state
total_tokens = 0
for index, input in enumerate(inputs):
tokens = self.tokenize(input.encode("utf-8"), special=True)
self.reset()
self.eval(tokens)
t_batch = 0
s_sizes: List[int] = []

# accumulate batches and encode
for text in inputs:
tokens = self.tokenize(text.encode("utf-8"))
if truncate:
tokens = tokens[:n_ctx]

n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
: llama_cpp.llama_n_embd(self._model.model)
]

data.append(
{
"object": "embedding",
"embedding": embedding,
"index": index,
}
)
# check for overrun
if n_tokens > n_ctx:
raise ValueError(
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
)

# time to eval batch
if t_batch + n_tokens > self._n_ctx:
decode_batch(s_sizes)
t_batch = 0
s_sizes = []

# add to batch
self._batch.add_sequence(tokens, len(s_sizes), False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iamlemec There is a null pointer access in this function if n_batch < n_ctx and the prompt exceeds n_batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! Check out #1194. It changes the bounds checking from n_ctx to n_batch.

t_batch += n_tokens
s_sizes.append(n_tokens)

# hanlde last batch
decode_batch(s_sizes)

if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx)

return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}

def embed(self, input: str) -> List[float]:
"""Embed a string.
output = data[0] if isinstance(input, str) else data

Args:
input: The utf-8 encoded string to embed.
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self.reset()

Returns:
A list of embeddings
"""
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
if return_count:
return output, total_tokens
else:
return output

def _create_completion(
self,
Expand Down