-
Notifications
You must be signed in to change notification settings - Fork 480
Closed as not planned
Labels
staleStale issue will be autoclosed soonStale issue will be autoclosed soon
Description
Description
Hi,
To present the results of refactoring the StatelessExecutor code.
This looks strange and not optimal (the context can be many gigabytes):
Context = _weights.CreateContext(_params, logger);
Context.Dispose();
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
Images = [];
_weights = weights;
_params = @params;
_logger = logger;
_batch = new LLamaBatch();
Context = _weights.CreateContext(_params, logger);
Context.Dispose();
}
I don't pretend to be accurate. I use this code and it works:
public class StatelessExecutor : ILLamaExecutor
{
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger _logger;
private readonly LLamaBatch _batch;
// LLava Section
public bool IsMultiModal => false;
/// <inheritdoc />
public LLavaWeights ClipModel => default;
/// <inheritdoc />
public List<byte[]> Images { get; }
/// <summary>
/// The context used by the executor when running the inference.
/// </summary>
public LLamaContext Context { get; private set; }
/// <summary>
/// If true, applies the default template to the prompt as defined in the rules for <a href="https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template">llama_chat_apply_template</a> template.
/// </summary>
public bool ApplyTemplate { get; init; }
/// <summary>
/// The system message to use with the prompt. Only used when <see cref="ApplyTemplate" /> is true.
/// </summary>
public string SystemMessage { get; init; }
/// <summary>
/// Create a new stateless executor which will use the given model
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
/// <param name="logger"></param>
public LLamaStatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger logger = null)
{
Images = [];
_weights = weights;
_params = @params;
_logger = logger;
_batch = new LLamaBatch();
}
/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams inferenceParams, [EnumeratorCancellation] CancellationToken token = default)
{
try
{
// Create an inference context which will be disposed when this method exits
using var Context = _weights.CreateContext(_params, _logger);
// Reset the sampling pipeline (if there is one)
inferenceParams?.SamplingPipeline?.Reset();
// Sanity check inference params
inferenceParams ??= new InferenceParams();
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
// Create decoders for the token stream
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
if (ApplyTemplate)
{
var template = new LLamaTemplate(_weights.NativeHandle) { AddAssistant = true };
if (SystemMessage != null) template.Add("system", SystemMessage);
template.Add("user", text);
text = PromptTemplateTransformer.ToModelPrompt(template);
}
// Tokenize the prompt
var tokens = Context.Tokenize(text: text, addBos: true, special: true).ToList();
// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past);
n_past = past;
if (r != DecodeResult.Ok)
throw new LLamaDecodeError(r);
// Begin loop, evaluating one token at a time
var maxTokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < maxTokens && !token.IsCancellationRequested; i++)
{
// Sample with the pipeline
var id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, _batch.TokenCount - 1);
// Check if this token should end generation
if (id.IsEndOfGeneration(_weights.Vocab))
break;
// Decode this token into text
decoder.Add(id);
var decoded = decoder.Read();
yield return decoded;
// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
break;
tokens.Clear();
tokens.Add(id);
// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{
var canAddBos = Context.Vocab.ShouldAddBOS;
var tokensKeep = inferenceParams.TokensKeep;
// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
if (tokensKeep < 0 || tokensKeep > tokens.Count)
{
tokensKeep = tokens.Count;
}
else
{
tokensKeep += Convert.ToInt32(canAddBos);
}
var n_left = n_past - tokensKeep;
var n_discard = n_left / 2;
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard);
n_past -= n_discard;
}
// Evaluate with this new token
_batch.Clear();
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
var returnCode = await Context.DecodeAsync(_batch, token);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
}
finally
{
Context = null;
}
}
}
Metadata
Metadata
Assignees
Labels
staleStale issue will be autoclosed soonStale issue will be autoclosed soon