From 9931d0ef60572a4605018878141f4d33de9dc56b Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 27 Sep 2025 07:52:40 +0200 Subject: [PATCH 1/5] Mtmd Implementation base --- LLama.Examples/ExampleRunner.cs | 4 +- .../Examples/BatchedExecutorLLava.cs | 91 ----- .../Examples/BatchedExecutorMtmd.cs | 126 +++++++ ...ecute.cs => MtmdInteractiveModeExecute.cs} | 83 +++-- LLama.Examples/LLama.Examples.csproj | 2 +- LLama.Unittest/Constants.cs | 6 +- LLama.Unittest/LLama.Unittest.csproj | 16 +- LLama.Unittest/MtmdExecutorTests.cs | 81 ++++ LLama.Unittest/MtmdWeightsTests.cs | 140 +++++++ .../Native/SafeLlamaModelHandleTests.cs | 32 -- .../SafeLlamaModelHandleVocabularyTests.cs | 42 --- LLama/Abstractions/ILLamaExecutor.cs | 7 +- LLama/Batched/BatchedExecutor.cs | 67 +++- LLama/Batched/Conversation.cs | 242 ++++++++++-- LLama/Batched/ConversationExtensions.cs | 14 +- LLama/LLamaExecutorBase.cs | 135 ++++--- LLama/LLamaInstructExecutor.cs | 211 ++++++++++- LLama/LLamaInteractExecutor.cs | 311 ++++++++++++---- LLama/LLamaSharp.csproj | 4 +- LLama/LLamaStatelessExecutor.cs | 6 +- LLama/LLavaWeights.cs | 137 ------- LLama/Native/LLavaImageEmbed.cs | 19 - LLama/Native/Load/NativeLibraryConfig.cs | 32 +- LLama/Native/Load/NativeLibraryUtils.cs | 2 +- LLama/Native/MtmdContextParams.cs | 148 ++++++++ LLama/Native/MtmdImageEmbed.cs | 20 + LLama/Native/NativeApi.LLava.cs | 63 ---- LLama/Native/NativeApi.Load.cs | 22 +- LLama/Native/NativeApi.Mtmd.cs | 312 ++++++++++++++++ LLama/Native/NativeApi.cs | 119 +++++- LLama/Native/SafeLlavaImageEmbedHandle.cs | 162 -------- LLama/Native/SafeLlavaModelHandle.cs | 137 ------- LLama/Native/SafeMtmdEmbed.cs | 247 +++++++++++++ LLama/Native/SafeMtmdInputChunk.cs | 150 ++++++++ LLama/Native/SafeMtmdInputChunks.cs | 103 ++++++ LLama/Native/SafeMtmdModelHandle.cs | 349 ++++++++++++++++++ LLama/Properties/InternalsVisibleTo.cs | 3 + LLama/SafeMtmdWeights.cs | 80 ++++ docs/Examples/LLavaInteractiveModeExecute.md | 129 ------- docs/Examples/MtmdInteractiveModeExecute.md | 41 ++ mkdocs.yml | 4 +- 41 files changed, 2832 insertions(+), 1067 deletions(-) delete mode 100644 LLama.Examples/Examples/BatchedExecutorLLava.cs create mode 100644 LLama.Examples/Examples/BatchedExecutorMtmd.cs rename LLama.Examples/Examples/{LlavaInteractiveModeExecute.cs => MtmdInteractiveModeExecute.cs} (59%) create mode 100644 LLama.Unittest/MtmdExecutorTests.cs create mode 100644 LLama.Unittest/MtmdWeightsTests.cs delete mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleTests.cs delete mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs delete mode 100644 LLama/LLavaWeights.cs delete mode 100644 LLama/Native/LLavaImageEmbed.cs create mode 100644 LLama/Native/MtmdContextParams.cs create mode 100644 LLama/Native/MtmdImageEmbed.cs delete mode 100644 LLama/Native/NativeApi.LLava.cs create mode 100644 LLama/Native/NativeApi.Mtmd.cs delete mode 100644 LLama/Native/SafeLlavaImageEmbedHandle.cs delete mode 100644 LLama/Native/SafeLlavaModelHandle.cs create mode 100644 LLama/Native/SafeMtmdEmbed.cs create mode 100644 LLama/Native/SafeMtmdInputChunk.cs create mode 100644 LLama/Native/SafeMtmdInputChunks.cs create mode 100644 LLama/Native/SafeMtmdModelHandle.cs create mode 100644 LLama/Properties/InternalsVisibleTo.cs create mode 100644 LLama/SafeMtmdWeights.cs delete mode 100644 docs/Examples/LLavaInteractiveModeExecute.md create mode 100644 docs/Examples/MtmdInteractiveModeExecute.md diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index c073cd4cd..23f07c6a1 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -15,7 +15,7 @@ public class ExampleRunner { "Chat Session: Automatic conversation", TalkToYourself.Run }, { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, { "Executor: Interactive mode chat", InteractiveModeExecute.Run }, - { "Executor: Llava Interactive mode chat", LlavaInteractiveModeExecute.Run }, + { "Executor: Mtmd Interactive mode chat", MtmdInteractiveModeExecute.Run }, { "Executor: Instruct mode chat", InstructModeExecute.Run }, { "Executor: Stateless mode chat", StatelessModeExecute.Run }, { "Save and Load: chat session", SaveAndLoadSession.Run }, @@ -33,7 +33,7 @@ public class ExampleRunner { "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run }, { "Batched Executor: Fork", BatchedExecutorFork.Run }, { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, - { "Batched Executor: LLava", BatchedExecutorLLava.Run }, + { "Batched Executor: Mtmd", BatchedExecutorMtmd.Run }, { "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run }, { "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run }, { "Custom Sampling Pipeline", CustomSampler.Run }, diff --git a/LLama.Examples/Examples/BatchedExecutorLLava.cs b/LLama.Examples/Examples/BatchedExecutorLLava.cs deleted file mode 100644 index a131e994e..000000000 --- a/LLama.Examples/Examples/BatchedExecutorLLava.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System.Text; -using LLama.Batched; -using LLama.Common; -using LLama.Native; -using LLama.Sampling; -using Spectre.Console; - -namespace LLama.Examples.Examples; - -/// -/// Demonstrates using LLava (image embeddings) with the batched executor. -/// -public class BatchedExecutorLLava -{ - /// - /// How many tokens of response to generate - /// - public const int TokenCount = 64; - - public static async Task Run() - { - // Load model weights - var parameters = new ModelParams(UserSettings.GetModelPath()); - using var model = await LLamaWeights.LoadFromFileAsync(parameters); - using var llava = await LLavaWeights.LoadFromFileAsync(UserSettings.GetMMProjPath()); - - // Decide on the prompt - var prompt = model.Tokenize(AnsiConsole.Ask("Prompt (or ENTER for default):", "\nUSER: Provide a full description of the image.\nASSISTANT: "), true, false, Encoding.UTF8); - - // Get image and show it - var image = UserSettings.GetImagePath(); - AnsiConsole.Write(new CanvasImage(image)); - - // Create an executor with one conversation - using var executor = new BatchedExecutor(model, parameters); - using var conversation = executor.Create(); - - // Embed the image - SafeLlavaImageEmbedHandle embedding = null!; - await AnsiConsole - .Status() - .StartAsync("[yellow]Embedding image with CLIP[/]", async _ => - { - // ReSharper disable once AccessToDisposedClosure - embedding = llava.CreateImageEmbeddings(await File.ReadAllBytesAsync(image)); - }); - - // Pass in the image and run inference until the entire image has been processed - await AnsiConsole - .Status() - .StartAsync("[yellow]Processing image embedding with language model[/]", async _ => - { - conversation.Prompt(embedding); - while (executor.BatchedTokenCount > 0) - await executor.Infer(); - }); - - // Prompt with the text prompt - conversation.Prompt(prompt); - - // Run inference loop - var decoder = new StreamingTokenDecoder(executor.Context); - var sampler = new DefaultSamplingPipeline(); - await AnsiConsole - .Progress() - .StartAsync(async ctx => - { - var task = ctx.AddTask("Generating Response"); - task.MaxValue = TokenCount; - - // Run a normal inference loop - for (var i = 0; i < TokenCount; i++) - { - task.Increment(1); - - await executor.Infer(); - - var token = sampler.Sample(executor.Context.NativeHandle, conversation.GetSampleIndex()); - if (token.IsEndOfGeneration(executor.Context.Vocab)) - break; - - decoder.Add(token); - conversation.Prompt(token); - } - }); - - // Print final result - var str = decoder.Read(); - AnsiConsole.MarkupInterpolated($"[green]{str}[/]"); - } -} \ No newline at end of file diff --git a/LLama.Examples/Examples/BatchedExecutorMtmd.cs b/LLama.Examples/Examples/BatchedExecutorMtmd.cs new file mode 100644 index 000000000..b62f8b120 --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorMtmd.cs @@ -0,0 +1,126 @@ +using System; +using System.Collections.Generic; +using System.IO; +using LLama.Batched; +using LLama.Common; +using LLama.Exceptions; +using LLama.Native; +using LLama.Sampling; +using Spectre.Console; + +namespace LLama.Examples.Examples; + +/// +/// Demonstrates how to evaluate an image with MTMD helpers and continue generation by +/// manually scheduling batches, similar to what the batched executor does internally. +/// +public class BatchedExecutorMtmd +{ + /// + /// Number of completion tokens to generate after sending the image prompt. + /// + public const int TokenCount = 10000; + + public static async Task Run() + { + // Load the base LLM and its clip/mtmd sidecar weights so the executor has everything it needs. + var parameters = new ModelParams(UserSettings.GetModelPath()); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); + var mtmdParams = MtmdContextParams.Default(); // reuse llama.cpp defaults for helper settings + mtmdParams.UseGpu = false; + var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + + using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights + + using var executor = new BatchedExecutor(model, parameters, mtmd); // drives batched token + chunk evaluation + + // Prepend the media marker so the helper knows where to inject the encoded image tokens. + var defaultPrompt = "\nUSER: Provide a full description of the image.\nASSISTANT: "; + var promptSuffix = AnsiConsole.Ask("Prompt (or ENTER for default):", defaultPrompt); + var promptText = string.Concat(marker, promptSuffix); + + var imagePath = UserSettings.GetImagePath(); + AnsiConsole.Write(new CanvasImage(imagePath)); + + var vocab = executor.Context.NativeHandle.ModelHandle.Vocab; + + // Simple low-temperature sampler keeps the demo deterministic-ish. + var sampler = new DefaultSamplingPipeline + { + Temperature = 0.1f + }; + + // Stream decoded text to the console as soon as tokens arrive. + var decoder = new StreamingTokenDecoder(executor.Context) + { + DecodeSpecialTokens = false + }; + + try + { + // Each conversation tracks its own KV cache sequence IDs. + var conversation = executor.Create(); + // enqueue the image so MtmdHelper sees it + conversation.QueueMedia(imagePath); + // schedule multimodal prompt + conversation.Prompt(promptText, addBos: true, special: true); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Prompt queued with multimodal chunks. Generating response...\n"); + Console.ResetColor(); + + var remaining = TokenCount; + + // Run one decode/sampling/prompt cycle – mirrors the batched executor inner loop. + async Task ProcessNextAsync() + { + var decodeResult = await executor.Infer(); + if (decodeResult == DecodeResult.NoKvSlot) // KV cache exhausted – surface to the user + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine("Insufficient KV cache space for multimodal evaluation."); + Console.ResetColor(); + return false; + } + + if (decodeResult != DecodeResult.Ok) + throw new RuntimeError($"Failed to evaluate batch: {decodeResult}."); + + if (!conversation.RequiresSampling) // another conversation may still be queued + return true; + + var token = conversation.Sample(sampler); // pull logits (or -1 for mtmd chunk) and sample + if (token.IsEndOfGeneration(vocab)) + return false; + + decoder.Add(token); + var delta = decoder.Read(); + if (!string.IsNullOrEmpty(delta)) + Console.Write(delta); + + sampler.Accept(token); // keep sampler state in sync + conversation.Prompt(token); // feed the accepted token back into the batch + remaining--; + return remaining > 0; + } + + while (remaining > 0 && await ProcessNextAsync()) // continue until EOS or budget is reached + { + } + + Console.WriteLine(); + } + catch (IOException ex) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"Could not load media '{imagePath}': {ex.Message}"); + Console.ResetColor(); + } + catch (RuntimeError ex) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"MTMD processing failed: {ex.Message}"); + Console.ResetColor(); + } + } +} diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs similarity index 59% rename from LLama.Examples/Examples/LlavaInteractiveModeExecute.cs rename to LLama.Examples/Examples/MtmdInteractiveModeExecute.cs index 8cbf58dcd..ca0de3b77 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs @@ -1,3 +1,5 @@ +using System.Collections.Generic; +using System.IO; using System.Text.RegularExpressions; using LLama.Common; using Spectre.Console; @@ -6,27 +8,32 @@ namespace LLama.Examples.Examples { - // This example shows how to chat with LLaVA model with both image and text as input. + // This example shows how to chat with Mtmd model with both image and text as input. // It uses the interactive executor to inference. - public class LlavaInteractiveModeExecute + public class MtmdInteractiveModeExecute { public static async Task Run() { string multiModalProj = UserSettings.GetMMProjPath(); string modelPath = UserSettings.GetModelPath(); string modelImage = UserSettings.GetImagePath(); - const int maxTokens = 1024; + const int maxTokens = 2048; var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; var parameters = new ModelParams(modelPath); + var mtmdParameters = MtmdContextParams.Default(); + mtmdParameters.UseGpu = false; + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); - - // Llava Init - using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj); - + + // Mtmd Init + using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); + + var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + var ex = new InteractiveExecutor(context, clipModel); Console.ForegroundColor = ConsoleColor.Yellow; @@ -40,7 +47,7 @@ public static async Task Run() Temperature = 0.1f }, - AntiPrompts = new List { "\nUSER:" }, + AntiPrompts = new List { "\nASSISTANT:" }, MaxTokens = maxTokens }; @@ -48,30 +55,53 @@ public static async Task Run() do { - // Evaluate if we have images + // Evaluate if we have media // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; + var mediaMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var mediaCount = mediaMatches.Count(); + var hasMedia = mediaCount > 0; - if (hasImages) + if (hasMedia) { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); + var mediaPathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var mediaPaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); - List imageBytes; + var embeds = new List(); + var imageList = new List(); + var imageExtensions = new HashSet(StringComparer.OrdinalIgnoreCase) + { + ".png", + ".jpg", + ".jpeg", + ".bmp", + ".gif", + ".webp" + }; + try { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToList(); + foreach (var mediaPath in mediaPaths) + { + var extension = Path.GetExtension(mediaPath); + if (!string.IsNullOrEmpty(extension) && imageExtensions.Contains(extension)) + { + // Keep the raw image data so the caller can reuse or inspect the images later. + imageList.Add(File.ReadAllBytes(mediaPath)); + } + + var embed = clipModel.LoadMedia(mediaPath); + embeds.Add(embed); + } } catch (IOException exception) { Console.ForegroundColor = ConsoleColor.Red; Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); + $"Could not load your {(mediaCount == 1 ? "media" : "medias")}:"); Console.Write($"{exception.Message}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Please try again."); + clipModel.ClearMedia(); break; } @@ -81,19 +111,17 @@ public static async Task Run() // https://github.com/ggerganov/llama.cpp/discussions/3620 ex.Context.NativeHandle.MemorySequenceRemove( LLamaSeqId.Zero, -1, -1 ); - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) + // Replace placeholders with media markers (one marker per image) + foreach (var path in mediaPathsWithCurlyBraces) { - // First image replace to tag " : ""); + prompt = prompt.Replace(path, mediaMarker, StringComparison.Ordinal); } - Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); Console.WriteLine(); - foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)) ?? Array.Empty()) + foreach (var consoleImage in imageList.Select(image => new CanvasImage(image.ToArray()))) { consoleImage.MaxWidth = 50; AnsiConsole.Write(consoleImage); @@ -108,10 +136,9 @@ public static async Task Run() // Initialize Images in executor // - foreach (var image in imagePaths) - { - ex.Images.Add(await File.ReadAllBytesAsync(image)); - } + ex.Embeds.Clear(); + foreach (var embed in embeds) + ex.Embeds.Add(embed); } Console.ForegroundColor = Color.White; diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index 330e77386..6ca8c7210 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -9,7 +9,7 @@ true true - 12 + 13 1701;1702;8604;SKEXP0001;SKEXP0050;SKEXP0052;SKEXP0003 diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index d501b189b..f705f1609 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -9,9 +9,9 @@ internal static class Constants public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf"; - public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf"; - public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf"; - public static readonly string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg"; + public static readonly string MtmdModelPath = "Models/gemma-3-4b-it-Q4_K_M.gguf"; + public static readonly string MtmdMmpPath = "Models/gemma-mmproj-model-f16.gguf"; + public static readonly string MtmdImage = "Models/extreme-ironing-taxi-610x427.jpg"; /// /// Calculate GpuLayer Count to use in UnitTest diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 8f9f075d8..ca3ea8854 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -52,16 +52,16 @@ jina-reranker-v1-tiny-en-FP16.gguf - - https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf + + https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf Models - llava-v1.6-mistral-7b.Q3_K_XS.gguf + gemma-3-4b-it-Q4_K_M.gguf - - https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf + + https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/mmproj-model-f16.gguf Models - mmproj-model-f16.gguf + gemma-mmproj-model-f16.gguf @@ -142,10 +142,10 @@ PreserveNewest - + PreserveNewest - + PreserveNewest diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs new file mode 100644 index 000000000..75a96b261 --- /dev/null +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using LLama.Common; +using LLama.Native; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace LLama.Unittest; + +[Trait("Category", "NoCI")] +public class MtmdExecutorTests : IDisposable +{ + private readonly LLamaWeights _weights; + private readonly MtmdContextParams _mtmdParams; + private readonly SafeMtmdWeights _mtmd; + private readonly ModelParams _modelParams; + + public MtmdExecutorTests() + { + _modelParams = new ModelParams(Constants.MtmdModelPath) + { + ContextSize = 1024 * 8, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + + _weights = LLamaWeights.LoadFromFile(_modelParams); + + _mtmdParams = MtmdContextParams.Default(); + _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); + _mtmdParams.UseGpu = false; + + _mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); + } + + public void Dispose() + { + _mtmd.Dispose(); + _weights.Dispose(); + } + + [Fact] + public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize() + { + using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); + var executor = new InteractiveExecutor(context, _mtmd, NullLogger.Instance); + var marker = _mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + var prompt = $"{marker}\nDescribe the image succinctly."; + + executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + + await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) + { + Assert.True(false, "Prefill should not emit generated text"); + } + + var diagnostics = executor.GetDiagnostics(); + Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); + Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.Equal(0, diagnostics.PendingEmbedCount); + } + + [Fact] + public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce() + { + using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); + var executor = new InstructExecutor(context, _mtmd, logger: NullLogger.Instance); + executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + + var prompt = $"{_mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""} Provide details."; + + await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) + { + } + + var diagnostics = executor.GetDiagnostics(); + Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); + Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.Equal(0, diagnostics.PendingEmbedCount); + } +} diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs new file mode 100644 index 000000000..947bbd1ea --- /dev/null +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -0,0 +1,140 @@ +using System; +using System.IO; +using LLama.Common; +using LLama.Native; +using Xunit; + +namespace LLama.Unittest +{ + // Test the same things as llama model + image embedings + // + public sealed class MtmdWeightTests + : IDisposable + { + private readonly LLamaWeights _llamaWeights; + private readonly SafeMtmdWeights _safeMtmdWeights; + private readonly LLamaContext _context; + private readonly MtmdContextParams _mtmdParams; + private readonly string _mediaMarker; + + public MtmdWeightTests() + { + var @params = new ModelParams(Constants.MtmdModelPath) + { + // Mtmd models requires big context + ContextSize = 1024 * 32, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + _llamaWeights = LLamaWeights.LoadFromFile(@params); + + _mtmdParams = MtmdContextParams.Default(); + _mtmdParams.NThreads = Constants.CIGpuLayerCount; + _mtmdParams.UseGpu = false; // keep tests portable across environments without GPU + + _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable."); + + _safeMtmdWeights = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); + _context = _llamaWeights.CreateContext(@params); + } + + public void Dispose() + { + _context.Dispose(); + _safeMtmdWeights.Dispose(); + _llamaWeights.Dispose(); + } + + private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) + { + _safeMtmdWeights.ClearMedia(); + + var embed = loadEmbed(); + Assert.NotNull(embed); + + using (embed) + { + Assert.True(embed.Nx > 0); + Assert.True(embed.Ny > 0); + Assert.False(embed.IsAudio); + Assert.True(embed.GetDataSpan().Length > 0); + + var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); + Assert.Equal(0, status); + Assert.NotNull(chunks); + + return chunks!; + } + } + + private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) + { + long nPast = 0; + var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); + Assert.Equal(0, eval); + Assert.True(nPast > 0); + } + + [Fact,Trait("Category", "NoCI")] + public void EmbedImageAsFileName() + { + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + AssertChunksEvaluate(chunks); + } + + [Fact,Trait("Category", "NoCI")] + public void EmbedImageAsBinary() + { + var imageBytes = File.ReadAllBytes(Constants.MtmdImage); + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes)); + AssertChunksEvaluate(chunks); + } + + [Fact,Trait("Category", "NoCI")] + public void TokenizeProvidesChunkMetadata() + { + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + + Assert.True(chunks.Size > 0); + + ulong totalTokens = 0; + long totalPositions = 0; + var imageChunks = 0; + + foreach (var chunk in chunks.Enumerate()) + { + totalTokens += chunk.NTokens; + totalPositions += chunk.NPos; + + if (chunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Image) + { + imageChunks++; + + var copy = chunk.Copy(); + try + { + Assert.NotNull(copy); + if (copy != null) + { + Assert.Equal(chunk.NTokens, copy.NTokens); + Assert.Equal(chunk.NPos, copy.NPos); + } + } + finally + { + copy?.Dispose(); + } + } + } + + Assert.True(imageChunks > 0); + Assert.True(totalTokens > 0); + Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks)); + Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks)); + Assert.True(_safeMtmdWeights.SupportsVision); + Assert.False(_safeMtmdWeights.SupportsAudio); + + var audioBitrate = _safeMtmdWeights.AudioBitrate; + Assert.True(audioBitrate <= 0); + } + } +} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs deleted file mode 100644 index f3e5798f2..000000000 --- a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Runtime.InteropServices; -using System.Text; -using LLama.Common; -using LLama.Extensions; -using Xunit; - -namespace LLama.Unittest.Native; - -public class SafeLlamaModelHandleTests -{ - private readonly LLamaWeights _model; - - public SafeLlamaModelHandleTests() - { - var @params = new ModelParams(Constants.GenerativeModelPath2) - { - ContextSize = 1, - GpuLayerCount = Constants.CIGpuLayerCount - }; - _model = LLamaWeights.LoadFromFile(@params); - } - - // Note: This test is flakey, it appears to often (but not always) fail the first time it is run after downloading the model file, but then succeed every time after! - //[SkippableFact] - //public void MetadataValByKey_ReturnsCorrectly() - //{ - // Skip.If(RuntimeInformation.IsOSPlatform(OSPlatform.OSX), "Skipping this test on macOS because for some reason the meta data is incorrect, but the rest of tests work well on mscOS [Check later!]."); - // const string key = "general.name"; - // var template = _model.NativeHandle.MetadataValueByKey(key); - // var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span); - //} -} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs deleted file mode 100644 index 1ce53f395..000000000 --- a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System.Text; -using System.Xml.Linq; -using LLama.Common; -using LLama.Extensions; -using Microsoft.Extensions.Logging; - - -namespace LLama.Unittest.Native; - -public class SafeLlamaModelHandleVocabularyTests: IDisposable -{ - private readonly LLamaWeights _model; - - public SafeLlamaModelHandleVocabularyTests() - { - var @params = new ModelParams(Constants.RerankingModelPath) - { - ContextSize = 0, - PoolingType = LLama.Native.LLamaPoolingType.Rank, - GpuLayerCount = Constants.CIGpuLayerCount - }; - _model = LLamaWeights.LoadFromFile(@params); - } - - public void Dispose() - { - _model.Dispose(); - } - - [Fact] - public void GetLLamaTokenString() - { - var bos = _model.Vocab.BOS; - var eos = _model.Vocab.EOS; - - var bosStr = _model.Vocab.LLamaTokenToString(bos, true); - var eosStr = _model.Vocab.LLamaTokenToString(eos, true); - - Assert.Equal("", bosStr); - Assert.Equal("", eosStr); - } -} diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 9a2233287..92276e4a6 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Threading; +using LLama.Native; namespace LLama.Abstractions { @@ -22,12 +23,12 @@ public interface ILLamaExecutor /// /// Multi-Modal Projections / Clip Model weights /// - public LLavaWeights? ClipModel { get; } + public SafeMtmdWeights? ClipModel { get; } /// - /// List of images: List of images in byte array format. + /// List of media: List of media for Multi-Modal models. /// - public List Images { get; } + public List Embeds { get; } /// /// Asynchronously infers a response from the model. diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index cdb1835e4..1d47bb0b8 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -17,6 +17,7 @@ public sealed class BatchedExecutor { private int _nextSequenceId; private readonly List _batchQueue = [ ]; + private string? _mtmdMarker; /// /// Set to 1 using interlocked exchange while inference is running @@ -60,12 +61,20 @@ public sealed class BatchedExecutor /// The model to use /// Parameters to create a new context public BatchedExecutor(LLamaWeights model, IContextParams contextParams) + : this(model, contextParams, null) + { + } + + public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtmdWeights? clipModel) { Model = model; Context = model.CreateContext(contextParams); + ClipModel = clipModel; Epoch = 1; } + public SafeMtmdWeights? ClipModel { get; } + /// /// Start a new /// @@ -254,6 +263,23 @@ internal LLamaSeqId GetNextSequenceId() return (end, Epoch + (uint)_batchQueue.Count * 2); } + internal ulong QueueMtmdBatch(Conversation conversation, Conversation.MtmdChunkSequence sequence) + { + if (ClipModel is null) + throw new InvalidOperationException("This batched executor is not configured for multimodal inference."); + + var batch = new MtmdChunkBatch(ClipModel, conversation, sequence); + _batchQueue.Add(batch); + return Epoch + (uint)_batchQueue.Count * 2; + } + + internal string GetMtmdMarker() + { + if (ClipModel is null) + throw new InvalidOperationException("This batched executor is not configured for multimodal inference."); + return _mtmdMarker ??= NativeApi.MtmdDefaultMarker() ?? ""; + } + #region batches private interface IBatch { @@ -285,5 +311,44 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token) return ctx.DecodeAsync(Batch, token); } } + + private class MtmdChunkBatch : IBatch + { + private readonly SafeMtmdWeights _clipModel; + private readonly Conversation _conversation; + private readonly Conversation.MtmdChunkSequence _sequence; + + public MtmdChunkBatch(SafeMtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) + { + _clipModel = clipModel; + _conversation = conversation; + _sequence = sequence; + } + + public int ItemCount => Math.Max(1, _sequence.TotalTokens); + + public Task DecodeAsync(LLamaContext ctx, CancellationToken token) + { + try + { + var nPast = _conversation.GetMtmdPast(); + var status = _clipModel.EvaluateChunks(_sequence.Chunks, ctx.NativeHandle, ref nPast, + (int)_conversation.ConversationId, checked((int)ctx.BatchSize), logitsLast: true); + if (status != 0) + { + _conversation.OnMtmdEvaluationFailed(status); + return Task.FromResult(DecodeResult.DecodeFailed); + } + + _conversation.OnMtmdEvaluationCompleted(nPast, _sequence); + return Task.FromResult(DecodeResult.Ok); + } + catch + { + _conversation.OnMtmdEvaluationFailed(-1); + return Task.FromResult(DecodeResult.DecodeFailed); + } + } + } #endregion -} \ No newline at end of file +} diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index fcc94ae8f..807542b79 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Text.Json; using CommunityToolkit.HighPerformance.Buffers; +using LLama.Exceptions; using LLama.Native; namespace LLama.Batched; @@ -21,6 +22,12 @@ public sealed class Conversation /// Indicates if this conversation has been "forked" and may share logits with another conversation. /// private bool _forked; + private readonly List _mtmdEmbeds = new(); + private int? _mtmdLogitsIndex; + private MtmdChunkSequence? _pendingMtmdSequence; + private readonly List _embed_inps = new(); + private readonly List _session_tokens = new(); + private int _consumedTokensCount; /// /// Stores the indices to sample from. Contains valid items. @@ -65,6 +72,46 @@ internal Conversation(BatchedExecutor batch, LLamaSeqId id) Executor = batch; } + internal sealed class MtmdChunkSequence : IDisposable + { + public SafeMtmdInputChunks Chunks { get; } + public List TextTokens { get; } + public int TotalPositions { get; } + public int TotalTokens => TextTokens.Count; + + private MtmdChunkSequence(SafeMtmdInputChunks chunks, List textTokens, int totalPositions) + { + Chunks = chunks; + TextTokens = textTokens; + TotalPositions = totalPositions; + } + + public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel) + { + var textTokens = new List(); + + foreach (var chunk in chunks.Enumerate()) + { + using (chunk) + { + if (chunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in chunk.GetTextTokensSpan()) + textTokens.Add((LLamaToken)unchecked((int)token)); + } + } + + var totalPositions = (int)clipModel.CountPositions(chunks); + return new MtmdChunkSequence(chunks, textTokens, totalPositions); + } + + public void Dispose() + { + Chunks.Dispose(); + } + } + /// /// Finalizer for Conversation /// @@ -83,6 +130,11 @@ public void Dispose() return; _disposed = true; + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + + DisposeQueuedMedia(); + // Remove this conversation from the KV cache Executor.Context.NativeHandle.MemorySequenceRemove(ConversationId, -1, -1); @@ -206,6 +258,43 @@ private void AssertCanBePrompted() if (RequiresInference) throw new AlreadyPromptedConversationException(); + + _mtmdLogitsIndex = null; + } + + public void QueueMedia(string path) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + + var embed = Executor.ClipModel.LoadMedia(path); + _mtmdEmbeds.Add(embed); + _mtmdLogitsIndex = null; + } + + public void QueueMedia(SafeMtmdEmbed embed) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + + _mtmdEmbeds.Add(embed); + _mtmdLogitsIndex = null; + } + + public void Prompt(string promptText, bool addBos = true, bool special = true) + { + if (Executor.ClipModel != null && _mtmdEmbeds.Count > 0) + { + PromptMultimodal(promptText, addBos); + return; + } + + var tokens = Executor.Context.Tokenize(promptText, addBos, special); + Prompt(tokens); } /// @@ -246,6 +335,7 @@ public void Prompt(List tokens, bool allLogits = false) public void Prompt(ReadOnlySpan tokens, bool allLogits = false) { AssertCanBePrompted(); + _mtmdLogitsIndex = null; // No point doing anything if there is no actual prompt! if (tokens.Length == 0) @@ -289,6 +379,59 @@ public void Prompt(ReadOnlySpan tokens, bool allLogits = false) // Unset the forked flag. Since this conversation has just been prompted it's no longer // sharing anything with any other conversations. _forked = false; + _mtmdLogitsIndex = null; + } + + private void PromptMultimodal(string text, bool addBos) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + if (_mtmdEmbeds.Count == 0) + throw new InvalidOperationException("Queue media before prompting with multimodal input."); + + var marker = Executor.GetMtmdMarker(); + var prompt = text; + + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, _mtmdEmbeds.Count)); + prompt = string.Concat(prompt, suffix); + } + + SafeMtmdInputChunks? chunks = null; + try + { + _mtmdLogitsIndex = null; + var status = Executor.ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + Executor.ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + var sequence = MtmdChunkSequence.Create(chunks, Executor.ClipModel); + _pendingMtmdSequence = sequence; + + var epoch = Executor.QueueMtmdBatch(this, sequence); + chunks = null; + + if (_batchSampleIndices.Length == 0) + _batchSampleIndices = new int[4]; + + _batchSampleCount = 0; + _requiredEpoch = epoch; + _forked = false; + } + finally + { + DisposeQueuedMedia(); + chunks?.Dispose(); + } } /// @@ -305,32 +448,7 @@ public void Prompt(LLamaToken token) Span span = [ token ]; Prompt(span); } - - /// - /// Prompt this conversation with an image embedding - /// - /// - public void Prompt(SafeLlavaImageEmbedHandle embedding) - { - AssertCanBePrompted(); - - if (embedding.Model.EmbeddingDimensions != Executor.Model.EmbeddingSize) - throw new ArgumentException($"Embedding dimension mismatch between image embedding ({embedding.Model.EmbeddingDimensions}) and model ({Executor.Model.EmbeddingSize})"); - - for (var i = 0; i < embedding.Model.PatchCount; i++) - { - // Get a batch with space - (var batch, _requiredEpoch) = Executor.GetEmbeddingBatch(); - - batch.Add( - (i, embedding), - static (Span dest, (int index, SafeLlavaImageEmbedHandle embedding) tup) => tup.embedding.GetEmbedding(dest, tup.index), - _end++, - ConversationId, - i == embedding.Model.PatchCount - 1 - ); - } - } + /// /// Prompt this conversation with embeddings @@ -339,6 +457,7 @@ public void Prompt(SafeLlavaImageEmbedHandle embedding) public void Prompt(ReadOnlySpan embeddings) { AssertCanBePrompted(); + _mtmdLogitsIndex = null; var dim = Executor.Model.EmbeddingSize; var count = embeddings.Length / dim; @@ -385,6 +504,75 @@ public void Modify(ModifyKvCache modifier) _requiredEpoch = 0; } + internal long GetMtmdPast() => _end.Value; + + internal void OnMtmdEvaluationCompleted(long newPast, MtmdChunkSequence sequence) + { + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + + _end = (LLamaPos)checked((int)newPast); + + if (_batchSampleIndices.Length == 0) + _batchSampleIndices = new int[4]; + + _batchSampleCount = 1; + _batchSampleIndices[0] = 0; + _mtmdLogitsIndex = -1; + _requiredEpoch = Executor.Epoch + 1; + _forked = false; + + if (sequence.TextTokens.Count > 0) + { + _embed_inps.AddRange(sequence.TextTokens); + _session_tokens.AddRange(sequence.TextTokens); + } + + var fillerToken = GetFillerToken(Executor.GetMtmdMarker()); + var fillerCount = Math.Max(0, sequence.TotalPositions - sequence.TotalTokens); + for (var i = 0; i < fillerCount; i++) + _embed_inps.Add(fillerToken); + + _consumedTokensCount = _embed_inps.Count; + sequence.Dispose(); + } + + internal void OnMtmdEvaluationFailed(int status) + { + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + _mtmdLogitsIndex = null; + _requiredEpoch = Executor.Epoch; + DisposeQueuedMedia(); + } + + internal int? MtmdLogitsIndex => _mtmdLogitsIndex; + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Executor.Context.Tokenize(marker, addBos: false, special: true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Executor.Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default; + } + + private void DisposeQueuedMedia() + { + if (_mtmdEmbeds.Count == 0) + return; + + foreach (var embed in _mtmdEmbeds) + embed.Dispose(); + + _mtmdEmbeds.Clear(); + Executor.ClipModel?.ClearMedia(); + } + /// /// Provides direct access to the KV cache of a . /// See for how to use this. @@ -629,4 +817,4 @@ internal State() } } #endregion -} \ No newline at end of file +} diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs index eb0192061..3e25d3f43 100644 --- a/LLama/Batched/ConversationExtensions.cs +++ b/LLama/Batched/ConversationExtensions.cs @@ -18,7 +18,11 @@ public static class ConversationExtensions /// public static LLamaToken Sample(this Conversation conversation, SafeLLamaSamplerChainHandle sampler, int offset = 0) { - return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + var ctx = conversation.Executor.Context.NativeHandle; + if (conversation.MtmdLogitsIndex == -1) + return sampler.Sample(ctx, -1); + + return sampler.Sample(ctx, conversation.GetSampleIndex(offset)); } /// @@ -30,7 +34,11 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler /// public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0) { - return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + var ctx = conversation.Executor.Context.NativeHandle; + if (conversation.MtmdLogitsIndex == -1) + return sampler.Sample(ctx, -1); + + return sampler.Sample(ctx, conversation.GetSampleIndex(offset)); } /// @@ -82,4 +90,4 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep return end.Value - count; }); } -} \ No newline at end of file +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 36989006e..212194bea 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -32,11 +32,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected int _consumedTokensCount; // n_consume /// - /// + /// Number of tokens consumed from the session cache during the current run. /// protected int _n_session_consumed; /// - /// + /// Number of prompt tokens that match the loaded session cache prefix. /// protected int _n_matching_session_tokens; /// @@ -52,7 +52,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected List _embed_inps = new(); /// - /// + /// Tokens recovered from the session file and reused to warm up the KV cache. /// protected List _session_tokens = new(); /// @@ -76,21 +76,21 @@ public bool IsMultiModal } /// - public LLavaWeights? ClipModel { get; } + public SafeMtmdWeights? ClipModel { get; } /// - public List Images { get; } + public List Embeds { get; } private readonly StreamingTokenDecoder _decoder; /// - /// + /// Initialize a stateful executor bound to a specific context. /// - /// - /// + /// LLama context used for all native interactions. + /// Optional logger for diagnostic output. protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) { - Images = new List(); + Embeds = new List(); _logger = logger; Context = context; _pastTokensCount = 0; @@ -101,22 +101,22 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) } /// - /// + /// Initialize a multimodal executor with the supplied MTMD weights. /// - /// - /// - /// - public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger? logger = null) : + /// LLama context used for all native interactions. + /// Multimodal weights to associate with this executor. + /// Optional logger for diagnostic output. + public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) : this( context, logger ) { - ClipModel = lLavaWeights; + ClipModel = safeMtmdWeights; } /// - /// This API is currently not verified. + /// Attach a session cache file so the executor can reuse previous KV state if compatible. /// - /// - /// + /// Path to the llama.cpp session file. + /// The current executor instance for fluent configuration. /// /// public StatefulExecutorBase WithSessionFile(string filename) @@ -173,9 +173,9 @@ public StatefulExecutorBase WithSessionFile(string filename) } /// - /// This API has not been verified currently. + /// Persist the current session cache to disk. /// - /// + /// Destination path for the llama.cpp session file. public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); @@ -203,7 +203,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) } /// - /// Try to reuse the matching prefix from the session file. + /// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens. /// protected virtual void TryReuseMatchingPrefix() { @@ -236,66 +236,66 @@ protected virtual void TryReuseMatchingPrefix() } /// - /// Decide whether to continue the loop. + /// Determine whether the inference loop should continue processing tokens. /// - /// - /// + /// Mutable state associated with the current inference. + /// true to continue generating; otherwise false. protected abstract Task GetLoopCondition(InferStateArgs args); /// - /// Preprocess the inputs before the inference. + /// Prepare the executor for inference by tokenizing input and updating cached state. /// - /// - /// + /// Prompt text to process. + /// Mutable state associated with the current inference. protected abstract Task PreprocessInputs(string? text, InferStateArgs args); /// - /// Do some post processing after the inference. + /// Perform any post-processing on the generated tokens. /// - /// - /// - /// + /// Parameters controlling sampling. + /// Mutable state associated with the current inference. + /// A tuple indicating whether generation should stop and any extra outputs to emit. protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); /// - /// The core inference logic. + /// Core inference loop that advances the model by one step. /// - /// - /// + /// Parameters controlling sampling. + /// Mutable state associated with the current inference. protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); /// - /// Save the current state to a file. + /// Save the executor state to a serialized snapshot file. /// - /// + /// Destination file for the serialized state. public abstract Task SaveState(string filename); /// - /// Get the current state data. + /// Capture the executor state in a serializable object. /// - /// + /// State snapshot suitable for persistence. public abstract ExecutorBaseState GetStateData(); /// - /// Load the state from data. + /// Restore executor state from a previously captured snapshot. /// - /// + /// State snapshot created by . public abstract Task LoadState(ExecutorBaseState data); /// - /// Load the state from a file. + /// Restore executor state from a serialized snapshot file. /// - /// + /// Path to the snapshot produced by . public abstract Task LoadState(string filename); /// - /// Execute the inference. + /// Execute an asynchronous inference session. /// - /// The prompt. If null, generation will continue where it left off previously. - /// - /// - /// + /// Optional prompt; when null generation resumes from prior state. + /// Sampling parameters to apply; defaults are used when null. + /// Cancellation token for cooperative cancellation. + /// Stream of decoded text segments as they become available. public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); @@ -370,12 +370,12 @@ public virtual async Task PrefillPromptAsync(string prompt) } /// - /// State arguments that are used in single inference + /// Mutable state passed between inference callbacks during a single generation pass. /// protected class InferStateArgs { /// - /// + /// Anti-prompts that terminate generation when encountered. /// public IList? Antiprompts { get; set; } /// @@ -383,20 +383,23 @@ protected class InferStateArgs /// public int RemainedTokens { get; set; } /// - /// + /// Indicates whether generated tokens should be returned to the caller. /// public bool ReturnValue { get; set; } /// - /// + /// Signals that the executor should pause and wait for additional user input. /// public bool WaitForInput { get; set; } /// - /// + /// Indicates whether the session cache should be persisted after inference completes. /// public bool NeedToSaveSession { get; set; } } #pragma warning disable CS1591, CS8618 // Missing XML and irrelevant nullable warnings + /// + /// Serializable snapshot of executor state used for persistence and restart. + /// [JsonConverter(typeof(PolymorphicJSONConverter))] public class ExecutorBaseState { @@ -434,5 +437,33 @@ public class ExecutorBaseState public float? MirostatMu { get; set; } } #pragma warning restore + + internal ExecutorDiagnostics GetDiagnostics() + { + return new ExecutorDiagnostics( + _embed_inps.Count, + _consumedTokensCount, + _pastTokensCount, + _embeds.Count); + } + } +} + +namespace LLama +{ + internal readonly struct ExecutorDiagnostics + { + public ExecutorDiagnostics(int embedCount, int consumedCount, int pastCount, int pendingEmbeds) + { + EmbedCount = embedCount; + ConsumedCount = consumedCount; + PastCount = pastCount; + PendingEmbedCount = pendingEmbeds; + } + + public int EmbedCount { get; } + public int ConsumedCount { get; } + public int PastCount { get; } + public int PendingEmbedCount { get; } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 331591fba..2069061d5 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; @@ -24,6 +25,9 @@ public class InstructExecutor private readonly string _instructionPrefix; private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; + private SafeMtmdInputChunks? _mtmdChunks; + private string? _mtmdMarker; + private readonly string _instructionSuffix; /// /// @@ -41,6 +45,20 @@ public InstructExecutor(LLamaContext context, _inp_pfx = Context.Tokenize(instructionPrefix, true, true); _inp_sfx = Context.Tokenize(instructionSuffix, false, true); _instructionPrefix = instructionPrefix; + _instructionSuffix = instructionSuffix; + } + + public InstructExecutor(LLamaContext context, + SafeMtmdWeights clipModel, + string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n", + ILogger? logger = null) + : base(context, clipModel, logger) + { + _inp_pfx = Context.Tokenize(instructionPrefix, true, true); + _inp_sfx = Context.Tokenize(instructionSuffix, false, true); + _instructionPrefix = instructionPrefix; + _instructionSuffix = instructionSuffix; } /// @@ -67,6 +85,7 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data) { + DisposeMtmdChunks(); if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; @@ -126,7 +145,14 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) { // When running the first input (prompt) in inteactive mode, we should specially process it. if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); - _embed_inps = Context.Tokenize(text, true, true).ToList(); + if (!IsMultiModal) + { + _embed_inps = Context.Tokenize(text, true, true).ToList(); + } + else + { + return PreprocessMtmd(text, args, addBos: true, replaceExisting: true); + } } else { @@ -139,20 +165,161 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) { text += "\n"; } - _embed_inps.AddRange(_inp_pfx); + if (!IsMultiModal) + { + _embed_inps.AddRange(_inp_pfx); - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); - _embed_inps.AddRange(_inp_sfx); + _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; + } + else + { + var builder = new StringBuilder(); + builder.Append(_instructionPrefix); + builder.Append(text); + builder.Append(_instructionSuffix); + return PreprocessMtmd(builder.ToString(), args, addBos: false, replaceExisting: false); + } } } return Task.CompletedTask; } + private void DisposeMtmdChunks() + { + _mtmdChunks?.Dispose(); + _mtmdChunks = null; + } + + private void DisposeEmbeds() + { + if (Embeds.Count == 0) + return; + + foreach (var embed in Embeds) + embed.Dispose(); + + Embeds.Clear(); + } + + private string GetMtmdMarker() + { + if (_mtmdMarker is not null) + return _mtmdMarker; + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default(LLamaToken); + } + + private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); + prompt = string.Concat(prompt, suffix); + } + } + + SafeMtmdInputChunks? chunks = null; + try + { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + _mtmdChunks = chunks; + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in scopedChunk.GetTextTokensSpan()) + tokens.Add(unchecked((int)token)); + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + + if (replaceExisting) + { + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; + } + else + { + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; + } + } + catch + { + chunks?.Dispose(); + _mtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); + } + + return Task.CompletedTask; + } + /// protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) { @@ -213,11 +380,43 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } + else if (IsMultiModal && _mtmdChunks is not null) + { + _is_prompt_run = false; + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { + if (inferenceParams.MaxTokens == 0) + { + _embeds.Clear(); + args.WaitForInput = true; + args.ReturnValue = false; + return; + } // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 7c9558ee3..a6ead60fa 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; +using LLama; using LLama.Exceptions; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -20,30 +21,30 @@ namespace LLama /// public class InteractiveExecutor : StatefulExecutorBase { + // Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn. private bool _is_prompt_run = true; - - // LLava - private int _EmbedImagePosition = -1; - private List _imageEmbedHandles = new List(); - private bool _imageInPrompt = false; + + // MTMD multimodal state + private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer. + private string? _mtmdMarker; // Cached multimodal marker returned by the native helper. /// - /// + /// Create an interactive executor for text-only inference. /// - /// - /// + /// LLama context to operate against. + /// Optional logger for diagnostic output. public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { } /// - /// + /// Create an interactive multimodal executor that can process text alongside media inputs. /// - /// - /// - /// - public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) + /// LLama context to operate against. + /// Multimodal weights (MTMD) to attach to the executor. + /// Optional logger for diagnostic output. + public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { } @@ -70,6 +71,7 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data) { + DisposeMtmdChunks(); if (data is InteractiveExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; @@ -108,15 +110,20 @@ public override async Task LoadState(string filename) } /// - /// Define whether to continue the loop to generate responses. + /// Decide whether generation should continue for the current iteration. /// - /// + /// Mutable inference state. + /// true to keep generating; otherwise false. protected override Task GetLoopCondition(InferStateArgs args) { return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } - /// + /// + /// Preprocess the incoming prompt or continuation text before inference. + /// + /// Prompt text or continuation provided by the caller. + /// Mutable inference state. protected override Task PreprocessInputs(string? text, InferStateArgs args) { if (_is_prompt_run) @@ -129,7 +136,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } else { - PreprocessLlava(text, args, true); + PreprocessMtmd(text, args, true); } } else @@ -150,7 +157,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } else { - PreprocessLlava(text, args, false); + PreprocessMtmd(text, args, false); } } } @@ -158,51 +165,171 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) return Task.CompletedTask; } - /// - private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true ) - { - // If the prompt contains the tag extract this. - _imageInPrompt = text.Contains(""); - if (_imageInPrompt && IsMultiModal) + /// + /// Release any queued multimodal chunks and reset state. + /// + private void DisposeMtmdChunks() + { + _mtmdChunks?.Dispose(); + _mtmdChunks = null; + } + + /// + /// Dispose and clear any pending multimodal embeddings queued for evaluation. + /// + private void DisposeEmbeds() + { + if (Embeds.Count == 0) + { + return; + } + + foreach (var embed in Embeds) + { + embed.Dispose(); + } + + Embeds.Clear(); + } + + /// + /// Retrieve the marker token used to signal media segments to the tokenizer. + /// + private string GetMtmdMarker() + { + if (_mtmdMarker is not null) + { + return _mtmdMarker; + } + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default(LLamaToken); + } + + /// + /// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers. + /// + /// Prompt text containing optional media markers. + /// Mutable inference state. + /// Whether to treat the prompt as a fresh run and add the BOS token. + private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true) + { + if (ClipModel is null) { - foreach (var image in Images) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + } + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) { - _imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel!.NativeHandle, Context, image)); + prompt = prompt.Replace("", marker); } - int imageIndex = text.IndexOf(""); - // Tokenize segment 1 (before tag) - string preImagePrompt = text.Substring(0, imageIndex); - var segment1 = Context.Tokenize(preImagePrompt, addBos, true); - // Remember the position to add the image embeddings - _EmbedImagePosition = segment1.Length; - string postImagePrompt = text.Substring(imageIndex + 7); - var segment2 = Context.Tokenize(postImagePrompt, false, true); - _embed_inps.AddRange(segment1); - _embed_inps.AddRange(segment2); + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed. + prompt = string.Concat(prompt, suffix); + } } - else + + SafeMtmdInputChunks? chunks = null; + try { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + _mtmdChunks = chunks; // Own the chunk collection until evaluation completes. + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + { + continue; + } + + foreach (var token in scopedChunk.GetTextTokensSpan()) + { + tokens.Add(unchecked((int)token)); + } + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + if (addBos) { - _embed_inps = Context.Tokenize(text, true, true).ToList(); + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; } else { - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; } } + catch + { + chunks?.Dispose(); + _mtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval. + } + return Task.CompletedTask; } /// - /// Return whether to break the generation. + /// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers. /// - /// - /// - /// + /// Sampling parameters controlling generation. + /// Mutable inference state. + /// Tuple describing whether to stop and any additional outputs to emit. protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) { if (_embed_inps.Count <= _consumedTokensCount) @@ -253,51 +380,87 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In HandleRunOutOfContext(tokensToKeep); } - TryReuseMatchingPrefix(); + if (_mtmdChunks is null) + { + TryReuseMatchingPrefix(); + } - // Changes to support Multi-Modal LLMs. - // - (DecodeResult, int, int) header, end, result; - if (IsMultiModal && _EmbedImagePosition > 0) + if (IsMultiModal && _mtmdChunks is not null) { - // Tokens previous to the images - header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = header.Item3; - - if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); - - // Images - foreach( var image in _imageEmbedHandles ) - ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount); - - // Post-image Tokens - end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = end.Item3; - - _EmbedImagePosition = -1; - _imageEmbedHandles.Clear(); - Images.Clear(); + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); } else { - result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = result.Item3; if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); + + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + { + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; + } } - + } + else if (IsMultiModal && _mtmdChunks is not null) + { + _is_prompt_run = false; + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) { - _session_tokens.AddRange(_embeds); + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); _n_session_consumed = _session_tokens.Count; } - } + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } + _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { + if (inferenceParams.MaxTokens == 0) + { + _embeds.Clear(); + args.WaitForInput = true; + args.ReturnValue = false; + return; + } // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { @@ -344,10 +507,10 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } /// - /// The descriptor of the state of the interactive executor. + /// Serializable state specific to the interactive executor. /// public class InteractiveExecutorState - : ExecutorBaseState + : StatefulExecutorBase.ExecutorBaseState { /// /// Whether the executor is running for the first time (running the prompt). diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index f53de7069..e827585b7 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -3,7 +3,7 @@ netstandard2.0;net8.0 LLama enable - 12 + 13 AnyCPU;x64;Arm64 True @@ -17,7 +17,7 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 LLama, LLM, GPT, ChatGPT, NLP, AI, Chat Bot, SciSharp - LLamaSharp is a cross-platform library to run 🦙LLaMA/LLaVA model (and others) in your local device. + LLamaSharp is a cross-platform library to run 🦙LLaMA/Mtmd model (and others) in your local device. Based on [llama.cpp](https://github.com/ggerganov/llama.cpp), inference with LLamaSharp is efficient on both CPU and GPU. With the higher-level APIs and RAG support, it's convenient to deploy LLM (Large Language Model) in your application with LLamaSharp. diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 8f9b40cc3..94bc60830 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -28,10 +28,10 @@ public class StatelessExecutor public bool IsMultiModal => false; /// - public LLavaWeights? ClipModel => default; + public SafeMtmdWeights? ClipModel => default; /// - public List Images { get; } + public List Embeds { get; } /// /// The context used by the executor when running the inference. @@ -57,7 +57,7 @@ public class StatelessExecutor /// public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { - Images = [ ]; + Embeds = [ ]; _weights = weights; _params = @params; _logger = logger; diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs deleted file mode 100644 index f2f9f6256..000000000 --- a/LLama/LLavaWeights.cs +++ /dev/null @@ -1,137 +0,0 @@ - -using System; -using System.Threading; -using System.Threading.Tasks; -using LLama.Native; - -namespace LLama; - -/// -/// A set of llava model weights (mmproj), loaded into memory. -/// -public sealed class LLavaWeights - : IDisposable -{ - /// - /// The native handle, which is used in the native APIs - /// - /// Be careful how you use this! - public SafeLlavaModelHandle NativeHandle { get; } - - private LLavaWeights(SafeLlavaModelHandle weights) - { - NativeHandle = weights; - } - - #region load - /// - /// Load weights into memory - /// - /// path to the "mmproj" model file - /// - public static LLavaWeights LoadFromFile(string mmProject) - { - var weights = SafeLlavaModelHandle.LoadFromFile(mmProject, 1); - return new LLavaWeights(weights); - } - - /// - /// Load weights into memory - /// - /// path to the "mmproj" model file - /// - /// - public static Task LoadFromFileAsync(string mmProject, CancellationToken token = default) - { - return Task.Run(() => LoadFromFile(mmProject), token); - } - #endregion - - #region embed - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) - { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) - { - return NativeHandle.CreateImageEmbeddings(image, threads); - } - - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) - { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image); - } - - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) - { - return NativeHandle.CreateImageEmbeddings(image, threads); - } - #endregion - - /// - /// Eval the image embeddings - /// - /// - /// - /// - /// - public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) - { - return NativeHandle.EvalImageEmbed( ctxLlama, imageEmbed, ref n_past ); - } - - /// - public void Dispose() - { - NativeHandle.Dispose(); - } - -} \ No newline at end of file diff --git a/LLama/Native/LLavaImageEmbed.cs b/LLama/Native/LLavaImageEmbed.cs deleted file mode 100644 index 65eba230c..000000000 --- a/LLama/Native/LLavaImageEmbed.cs +++ /dev/null @@ -1,19 +0,0 @@ -namespace LLama.Native; - -/// -/// LLaVa Image embeddings -/// -/// llava_image_embed -[StructLayout(LayoutKind.Sequential)] -public unsafe struct LLavaImageEmbed -{ - /// - /// The embeddings of the embedded image. - /// - public float* embed; - - /// - /// The position of the image's tokens. - /// - public int n_image_pos; -} \ No newline at end of file diff --git a/LLama/Native/Load/NativeLibraryConfig.cs b/LLama/Native/Load/NativeLibraryConfig.cs index 652b1da48..7f250d1c1 100644 --- a/LLama/Native/Load/NativeLibraryConfig.cs +++ b/LLama/Native/Load/NativeLibraryConfig.cs @@ -299,15 +299,15 @@ public sealed partial class NativeLibraryConfig public static NativeLibraryConfig LLama { get; } /// - /// Configuration for LLava native library + /// Configuration for Mtmd native library /// - public static NativeLibraryConfig LLava { get; } + public static NativeLibraryConfig Mtmd { get; } static NativeLibraryConfig() { LLama = new(NativeLibraryName.LLama); - LLava = new(NativeLibraryName.LLava); - All = new(LLama, LLava); + Mtmd = new(NativeLibraryName.Mtmd); + All = new(LLama, Mtmd); } #if NETSTANDARD2_0 @@ -413,9 +413,9 @@ public void ForEach(Action action) /// When this method is called, all the other configurations will be ignored. /// /// The full path to the llama library to load. - /// The full path to the llava library to load. + /// The full path to the mtmd library to load. /// Thrown if `LibraryHasLoaded` is true. - public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llavaPath) + public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? mtmdPath) { foreach(var config in _configs) { @@ -423,9 +423,9 @@ public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llava { config.WithLibrary(llamaPath); } - if(config.NativeLibraryName == NativeLibraryName.LLava && llavaPath is not null) + if(config.NativeLibraryName == NativeLibraryName.Mtmd && mtmdPath is not null) { - config.WithLibrary(llavaPath); + config.WithLibrary(mtmdPath); } } @@ -594,7 +594,7 @@ public NativeLibraryConfigContainer WithLogCallback(ILogger? logger) /// You can still modify the configuration after this calling but only before any call from . /// /// Whether the running is successful. - public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedLLavaNativeLibrary) + public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedMtmdNativeLibrary) { bool success = true; foreach(var config in _configs) @@ -604,16 +604,16 @@ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibr { loadedLLamaNativeLibrary = loadedLibrary; } - else if(config.NativeLibraryName == NativeLibraryName.LLava) + else if(config.NativeLibraryName == NativeLibraryName.Mtmd) { - loadedLLavaNativeLibrary = loadedLibrary; + loadedMtmdNativeLibrary = loadedLibrary; } else { throw new Exception("Unknown native library config during the dry run."); } } - loadedLLamaNativeLibrary = loadedLLavaNativeLibrary = null; + loadedLLamaNativeLibrary = loadedMtmdNativeLibrary = null; return success; } } @@ -628,9 +628,9 @@ public enum NativeLibraryName /// LLama, /// - /// The native library compiled from the LLaVA example of llama.cpp. + /// The native library compiled from the MTMD library of llama.cpp. /// - LLava + Mtmd } internal static class LibraryNameExtensions @@ -641,8 +641,8 @@ public static string GetLibraryName(this NativeLibraryName name) { case NativeLibraryName.LLama: return NativeApi.libraryName; - case NativeLibraryName.LLava: - return NativeApi.llavaLibraryName; + case NativeLibraryName.Mtmd: + return NativeApi.mtmdLibraryName; default: throw new ArgumentOutOfRangeException(nameof(name), name, null); } diff --git a/LLama/Native/Load/NativeLibraryUtils.cs b/LLama/Native/Load/NativeLibraryUtils.cs index 9f6457cd1..84ababc60 100644 --- a/LLama/Native/Load/NativeLibraryUtils.cs +++ b/LLama/Native/Load/NativeLibraryUtils.cs @@ -9,7 +9,7 @@ namespace LLama.Native internal static class NativeLibraryUtils { /// - /// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible + /// Try to load libllama/mtmd, using CPU feature detection to try and load a more specialised DLL if possible /// /// The library handle to unload later, or IntPtr.Zero if no library was loaded internal static IntPtr TryLoadLibrary(NativeLibraryConfig config, out INativeLibrary? loadedLibrary) diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs new file mode 100644 index 000000000..d83831d85 --- /dev/null +++ b/LLama/Native/MtmdContextParams.cs @@ -0,0 +1,148 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native; + +/// +/// Managed representation of the native mtmd_context_params structure used to configure multimodal helpers. +/// +public class MtmdContextParams +{ + /// + /// Whether GPU acceleration should be requested when available. + /// + public bool UseGpu { get; set; } + + /// + /// Whether timing information should be emitted by the native helper. + /// + public bool PrintTimings { get; set; } + + /// + /// Number of worker threads to dedicate to preprocessing and tokenization. + /// + public int NThreads { get; set; } + + /// + /// Verbosity level forwarded to llama.cpp logging (matches ggml_log_level). + /// + public int Verbosity { get; set; } + + /// + /// Marker token inserted into the text stream to reference an image embedding. + /// + public string? ImageMarker { get; set; } + + /// + /// Marker token inserted into the text stream to reference a generic media embedding. + /// + public string? MediaMarker { get; set; } + + /// + /// Create a managed copy of the native defaults returned by . + /// + public static MtmdContextParams Default() + { + var native = NativeApi.mtmd_context_params_default(); + return new MtmdContextParams + { + UseGpu = native.use_gpu, + PrintTimings = native.print_timings, + NThreads = native.n_threads, + Verbosity = native.verbosity, + ImageMarker = PtrToString(native.image_marker), + MediaMarker = PtrToString(native.media_marker) + }; + } + + private static string? PtrToString(IntPtr ptr) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var length = 0; + var current = (byte*)ptr; + while (current[length] != 0) + length++; + + if (length == 0) + return string.Empty; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + + /// + /// Convert the managed representation to a native structure, pinning strings for the duration of the scope. + /// + internal NativeScope ToNativeScope() => new(this); + + internal readonly struct NativeScope : IDisposable + { + public NativeApi.mtmd_context_params Value { get; } + + private readonly PinnedUtf8String? _imageMarker; + private readonly PinnedUtf8String? _mediaMarker; + + public NativeScope(MtmdContextParams managed) + { + _imageMarker = PinnedUtf8String.Create(managed.ImageMarker); + _mediaMarker = PinnedUtf8String.Create(managed.MediaMarker); + + var native = NativeApi.mtmd_context_params_default(); + native.use_gpu = managed.UseGpu; + native.print_timings = managed.PrintTimings; + native.n_threads = managed.NThreads; + native.verbosity = managed.Verbosity; + + if (_imageMarker is not null) + native.image_marker = _imageMarker.Pointer; + if (_mediaMarker is not null) + native.media_marker = _mediaMarker.Pointer; + + Value = native; + } + + public void Dispose() + { + _imageMarker?.Dispose(); + _mediaMarker?.Dispose(); + } + } +} + +/// +/// Helper that pins a managed string as UTF-8 for the lifetime of the instance. +/// +internal sealed class PinnedUtf8String : IDisposable +{ + private readonly byte[]? _buffer; + private readonly GCHandle _handle; + + private PinnedUtf8String(string value) + { + var bytes = Encoding.UTF8.GetBytes(value); + _buffer = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, _buffer, 0, bytes.Length); + _handle = GCHandle.Alloc(_buffer, GCHandleType.Pinned); + } + + public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value); + + public IntPtr Pointer => _buffer is null ? IntPtr.Zero : _handle.AddrOfPinnedObject(); + + public void Dispose() + { + if (_buffer is not null && _handle.IsAllocated) + _handle.Free(); + } +} diff --git a/LLama/Native/MtmdImageEmbed.cs b/LLama/Native/MtmdImageEmbed.cs new file mode 100644 index 000000000..7341b8563 --- /dev/null +++ b/LLama/Native/MtmdImageEmbed.cs @@ -0,0 +1,20 @@ +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Representation of the native llava_image_embed structure used to return image embeddings. +/// +[StructLayout(LayoutKind.Sequential)] +public unsafe struct MtmdImageEmbed +{ + /// + /// Pointer to the embedding buffer for the decoded image. + /// + public float* embed; + + /// + /// Number of sequence positions consumed by the image tokens associated with the embedding. + /// + public int n_image_pos; +} diff --git a/LLama/Native/NativeApi.LLava.cs b/LLama/Native/NativeApi.LLava.cs deleted file mode 100644 index 692e3f0ad..000000000 --- a/LLama/Native/NativeApi.LLava.cs +++ /dev/null @@ -1,63 +0,0 @@ -using System; - -namespace LLama.Native; - -public static partial class NativeApi -{ - /// - /// Sanity check for clip <-> llava embed size match - /// - /// LLama Context - /// Llava Model - /// True if validate successfully - [DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip); - - /// - /// Build an image embed from image file bytes - /// - /// SafeHandle to the Clip Model - /// Number of threads - /// Binary image in jpeg format - /// Bytes length of the image - /// SafeHandle to the Embeddings - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_bytes", - CallingConvention = CallingConvention.Cdecl)] - public static extern - SafeLlavaImageEmbedHandle llava_image_embed_make_with_bytes(SafeLlavaModelHandle ctx_clip, int n_threads, - byte[] image_bytes, int image_bytes_length); - - /// - /// Build an image embed from a path to an image filename - /// - /// SafeHandle to the Clip Model - /// Number of threads - /// Image filename (jpeg) to generate embeddings - /// SafeHandle to the embeddings - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_filename", CallingConvention = CallingConvention.Cdecl)] - public static extern - SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHandle ctx_clip, int n_threads, - [MarshalAs(UnmanagedType.LPStr)] string image_path); - - /// - /// Free an embedding made with llava_image_embed_make_* - /// - /// Embeddings to release - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_free", CallingConvention = CallingConvention.Cdecl)] - public static extern void llava_image_embed_free(IntPtr embed); - - /// - /// Write the image represented by embed into the llama context with batch size n_batch, starting at context - /// pos n_past. on completion, n_past points to the next position in the context after the image embed. - /// - /// Llama Context - /// Embedding handle - /// - /// - /// True on success - [DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past); - -} \ No newline at end of file diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 4555ed0d2..57bb2d146 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -16,7 +16,7 @@ static NativeApi() // Set flag to indicate that this point has been passed. No native library config can be done after this point. NativeLibraryConfig.LLama.LibraryHasLoaded = true; - NativeLibraryConfig.LLava.LibraryHasLoaded = true; + NativeLibraryConfig.Mtmd.LibraryHasLoaded = true; // Immediately make a call which requires loading the llama DLL. This method call // can't fail unless the DLL hasn't been loaded. @@ -45,7 +45,7 @@ static NativeApi() #if NET5_0_OR_GREATER private static IntPtr _loadedLlamaHandle; - private static IntPtr _loadedLlavaSharedHandle; + private static IntPtr _loadedMtmdHandle; #endif private static void SetDllImportResolver() @@ -72,15 +72,15 @@ private static void SetDllImportResolver() return _loadedLlamaHandle; } - if (name == "llava_shared") + if (name == "mtmd") { - // If we've already loaded llava return the handle that was loaded last time. - if (_loadedLlavaSharedHandle != IntPtr.Zero) - return _loadedLlavaSharedHandle; + // If we've already loaded Mtmd return the handle that was loaded last time. + if (_loadedMtmdHandle != IntPtr.Zero) + return _loadedMtmdHandle; // Try to load a preferred library, based on CPU feature detection - _loadedLlavaSharedHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLava, out _loadedLLavaLibrary); - return _loadedLlavaSharedHandle; + _loadedMtmdHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.Mtmd, out _loadedMtmdLibrary); + return _loadedMtmdHandle; } // Return null pointer to indicate that nothing was loaded. @@ -100,17 +100,17 @@ private static void SetDllImportResolver() return name switch { NativeLibraryName.LLama => _loadedLLamaLibrary, - NativeLibraryName.LLava => _loadedLLavaLibrary, + NativeLibraryName.Mtmd => _loadedMtmdLibrary, _ => throw new ArgumentException($"Library name {name} is not found.") }; } internal const string libraryName = "llama"; - internal const string llavaLibraryName = "llava_shared"; + internal const string mtmdLibraryName = "mtmd"; internal const string ggmlLibraryName = "ggml"; internal const string ggmlBaseLibraryName = "ggml-base"; private static INativeLibrary? _loadedLLamaLibrary = null; - private static INativeLibrary? _loadedLLavaLibrary = null; + private static INativeLibrary? _loadedMtmdLibrary = null; } } diff --git a/LLama/Native/NativeApi.Mtmd.cs b/LLama/Native/NativeApi.Mtmd.cs new file mode 100644 index 000000000..bfd6193c2 --- /dev/null +++ b/LLama/Native/NativeApi.Mtmd.cs @@ -0,0 +1,312 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native; + +/// +/// P/Invoke surface for MTMD (multimodal) helpers exposed by llama.cpp. +/// +public static partial class NativeApi +{ + /// + /// Convert a UTF-8 encoded native string pointer into a managed . + /// Returns null when the pointer is zero. + /// + public static string? PtrToStringUtf8(IntPtr ptr) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var current = (byte*)ptr; + var length = 0; + while (current[length] != 0) + length++; + + if (length == 0) + return string.Empty; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + + /// + /// Native context parameters returned by . + /// + [StructLayout(LayoutKind.Sequential)] + internal struct mtmd_context_params + { + [MarshalAs(UnmanagedType.I1)] public bool use_gpu; + [MarshalAs(UnmanagedType.I1)] public bool print_timings; + public int n_threads; + public int verbosity; + public IntPtr image_marker; + public IntPtr media_marker; + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_default_marker", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_default_marker(); + + /// + /// Retrieve the default multimodal marker text. + /// + public static string? MtmdDefaultMarker() + => PtrToStringUtf8(mtmd_default_marker()); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_context_params_default", CallingConvention = CallingConvention.Cdecl)] + internal static extern mtmd_context_params mtmd_context_params_default(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_non_causal", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_decode_use_non_causal(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_mrope", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_decode_use_mrope(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_vision", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_support_vision(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_audio", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_support_audio(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_audio_bitrate", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_get_audio_bitrate(IntPtr ctx); + + // bitmap ------------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_init(uint nx, uint ny, IntPtr data); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init_from_audio", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_init_from_audio(ulong n_samples, IntPtr data); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_nx", CallingConvention = CallingConvention.Cdecl)] + internal static extern uint mtmd_bitmap_get_nx(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_ny", CallingConvention = CallingConvention.Cdecl)] + internal static extern uint mtmd_bitmap_get_ny(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_data", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_get_data(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_n_bytes", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_bitmap_get_n_bytes(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_is_audio", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_bitmap_is_audio(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_bitmap_free(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_get_id(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_set_id", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe void mtmd_bitmap_set_id_native(IntPtr bitmap, byte* id); + + /// + /// Assign an identifier to a bitmap using a UTF-8 encoded string. + /// + internal static unsafe void mtmd_bitmap_set_id(IntPtr bitmap, string? id) + { + if (bitmap == IntPtr.Zero) + throw new ArgumentNullException(nameof(bitmap)); + + if (id is null) + { + mtmd_bitmap_set_id_native(bitmap, null); + return; + } + + using var pinned = PinnedUtf8String.Create(id) ?? throw new ArgumentNullException(nameof(id)); + mtmd_bitmap_set_id_native(bitmap, (byte*)pinned.Pointer); + } + + // input_chunks ------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_init", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunks_init(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_size", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_input_chunks_size(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_get", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunks_get(IntPtr chunks, UIntPtr idx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_input_chunks_free(IntPtr chunks); + + // input_chunk ------------------------------------------------------- + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_type", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_input_chunk_get_type(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_text", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_tokens_text(IntPtr chunk, out UIntPtr n_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_image", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_tokens_image(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_input_chunk_get_n_tokens(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_id(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_input_chunk_get_n_pos(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_copy", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_copy(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_input_chunk_free(IntPtr chunk); + + // image_tokens ------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_n_tokens(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_nx", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_nx(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_ny", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_ny(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_image_tokens_get_id(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_image_tokens_get_n_pos(IntPtr image_tokens); + + // tokenize ---------------------------------------------------------- + + /// + /// Native text structure consumed by . + /// + internal unsafe struct mtmd_input_text_native + { + public byte* text; + [MarshalAs(UnmanagedType.I1)] public bool add_special; + [MarshalAs(UnmanagedType.I1)] public bool parse_special; + } + + /// + /// Utility scope that pins managed text while invoking the native tokenizer. + /// + internal readonly unsafe ref struct MtmdInputTextScope + { + public readonly mtmd_input_text_native Value; + private readonly PinnedUtf8String _text; + + public MtmdInputTextScope(string text, bool addSpecial, bool parseSpecial) + { + _text = PinnedUtf8String.Create(text) ?? throw new ArgumentNullException(nameof(text)); + Value = new mtmd_input_text_native + { + text = (byte*)_text.Pointer, + add_special = addSpecial, + parse_special = parseSpecial + }; + } + + public void Dispose() => _text.Dispose(); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_tokenize", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe int mtmd_tokenize_native( + IntPtr ctx, + IntPtr output, + mtmd_input_text_native* text, + IntPtr[] bitmaps, + UIntPtr n_bitmaps); + + internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, in mtmd_input_text_native text, IntPtr[] bitmaps, UIntPtr n_bitmaps) + { + var temp = text; + return mtmd_tokenize_native(ctx, output, &temp, bitmaps, n_bitmaps); + } + + internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, string text, bool addSpecial, bool parseSpecial, IntPtr[] bitmaps, UIntPtr n_bitmaps) + { + using var scope = new MtmdInputTextScope(text, addSpecial, parseSpecial); + return mtmd_tokenize_native(ctx, output, &scope.Value, bitmaps, n_bitmaps); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_encode(IntPtr ctx, IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode_chunk", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_encode_chunk(IntPtr ctx, IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_output_embd", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_get_output_embd(IntPtr ctx); + + // helper ------------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_test_create_input_chunks", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_test_create_input_chunks(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_file", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe IntPtr mtmd_helper_bitmap_init_from_file_native(IntPtr ctx, byte* fname); + + internal static unsafe IntPtr mtmd_helper_bitmap_init_from_file(IntPtr ctx, string fname) + { + using var pinned = PinnedUtf8String.Create(fname) ?? throw new ArgumentNullException(nameof(fname)); + return mtmd_helper_bitmap_init_from_file_native(ctx, (byte*)pinned.Pointer); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_buf", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_helper_bitmap_init_from_buf(IntPtr ctx, IntPtr buf, UIntPtr len); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_helper_get_n_tokens(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_helper_get_n_pos(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunks", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_eval_chunks( + IntPtr ctx, + IntPtr lctx, + IntPtr chunks, + long n_past, + int seq_id, + int n_batch, + [MarshalAs(UnmanagedType.I1)] bool logits_last, + ref long new_n_past); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunk_single", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_eval_chunk_single( + IntPtr ctx, + IntPtr lctx, + IntPtr chunk, + long n_past, + int seq_id, + int n_batch, + [MarshalAs(UnmanagedType.I1)] bool logits_last, + ref long new_n_past); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_decode_image_chunk", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_decode_image_chunk( + IntPtr ctx, + IntPtr lctx, + IntPtr chunk, + IntPtr encoded_embd, + long n_past, + int seq_id, + int n_batch, + ref long new_n_past); +} diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index db9e928bd..3123674fc 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,4 +1,5 @@ using System; +using System.Text; #pragma warning disable IDE1006 // Naming Styles @@ -323,21 +324,115 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// /// Returns the split_path length. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_path")] + private static extern unsafe int llama_split_path_native(byte* split_path, nuint maxlen, byte* path_prefix, int split_no, int split_count); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_prefix")] + private static extern unsafe int llama_split_prefix_native(byte* split_prefix, nuint maxlen, byte* split_path, int split_no, int split_count); + + private static byte[] EncodeNullTerminatedUtf8(string value, string paramName) + { + if (value is null) + throw new ArgumentNullException(paramName); + + var bytes = Encoding.UTF8.GetBytes(value); + var buffer = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length); + // buffer[^1] = 0; + return buffer; + } /// - /// Extract the path prefix from the split_path if and only if the split_no and split_count match. - /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" + /// Build the fully-qualified path for a specific split file in a GGUF shard set. /// - /// - /// - /// - /// - /// - /// Returns the split_prefix length. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count); + /// Writable buffer that receives the UTF-8 encoded path. + /// Base path (e.g. "/models/ggml-model-q4_0"). + /// Zero-based split index. + /// Total number of splits. + /// Number of bytes written to . + public static int llama_split_path(Span splitPathBuffer, string pathPrefix, int splitNo, int splitCount) + { + if (splitPathBuffer.Length == 0) + throw new ArgumentException("Buffer must not be empty.", nameof(splitPathBuffer)); + + var pathPrefixBytes = EncodeNullTerminatedUtf8(pathPrefix, nameof(pathPrefix)); + + unsafe + { + fixed (byte* splitPtr = splitPathBuffer) + fixed (byte* prefixPtr = pathPrefixBytes) + { + return llama_split_path_native(splitPtr, (nuint)splitPathBuffer.Length, prefixPtr, splitNo, splitCount); + } + } + } + + /// + /// Build the fully-qualified path for a specific split file in a GGUF shard set. + /// + /// Base path (e.g. "/models/ggml-model-q4_0"). + /// Zero-based split index. + /// Total number of splits. + /// Maximum number of bytes to allocate for the resulting UTF-8 string. + /// UTF-8 decoded split path. + public static string llama_split_path(string pathPrefix, int splitNo, int splitCount, int maxLength = 1024) + { + if (maxLength <= 0) + throw new ArgumentOutOfRangeException(nameof(maxLength)); + + var buffer = new byte[maxLength]; + var written = llama_split_path((Span)buffer, pathPrefix, splitNo, splitCount); + if (written <= 0) + throw new InvalidOperationException("Failed to build split path using llama_split_path."); + + return Encoding.UTF8.GetString(buffer, 0, written); + } + + /// + /// Extract the shard prefix from a GGUF split path when the split metadata matches. + /// + /// Writable buffer that receives the UTF-8 encoded prefix. + /// Full path to a shard file. + /// Zero-based split index. + /// Total number of splits. + /// Number of bytes written to . + public static int llama_split_prefix(Span splitPrefixBuffer, string splitPath, int splitNo, int splitCount) + { + if (splitPrefixBuffer.Length == 0) + throw new ArgumentException("Buffer must not be empty.", nameof(splitPrefixBuffer)); + + var splitPathBytes = EncodeNullTerminatedUtf8(splitPath, nameof(splitPath)); + + unsafe + { + fixed (byte* prefixPtr = splitPrefixBuffer) + fixed (byte* pathPtr = splitPathBytes) + { + return llama_split_prefix_native(prefixPtr, (nuint)splitPrefixBuffer.Length, pathPtr, splitNo, splitCount); + } + } + } + + /// + /// Extract the shard prefix from a GGUF split path when the split metadata matches. + /// + /// Full path to a shard file. + /// Zero-based split index. + /// Total number of splits. + /// Maximum number of bytes to allocate for the resulting UTF-8 string. + /// UTF-8 decoded split prefix. + public static string llama_split_prefix(string splitPath, int splitNo, int splitCount, int maxLength = 1024) + { + if (maxLength <= 0) + throw new ArgumentOutOfRangeException(nameof(maxLength)); + + var buffer = new byte[maxLength]; + var written = llama_split_prefix((Span)buffer, splitPath, splitNo, splitCount); + if (written <= 0) + throw new InvalidOperationException("Failed to extract split prefix using llama_split_prefix."); + + return Encoding.UTF8.GetString(buffer, 0, written); + } //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); diff --git a/LLama/Native/SafeLlavaImageEmbedHandle.cs b/LLama/Native/SafeLlavaImageEmbedHandle.cs deleted file mode 100644 index 102c4b93f..000000000 --- a/LLama/Native/SafeLlavaImageEmbedHandle.cs +++ /dev/null @@ -1,162 +0,0 @@ -using System; -using System.IO; - - -namespace LLama.Native -{ - /// - /// A Reference to a llava Image Embed handle - /// - public sealed class SafeLlavaImageEmbedHandle - : SafeLLamaHandleBase - { - /// - /// Get the model used to create this image embedding - /// - public SafeLlavaModelHandle Model { get; private set; } = null!; - - /// - /// Get the number of dimensions in an embedding - /// - public int EmbeddingDimensions => Model.EmbeddingDimensions; - - /// - /// Get the number of "patches" in an image embedding - /// - public int PatchCount => Model.PatchCount; - - #region embed - /// - /// Create an image embed from an image file - /// - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image) - { - if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) - throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); - - return CreateFromFileName(clip, image, (int)ctx.BatchThreads); - } - - /// - /// Create an image embed from an image file - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1) - { - if (threads <= 0) - threads = Environment.ProcessorCount / 2; - - // Try to open the image file, this will check: - // - File exists (automatically throws FileNotFoundException) - // - File is readable (explicit check) - // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(image, FileMode.Open)) - if (!fs.CanRead) - throw new InvalidOperationException($"Llava image file '{image}' is not readable"); - - var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image); - embed.Model = clip; - return embed; - } - - /// - /// Create an image embed from the bytes of an image. - /// - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image) - { - if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) - throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); - - return CreateFromMemory(clip, image, (int)ctx.BatchThreads); - } - - /// - /// Create an image embed from the bytes of an image. - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1) - { - if (threads <= 0) - threads = Environment.ProcessorCount / 2; - - var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length); - embed.Model = clip; - return embed; - } - #endregion - - /// - protected override bool ReleaseHandle() - { - NativeApi.llava_image_embed_free(DangerousGetHandle()); - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// Copy the embeddings data to the destination span - /// - /// - /// - public void GetEmbedding(Span dest, int index) - { - if (index < 0) - throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0"); - if (index >= Model.PatchCount) - throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount"); - - unsafe - { - var embed = (LLavaImageEmbed*)DangerousGetHandle(); - new Span( - embed->embed + Model.EmbeddingDimensions * index, - Model.EmbeddingDimensions - ).CopyTo(dest); - } - } - } -} diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs deleted file mode 100644 index 5b3a910e9..000000000 --- a/LLama/Native/SafeLlavaModelHandle.cs +++ /dev/null @@ -1,137 +0,0 @@ -using System; -using System.IO; -using LLama.Exceptions; - - -namespace LLama.Native -{ - /// - /// A reference to a set of llava model weights. - /// - public sealed class SafeLlavaModelHandle - : SafeLLamaHandleBase - { - /// - /// Get the number of dimensions in an embedding - /// - public int EmbeddingDimensions => clip_n_mmproj_embd(this); - - /// - /// Get the number of "patches" in an image embedding - /// - public int PatchCount => clip_n_patches(this); - - /// - protected override bool ReleaseHandle() - { - clip_free(DangerousGetHandle()); - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// Load a model from the given file path into memory - /// - /// MMP File (Multi-Modal Projections) - /// Verbosity level - /// SafeHandle of the Clip Model - /// - /// - public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity ) - { - // Try to open the model file, this will check: - // - File exists (automatically throws FileNotFoundException) - // - File is readable (explicit check) - // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(modelPath, FileMode.Open)) - if (!fs.CanRead) - throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable"); - - var handle = clip_model_load(modelPath, verbosity); - if (handle.IsInvalid) - throw new LoadWeightsFailedException(modelPath); - - return handle; - } - - /// - /// Create the Image Embeddings. - /// - /// LLama Context - /// Image filename (it supports jpeg format only) - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) - { - return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) - { - return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads); - } - - /// - /// Create the Image Embeddings. - /// - /// LLama Context - /// Image in binary format (it supports jpeg format only) - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) - { - return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image ); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) - { - return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads); - } - - /// - /// Evaluates the image embeddings. - /// - /// Llama Context - /// The current embeddings to evaluate - /// - /// True on success - public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) - { - return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.BatchSize, ref n_past ); - } - - #region native API - /// - /// Load MULTI MODAL PROJECTIONS model / Clip Model - /// - /// Model path/file - /// Verbosity level - /// SafeLlavaModelHandle - [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_model_load", CallingConvention = CallingConvention.Cdecl)] - private static extern SafeLlavaModelHandle clip_model_load(string mmProj, int verbosity); - - /// - /// Frees MULTI MODAL PROJECTIONS model / Clip Model - /// - /// Internal Pointer to the model - [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)] - private static extern void clip_free(IntPtr ctx); - - [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx); - - [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern int clip_n_patches(SafeLlavaModelHandle ctx); - #endregion - } -} diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs new file mode 100644 index 000000000..c651db102 --- /dev/null +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -0,0 +1,247 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + /// + /// Managed wrapper around mtmd_bitmap* resources. Instances own the native pointer + /// and ensure proper cleanup when disposed. + /// + public sealed class SafeMtmdEmbed : IDisposable + { + /// + /// Raw pointer to the native bitmap structure. Internal so other wrappers can interop. + /// + internal IntPtr NativePtr { get; private set; } + + private bool _disposed; + + private SafeMtmdEmbed(IntPtr ptr) + { + NativePtr = ptr != IntPtr.Zero + ? ptr + : throw new InvalidOperationException("Failed to create MTMD bitmap."); + } + + /// + /// Create an embedding from raw RGB bytes. + /// + /// Width of the bitmap in pixels. + /// Height of the bitmap in pixels. + /// Packed RGB data (3 bytes per pixel). + /// Managed wrapper when initialization succeeds; otherwise null. + /// The RGB buffer is null. + public static SafeMtmdEmbed? FromRgbBytes(uint nx, uint ny, byte[] rgbData) + { + if (rgbData == null) + throw new ArgumentNullException(nameof(rgbData)); + + var handle = GCHandle.Alloc(rgbData, GCHandleType.Pinned); + try + { + var native = NativeApi.mtmd_bitmap_init(nx, ny, handle.AddrOfPinnedObject()); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (handle.IsAllocated) + handle.Free(); + } + } + + /// + /// Create an embedding from PCM audio samples. + /// + /// Array of mono PCM samples in float format. + /// Managed wrapper when initialization succeeds; otherwise null. + /// The audio buffer is null. + public static SafeMtmdEmbed? FromAudioSamples(float[] samples) + { + if (samples == null) + throw new ArgumentNullException(nameof(samples)); + + var handle = GCHandle.Alloc(samples, GCHandleType.Pinned); + try + { + var native = NativeApi.mtmd_bitmap_init_from_audio((ulong)samples.Length, handle.AddrOfPinnedObject()); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (handle.IsAllocated) + handle.Free(); + } + } + + /// + /// Create an embedding by decoding a media file using libmtmd helpers. + /// + /// Model context that provides the decoder configuration. + /// Path to the media file on disk. + /// Managed wrapper when decoding succeeds; otherwise null. + /// The context is null. + /// The path is null or whitespace. + /// The supplied file does not exist. + public static SafeMtmdEmbed? FromMediaFile(SafeMtmdModelHandle mtmdContext, string path) + { + if (mtmdContext == null) + throw new ArgumentNullException(nameof(mtmdContext)); + if (string.IsNullOrWhiteSpace(path)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(path)); + + var fullPath = Path.GetFullPath(path); + if (!File.Exists(fullPath)) + throw new FileNotFoundException("Media file not found.", fullPath); + + bool added = false; + var ctxPtr = IntPtr.Zero; + try + { + // Hold a strong reference to the native context while the helper decodes the media file. + mtmdContext.DangerousAddRef(ref added); + ctxPtr = mtmdContext.DangerousGetHandle(); + var native = NativeApi.mtmd_helper_bitmap_init_from_file(ctxPtr, fullPath); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (added) + mtmdContext.DangerousRelease(); + } + } + + /// + /// Create an embedding from an in-memory media buffer (image/audio/video). + /// + /// Model context that provides the decoder configuration. + /// Binary buffer containing the encoded media. + /// Managed wrapper when decoding succeeds; otherwise null. + /// The context is null. + /// The buffer is empty. + public static unsafe SafeMtmdEmbed? FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan data) + { + if (mtmdContext == null) + throw new ArgumentNullException(nameof(mtmdContext)); + if (data.IsEmpty) + throw new ArgumentException("Buffer must not be empty.", nameof(data)); + + bool added = false; + var ctxPtr = IntPtr.Zero; + try + { + // Keep the context alive while the native helper processes the buffer. + mtmdContext.DangerousAddRef(ref added); + ctxPtr = mtmdContext.DangerousGetHandle(); + + fixed (byte* bufferPtr = data) + { + var native = NativeApi.mtmd_helper_bitmap_init_from_buf(ctxPtr, new IntPtr(bufferPtr), (UIntPtr)data.Length); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + } + finally + { + if (added) + mtmdContext.DangerousRelease(); + } + } + + /// + /// Width of the bitmap in pixels (or number of samples for audio embeddings). + /// + public uint Nx + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_get_nx(NativePtr); + } + } + + /// + /// Height of the bitmap in pixels. For audio embeddings this is typically 1. + /// + public uint Ny + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_get_ny(NativePtr); + } + } + + /// + /// Indicates whether the embedding stores audio data instead of image pixels. + /// + public bool IsAudio + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_is_audio(NativePtr); + } + } + + /// + /// Optional identifier assigned to this embedding. + /// + public string? Id + { + get + { + EnsureNotDisposed(); + var ptr = NativeApi.mtmd_bitmap_get_id(NativePtr); + return NativeApi.PtrToStringUtf8(ptr); + } + set + { + EnsureNotDisposed(); + NativeApi.mtmd_bitmap_set_id(NativePtr, value); + } + } + + /// + /// Zero-copy access to the underlying bitmap bytes. The span remains valid while this wrapper is alive. + /// + /// Read-only span exposing the native data buffer. + /// The embedding has been disposed. + public unsafe ReadOnlySpan GetDataSpan() + { + EnsureNotDisposed(); + + var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(NativePtr); + var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(NativePtr).ToUInt64()); + return dataPtr == null || length == 0 ? ReadOnlySpan.Empty : new ReadOnlySpan(dataPtr, length); + } + + /// + /// Release the underlying native bitmap. + /// + public void Dispose() + { + if (_disposed) + return; + + if (NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_bitmap_free(NativePtr); + NativePtr = IntPtr.Zero; + } + + _disposed = true; + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native resources are reclaimed when Dispose is not invoked. + /// + ~SafeMtmdEmbed() => Dispose(); + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); + } + } +} diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs new file mode 100644 index 000000000..59d1897ef --- /dev/null +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -0,0 +1,150 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Managed wrapper around a single mtmd_input_chunk. Instances can either own the +/// underlying native pointer (when created via ) or act as non-owning views +/// produced by the tokenizer. +/// +public sealed class SafeMtmdInputChunk : IDisposable +{ + /// + /// Chunk modality returned by the native tokenizer. + /// + public enum SafeMtmdInputChunkType + { + Text = 0, + Image = 1, + Audio = 2 + } + + /// + /// Raw pointer to the native chunk structure. + /// + public IntPtr NativePtr { get; private set; } + + private bool _ownsPtr; + private bool _disposed; + + private SafeMtmdInputChunk(IntPtr ptr, bool owns) + { + NativePtr = ptr; + _ownsPtr = owns; + } + + /// + /// Wrap an existing chunk pointer without taking ownership. + /// + /// Pointer returned by the native tokenizer. + /// Managed wrapper, or null when the pointer is null. + public static SafeMtmdInputChunk Wrap(IntPtr ptr) + => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, false); + + /// + /// Create an owning copy of the current chunk. The caller becomes responsible for disposal. + /// + /// Owning managed wrapper, or null if the native copy failed. + /// Thrown when the current wrapper has been disposed. + public SafeMtmdInputChunk Copy() + { + EnsureNotDisposed(); + + var p = NativeApi.mtmd_input_chunk_copy(NativePtr); + return p == IntPtr.Zero ? null : new SafeMtmdInputChunk(p, true); + } + + /// + /// Chunk modality reported by the native helper. + /// + public SafeMtmdInputChunkType Type + { + get + { + EnsureNotDisposed(); + return (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(NativePtr); + } + } + + /// + /// Number of tokens contained in this chunk. + /// + public ulong NTokens + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunk_get_n_tokens(NativePtr).ToUInt64(); + } + } + + /// + /// Identifier assigned by the tokenizer (if any). + /// + public string Id + { + get + { + EnsureNotDisposed(); + return Marshal.PtrToStringAnsi(NativeApi.mtmd_input_chunk_get_id(NativePtr)) ?? string.Empty; + } + } + + /// + /// Number of positional slots consumed by this chunk. + /// + public long NPos + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunk_get_n_pos(NativePtr); + } + } + + /// + /// Zero-copy view over the chunk's token buffer. The span remains valid only while the native chunk is alive. + /// + /// Read-only span exposing the chunk's tokens. + /// Thrown when the wrapper has been disposed. + public unsafe ReadOnlySpan GetTextTokensSpan() + { + EnsureNotDisposed(); + + UIntPtr n; + var p = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(NativePtr, out n); + return p == null ? ReadOnlySpan.Empty : new ReadOnlySpan(p, checked((int)n.ToUInt64())); + } + + /// + /// Release the underlying native resources if this instance owns them. + /// + public void Dispose() + { + if (_disposed) + return; + + if (_ownsPtr && NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_input_chunk_free(NativePtr); + } + + NativePtr = IntPtr.Zero; + _ownsPtr = false; + _disposed = true; + + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native memory is reclaimed when Dispose is not called by owners. + /// + ~SafeMtmdInputChunk() => Dispose(); + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); + } +} diff --git a/LLama/Native/SafeMtmdInputChunks.cs b/LLama/Native/SafeMtmdInputChunks.cs new file mode 100644 index 000000000..2081cd0a6 --- /dev/null +++ b/LLama/Native/SafeMtmdInputChunks.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Native; + +/// +/// Managed lifetime wrapper around a native mtmd_input_chunks collection returned by the tokenizer. +/// +public sealed class SafeMtmdInputChunks : IDisposable +{ + /// + /// Raw pointer to the native chunk collection. Internal to allow other wrappers to interop safely. + /// + internal IntPtr NativePtr { get; private set; } + + private bool _disposed; + + internal SafeMtmdInputChunks(IntPtr ptr) + { + NativePtr = ptr; + } + + /// + /// Releases the native chunk collection and suppresses finalization. + /// + public void Dispose() + { + if (_disposed) + return; + + if (NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_input_chunks_free(NativePtr); + NativePtr = IntPtr.Zero; + } + + _disposed = true; + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native memory is reclaimed if Dispose is not called. + /// + ~SafeMtmdInputChunks() + { + Dispose(); + } + + /// + /// Number of chunks currently held by the native collection. + /// + public ulong Size + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunks_size(NativePtr).ToUInt64(); + } + } + + /// + /// Get a raw pointer to a chunk. The returned is the mtmd_input_chunk*. + /// Use to create a managed wrapper if desired. + /// + /// Zero-based index of the chunk to retrieve. + /// Pointer to the requested chunk. + /// The collection has already been disposed. + /// The requested index is outside of the valid range. + public IntPtr GetChunkPtr(ulong index) + { + EnsureNotDisposed(); + + if (index >= Size) throw new IndexOutOfRangeException(); + return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index); + } + + /// + /// Enumerate the contained chunks as non-owning wrappers. Callers should dispose the returned chunk + /// if they create a copy. + /// + /// Enumeration of chunk wrappers backed by the native collection. + /// The collection has already been disposed. + public IEnumerable Enumerate() + { + EnsureNotDisposed(); + + for (ulong i = 0; i < Size; i++) + { + var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i)); + if (chunk != null) + { + // Yield a lightweight wrapper; ownership remains with the native collection. + yield return chunk; + } + } + } + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); + } +} diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs new file mode 100644 index 000000000..236a22011 --- /dev/null +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -0,0 +1,349 @@ +using System; +using System.Collections.Generic; +using System.IO; +using LLama.Exceptions; + + +namespace LLama.Native +{ + /// + /// Wrapper to the Multi Modal Weights handle. This wrapper manages the low level + /// operations. + /// + public sealed class SafeMtmdModelHandle : SafeLLamaHandleBase + { + // Pending media embeddings queued for the next call to Tokenize. + private readonly List _pendingMedia = new(); + + /// + protected override bool ReleaseHandle() + { + mtmd_free(DangerousGetHandle()); + SetHandle(IntPtr.Zero); + return true; + } + + /// + /// Load a multimodal projection model from disk and bind it to the supplied text model. + /// + /// Path to the MMP (Multi-Modal Projections) file. + /// Text model that provides tokenizer weights for the multimodal helper. + /// Optional context parameters; defaults are used when null. + /// Safe handle for the MTMD model. + /// The file exists but is not readable by the current process. + /// The native loader failed to initialize the MTMD model. + public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + { + // Try to open the model file, this will check: + // - File exists (automatically throws FileNotFoundException) + // - File is readable (explicit check) + // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. + using (var fs = new FileStream(modelPath, FileMode.Open)) + if (!fs.CanRead) + throw new InvalidOperationException($"Mtmd MMP Model file '{modelPath}' is not readable"); + + using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath)); + + unsafe + { + SafeMtmdModelHandle handle; + if (mtmdCtxParams is null) + { + var nativeParams = NativeApi.mtmd_context_params_default(); + handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParams); + } + else + { + using var nativeParamsScope = mtmdCtxParams.ToNativeScope(); + handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParamsScope.Value); + } + + if (handle.IsInvalid) + throw new LoadWeightsFailedException(modelPath); + + return handle; + } + } + + /// + /// Load media from disk and queue it for the next tokenize call. + /// + /// Absolute or relative path to the media asset. + /// Safe handle to the media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the file. + public SafeMtmdEmbed LoadMediaFromFile(string path) + { + EnsureNotDisposed(); + + var embed = SafeMtmdEmbed.FromMediaFile(this, path) + ?? throw new RuntimeError($"Failed to load media '{path}'."); + _pendingMedia.Add(embed); + return embed; + } + + /// + /// Load media from an in-memory buffer and queue it for the next tokenize call. + /// + /// Binary buffer containing the encoded media data. + /// Safe handle to the media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the buffer contents. + public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer) + { + EnsureNotDisposed(); + + var embed = SafeMtmdEmbed.FromMediaBuffer(this, buffer) + ?? throw new RuntimeError("Failed to load media from buffer."); + _pendingMedia.Add(embed); + return embed; + } + + /// + /// Disposes and clears any media buffers currently queued for tokenization. + /// + public void ClearMedia() + { + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + + /// + /// Tokenize a prompt alongside the pending media buffers. Pending media is cleared on success. + /// + /// Prompt text to tokenize. + /// Whether to append special tokens automatically. + /// Whether special tokens should be treated as user-provided text. + /// Receives the native chunk collection when tokenization succeeds. + /// Zero on success; otherwise the native mtmd tokenize error code. + /// The model handle has been disposed. + public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) + { + EnsureNotDisposed(); + + chunks = null; + // Allocate the chunk container before invoking the native tokenizer. + var output = NativeApi.mtmd_input_chunks_init(); + if (output == IntPtr.Zero) + throw new RuntimeError("Failed to allocate mtmd_input_chunks."); + + // Collect native pointers to the queued media embeddings. + var bitmapHandles = new IntPtr[_pendingMedia.Count]; + for (var i = 0; i < _pendingMedia.Count; i++) + bitmapHandles[i] = _pendingMedia[i].NativePtr; + + var result = NativeApi.mtmd_tokenize(DangerousGetHandle(), output, text, addSpecial, parseSpecial, bitmapHandles, (UIntPtr)bitmapHandles.Length); + + if (result == 0) + { + chunks = new SafeMtmdInputChunks(output); + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + else + { + NativeApi.mtmd_input_chunks_free(output); + } + + if (result != 0) + { + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + + return result; + } + + /// + /// Evaluate a batch of chunks using the helper (mirrors mtmd-helper eval logic). + /// + /// Chunk collection produced by . + /// Context handle that receives the evaluated tokens. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Whether to request logits for the last token only. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + { + EnsureNotDisposed(); + + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + if (llamaContext == null) + throw new ArgumentNullException(nameof(llamaContext)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_eval_chunks( + DangerousGetHandle(), + llamaContext.DangerousGetHandle(), + chunks.NativePtr, + nPast, + seqId, + nBatch, + logitsLast, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Evaluate a single chunk helper. + /// + /// Pointer to the chunk to evaluate. + /// Context handle that receives the evaluated tokens. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Whether to request logits for the last token only. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + { + EnsureNotDisposed(); + + if (chunkPtr == IntPtr.Zero) + throw new ArgumentNullException(nameof(chunkPtr)); + if (llamaContext == null) + throw new ArgumentNullException(nameof(llamaContext)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_eval_chunk_single( + DangerousGetHandle(), + llamaContext.DangerousGetHandle(), + chunkPtr, + nPast, + seqId, + nBatch, + logitsLast, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Decode a prepared image chunk whose embedding is already computed. + /// + /// Pointer to the chunk whose embedding should be decoded. + /// Context handle used for decoding. + /// Pointer to the pre-computed embedding data. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch) + { + EnsureNotDisposed(); + + if (chunkPtr == IntPtr.Zero) + throw new ArgumentNullException(nameof(chunkPtr)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_decode_image_chunk( + DangerousGetHandle(), + llamaContext?.DangerousGetHandle() ?? throw new ArgumentNullException(nameof(llamaContext)), + chunkPtr, + encodedEmbeddings, + nPast, + seqId, + nBatch, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Get the number of tokens contained in the provided chunk collection. + /// + /// Chunk collection produced by . + /// Total token count. + public ulong CountTokens(SafeMtmdInputChunks chunks) + { + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + return NativeApi.mtmd_helper_get_n_tokens(chunks.NativePtr).ToUInt64(); + } + + /// + /// Get the number of positions contained in the provided chunk collection. + /// + /// Chunk collection produced by . + /// Total number of positional slots consumed. + public long CountPositions(SafeMtmdInputChunks chunks) + { + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + return NativeApi.mtmd_helper_get_n_pos(chunks.NativePtr); + } + + #region native API + + // mtmd_init_from_file(const char * mmproj_fname, const struct llama_model * text_model, const struct mtmd_context_params ctx_params); + // The llama_model layout is opaque; expose it via SafeLlamaModelHandle to match the managed wrapper. + [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_init_from_file", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe SafeMtmdModelHandle mtmd_init_from_file( + byte* mmproj_fname, + SafeLlamaModelHandle text_model, + NativeApi.mtmd_context_params @ctx_params); + + [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_free(IntPtr ctx); + + #endregion + + + + /// + /// Finalizer to ensure native resources are released if Dispose was not called. + /// + ~SafeMtmdModelHandle() + { + Dispose(); + } + + /// + /// Indicates whether the model decodes using the non-causal path. + /// + public bool DecodeUseNonCausal() => NativeApi.mtmd_decode_use_non_causal(handle); + + /// + /// Indicates whether the model decodes using multi-scale RoPE. + /// + public bool DecodeUseMRope() => NativeApi.mtmd_decode_use_mrope(handle); + + /// + /// Indicates whether the model supports vision inputs. + /// + public bool SupportVision() => NativeApi.mtmd_support_vision(handle); + + /// + /// Indicates whether the model supports audio inputs. + /// + public bool SupportAudio() => NativeApi.mtmd_support_audio(handle); + + /// + /// Gets the audio bitrate advertised by the model. + /// + public int GetAudioBitrate() => NativeApi.mtmd_get_audio_bitrate(handle); + + private void EnsureNotDisposed() + { + if (IsInvalid || IsClosed) + throw new ObjectDisposedException(nameof(SafeMtmdModelHandle)); + } + } +} diff --git a/LLama/Properties/InternalsVisibleTo.cs b/LLama/Properties/InternalsVisibleTo.cs new file mode 100644 index 000000000..b0a1ac4be --- /dev/null +++ b/LLama/Properties/InternalsVisibleTo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("LLama.Unittest")] diff --git a/LLama/SafeMtmdWeights.cs b/LLama/SafeMtmdWeights.cs new file mode 100644 index 000000000..e490049b4 --- /dev/null +++ b/LLama/SafeMtmdWeights.cs @@ -0,0 +1,80 @@ + +using System; +using System.Threading; +using System.Threading.Tasks; +using LLama.Native; + +namespace LLama; + +/// +/// Lightweight wrapper around the MTMD native context and its helpers. +/// +public sealed class SafeMtmdWeights : IDisposable +{ + public SafeMtmdModelHandle NativeHandle { get; } + + private SafeMtmdWeights(SafeMtmdModelHandle handle) + { + NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle)); + } + + public static SafeMtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + { + if (mmProject == null) throw new ArgumentNullException(nameof(mmProject)); + if (textModel == null) throw new ArgumentNullException(nameof(textModel)); + if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams)); + + var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams); + return new SafeMtmdWeights(handle); + } + + public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) + { + return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token); + } + + /// + /// Load media from disk and keep it pending for the next tokenize call. + /// + public SafeMtmdEmbed LoadMedia(string path) => NativeHandle.LoadMediaFromFile(path); + + /// + /// Load media from an in-memory buffer and keep it pending for the next tokenize call. + /// + public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) => NativeHandle.LoadMediaFromBuffer(data); + + /// + /// Clear any pending media buffers before or after tokenization. + /// + public void ClearMedia() => NativeHandle.ClearMedia(); + + /// + /// Tokenize text (with optional special tokens) against the pending media buffers. + /// + public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) + => NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + + /// + /// Evaluate a chunk batch using the helper that performs mtmd encode + llama decode. + /// + public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + => NativeHandle.EvaluateChunks(chunks, llamaContext, ref nPast, seqId, nBatch, logitsLast); + + public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + => NativeHandle.EvaluateChunk(chunkPtr, llamaContext, ref nPast, seqId, nBatch, logitsLast); + + public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch) + => NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch); + + public ulong CountTokens(SafeMtmdInputChunks chunks) => NativeHandle.CountTokens(chunks); + + public long CountPositions(SafeMtmdInputChunks chunks) => NativeHandle.CountPositions(chunks); + + public bool SupportsVision => NativeHandle.SupportVision(); + public bool SupportsAudio => NativeHandle.SupportAudio(); + public bool UsesNonCausalAttention => NativeHandle.DecodeUseNonCausal(); + public bool UsesMRope => NativeHandle.DecodeUseMRope(); + public int AudioBitrate => NativeHandle.GetAudioBitrate(); + + public void Dispose() => NativeHandle.Dispose(); +} diff --git a/docs/Examples/LLavaInteractiveModeExecute.md b/docs/Examples/LLavaInteractiveModeExecute.md deleted file mode 100644 index 2bfbbea1d..000000000 --- a/docs/Examples/LLavaInteractiveModeExecute.md +++ /dev/null @@ -1,129 +0,0 @@ -# LLaVA - basic - -```cs -using System.Text.RegularExpressions; -using LLama.Common; -using Spectre.Console; -using LLama.Native; - -namespace LLama.Examples.Examples -{ - // This example shows how to chat with LLaVA model with both image and text as input. - // It uses the interactive executor to inference. - public class LlavaInteractiveModeExecute - { - public static async Task Run() - { - string multiModalProj = UserSettings.GetMMProjPath(); - string modelPath = UserSettings.GetModelPath(); - string modelImage = UserSettings.GetImagePath(); - const int maxTokens = 1024; - - var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; - - var parameters = new ModelParams(modelPath); - - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - - // Llava Init - using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); - - var ex = new InteractiveExecutor(context, clipModel ); - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); - Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); - - var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens }; - - do - { - - // Evaluate if we have images - // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; - - if (hasImages) - { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); - - List imageBytes; - try - { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToList(); - } - catch (IOException exception) - { - Console.ForegroundColor = ConsoleColor.Red; - Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); - Console.Write($"{exception.Message}"); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("Please try again."); - break; - } - - // Each prompt with images we clear cache - // When the prompt contains images we clear KV_CACHE to restart conversation - // See: - // https://github.com/ggerganov/llama.cpp/discussions/3620 - ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 ); - - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) - { - // First image replace to tag " : ""); - } - - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); - Console.WriteLine(); - - foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes))) - { - consoleImage.MaxWidth = 50; - AnsiConsole.Write(consoleImage); - } - - Console.WriteLine(); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine($"The images were scaled down for the console only, the model gets full versions."); - Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu."); - Console.WriteLine(); - - - // Initialize Images in executor - // - foreach (var image in imagePaths) - { - ex.Images.Add(await File.ReadAllBytesAsync(image)); - } - } - - Console.ForegroundColor = Color.White; - await foreach (var text in ex.InferAsync(prompt, inferenceParams)) - { - Console.Write(text); - } - Console.Write(" "); - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.WriteLine(); - - // let the user finish with exit - // - if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) - break; - - } - while(true); - } - } -} -``` \ No newline at end of file diff --git a/docs/Examples/MtmdInteractiveModeExecute.md b/docs/Examples/MtmdInteractiveModeExecute.md new file mode 100644 index 000000000..378c93a1b --- /dev/null +++ b/docs/Examples/MtmdInteractiveModeExecute.md @@ -0,0 +1,41 @@ +# MTMD interactive mode + +`MtmdInteractiveModeExecute` shows how to pair a multimodal projection with a text model so the chat loop can reason over images supplied at runtime. The sample lives in `LLama.Examples/Examples/MtmdInteractiveModeExecute.cs` and reuses the interactive executor provided by LLamaSharp. + +## Workflow +- Resolve the model, multimodal projection, and sample image paths via `UserSettings`. +- Create `ModelParams` for the text model and capture the MTMD defaults with `MtmdContextParams.Default()`. +- Load the base model and context, then initialize `SafeMtmdWeights` with the multimodal projection file. +- Ask the helper for a media marker (`mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""`) and feed it into an `InteractiveExecutor`. + +```cs +var mtmdParameters = MtmdContextParams.Default(); + +using var model = await LLamaWeights.LoadFromFileAsync(parameters); +using var context = model.CreateContext(parameters); + +// Mtmd Init +using var clipModel = await SafeMtmdWeights.LoadFromFileAsync( + multiModalProj, + model, + mtmdParameters); + +var mediaMarker = mtmdParameters.MediaMarker + ?? NativeApi.MtmdDefaultMarker() + ?? ""; + +var ex = new InteractiveExecutor(context, clipModel); +``` + +## Handling user input +- Prompts can include image paths wrapped in braces (for example `{c:/image.jpg}`); the loop searches for those markers with regular expressions. +- Every referenced file is loaded through `SafeMtmdWeights.LoadMedia`, producing `SafeMtmdEmbed` instances that are queued for the next tokenization call. +- When the user provides images, the executor clears its KV cache (`MemorySequenceRemove`) before replacing each brace-wrapped path in the prompt with the multimodal marker. +- The embeds collected for the current turn are copied into `ex.Embeds`, so the executor submits both the text prompt and the pending media to the helper before generation. + +## Running the sample +1. Ensure the model and projection paths returned by `UserSettings` exist locally. +2. Start the example (for instance from the examples host application) and observe the initial description printed to the console. +3. Type text normally, or reference new images by including their path inside braces. Type `/exit` to end the conversation. + +This walkthrough mirrors the logic in the sample so you can adapt it for your own multimodal workflows. diff --git a/mkdocs.yml b/mkdocs.yml index 09cb3b96b..fbffdbba7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,7 +38,7 @@ nav: - Interactive executor - basic: Examples/InteractiveModeExecute.md - Kernel memory integration - basic: Examples/KernelMemory.md - Kernel-memory - save & load: Examples/KernelMemorySaveAndLoad.md - - LLaVA - basic: Examples/LLavaInteractiveModeExecute.md + - MTMD interactive: Examples/MtmdInteractiveModeExecute.md - ChatSession - load & save: Examples/LoadAndSaveSession.md - Executor - save/load state: Examples/LoadAndSaveState.md - Quantization: Examples/QuantizeModel.md @@ -254,4 +254,4 @@ markdown_extensions: custom_checkbox: true - pymdownx.tilde - pymdownx.tabbed: - alternate_style: true \ No newline at end of file + alternate_style: true From c30700638e8d2c14ac81e83e8629d2e6e3fe53cd Mon Sep 17 00:00:00 2001 From: jlsantiago Date: Mon, 29 Sep 2025 21:56:58 +0200 Subject: [PATCH 2/5] Update LLama/Native/NativeApi.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- LLama/Native/NativeApi.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 3123674fc..0ea46a600 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -338,7 +338,6 @@ private static byte[] EncodeNullTerminatedUtf8(string value, string paramName) var bytes = Encoding.UTF8.GetBytes(value); var buffer = new byte[bytes.Length + 1]; Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length); - // buffer[^1] = 0; return buffer; } From 3c92b0704915d8c96d1988b133dd12ed09e135d9 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Mon, 29 Sep 2025 22:57:09 +0200 Subject: [PATCH 3/5] Resolve comment: https://github.com/SciSharp/LLamaSharp/pull/1261#discussion_r2386165308 --- LLama/Native/MtmdContextParams.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs index d83831d85..5b282d802 100644 --- a/LLama/Native/MtmdContextParams.cs +++ b/LLama/Native/MtmdContextParams.cs @@ -138,7 +138,16 @@ private PinnedUtf8String(string value) public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value); - public IntPtr Pointer => _buffer is null ? IntPtr.Zero : _handle.AddrOfPinnedObject(); + public IntPtr Pointer + { + get + { + if (_buffer is null || !_handle.IsAllocated) + return IntPtr.Zero; + + return _handle.AddrOfPinnedObject(); + } + } public void Dispose() { From 384ec34d4c30f8367de14c6dc872f2ce7f64e18b Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 5 Oct 2025 13:47:51 +0200 Subject: [PATCH 4/5] Remove duplicate code --- LLama/Native/SafeMtmdModelHandle.cs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs index 236a22011..86abf8c6c 100644 --- a/LLama/Native/SafeMtmdModelHandle.cs +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -40,7 +40,7 @@ public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights te // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. using (var fs = new FileStream(modelPath, FileMode.Open)) if (!fs.CanRead) - throw new InvalidOperationException($"Mtmd MMP Model file '{modelPath}' is not readable"); + throw new InvalidOperationException($"Mtmd Model file '{modelPath}' is not readable"); using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath)); @@ -138,21 +138,13 @@ public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtm if (result == 0) { chunks = new SafeMtmdInputChunks(output); - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); } else { NativeApi.mtmd_input_chunks_free(output); } - if (result != 0) - { - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); - } + ClearMedia(); return result; } From d5aab128ed3635848ca44e5d673e5bd3df9990f9 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 5 Oct 2025 14:16:49 +0200 Subject: [PATCH 5/5] Move common logic to LlamaExecutorBase --- LLama/LLamaExecutorBase.cs | 195 ++++++++++++++++++++++++++++++ LLama/LLamaInstructExecutor.cs | 154 +----------------------- LLama/LLamaInteractExecutor.cs | 214 ++------------------------------- 3 files changed, 204 insertions(+), 359 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 212194bea..5678945fc 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -81,6 +81,13 @@ public bool IsMultiModal /// public List Embeds { get; } + /// + /// Pending multimodal chunks produced by the MTMD tokenizer. + /// + protected SafeMtmdInputChunks? MtmdChunks { get; set; } + + private string? _mtmdMarker; + private readonly StreamingTokenDecoder _decoder; /// @@ -235,6 +242,194 @@ protected virtual void TryReuseMatchingPrefix() } } + /// + /// Dispose and clear any queued multimodal chunk collection. + /// + protected void DisposeMtmdChunks() + { + MtmdChunks?.Dispose(); + MtmdChunks = null; + } + + /// + /// Dispose and clear any pending multimodal embeddings. + /// + protected void DisposeEmbeds() + { + if (Embeds.Count == 0) + return; + + foreach (var embed in Embeds) + embed.Dispose(); + + Embeds.Clear(); + } + + /// + /// Retrieve the marker token used to signal media segments to the tokenizer. + /// + protected string GetMtmdMarker() + { + if (_mtmdMarker is not null) + return _mtmdMarker; + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + /// + /// Ensure the token list fills all positional slots reported by the MTMD helper. + /// + protected static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + /// + /// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions. + /// + protected LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default; + } + + /// + /// Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens. + /// + protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); + prompt = string.Concat(prompt, suffix); + } + } + + SafeMtmdInputChunks? chunks = null; + try + { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + MtmdChunks = chunks; + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in scopedChunk.GetTextTokensSpan()) + tokens.Add(unchecked((int)token)); + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + + if (replaceExisting) + { + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; + } + else + { + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; + } + } + catch + { + chunks?.Dispose(); + MtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); + } + + return Task.CompletedTask; + } + + /// + /// Apply bookkeeping after successfully evaluating multimodal chunks. + /// + protected void FinalizeMtmdEvaluation(long newNPast, int previousConsumed) + { + _pastTokensCount = checked((int)newNPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } + + /// + /// Evaluate the queued MTMD chunks and update executor state. + /// + protected void EvaluateMtmdChunks(ref long nPast, int previousConsumed, string executorName) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + if (MtmdChunks is null) + throw new InvalidOperationException("No MTMD chunks are queued for evaluation."); + + var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + FinalizeMtmdEvaluation(nPast, previousConsumed); + } + /// /// Determine whether the inference loop should continue processing tokens. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 2069061d5..35f20b776 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -25,8 +25,6 @@ public class InstructExecutor private readonly string _instructionPrefix; private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; - private SafeMtmdInputChunks? _mtmdChunks; - private string? _mtmdMarker; private readonly string _instructionSuffix; /// @@ -190,136 +188,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) return Task.CompletedTask; } - private void DisposeMtmdChunks() - { - _mtmdChunks?.Dispose(); - _mtmdChunks = null; - } - - private void DisposeEmbeds() - { - if (Embeds.Count == 0) - return; - - foreach (var embed in Embeds) - embed.Dispose(); - - Embeds.Clear(); - } - - private string GetMtmdMarker() - { - if (_mtmdMarker is not null) - return _mtmdMarker; - - _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; - return _mtmdMarker; - } - - private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - - private LLamaToken GetFillerToken(string marker) - { - var markerTokens = Context.Tokenize(marker, false, true); - if (markerTokens.Length > 0) - return markerTokens[markerTokens.Length - 1]; - - var eos = Context.Vocab.EOS; - if (eos.HasValue) - return eos.Value; - - return default(LLamaToken); - } - - private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) - { - if (ClipModel is null) - throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - - DisposeMtmdChunks(); - - var marker = GetMtmdMarker(); - var prompt = text; - - if (Embeds.Count > 0) - { - if (prompt.Contains("")) - prompt = prompt.Replace("", marker); - - if (!prompt.Contains(marker)) - { - var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); - prompt = string.Concat(prompt, suffix); - } - } - - SafeMtmdInputChunks? chunks = null; - try - { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); - if (status != 0 || chunks is null) - { - ClipModel.ClearMedia(); - throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); - } - - _mtmdChunks = chunks; - - var tokens = new List(); - foreach (var chunk in chunks.Enumerate()) - { - using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - continue; - - foreach (var token in scopedChunk.GetTextTokensSpan()) - tokens.Add(unchecked((int)token)); - } - - var totalPositions = (int)ClipModel.CountPositions(chunks); - var fillerToken = GetFillerToken(marker); - - if (replaceExisting) - { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); - _consumedTokensCount = 0; - } - else - { - if (_embed_inps.Count == 0) - _embed_inps = new List(); - - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; - } - } - catch - { - chunks?.Dispose(); - _mtmdChunks = null; - throw; - } - finally - { - DisposeEmbeds(); - } - - return Task.CompletedTask; - } - /// protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) { @@ -380,30 +248,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } - else if (IsMultiModal && _mtmdChunks is not null) + else if (IsMultiModal && MtmdChunks is not null) { _is_prompt_run = false; var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor)); } _embeds.Clear(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index a6ead60fa..392be783c 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -24,10 +24,6 @@ public class InteractiveExecutor : StatefulExecutorBase // Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn. private bool _is_prompt_run = true; - // MTMD multimodal state - private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer. - private string? _mtmdMarker; // Cached multimodal marker returned by the native helper. - /// /// Create an interactive executor for text-only inference. /// @@ -136,7 +132,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } else { - PreprocessMtmd(text, args, true); + return PreprocessMtmd(text, args, addBos: true, replaceExisting: true); } } else @@ -157,168 +153,9 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } else { - PreprocessMtmd(text, args, false); - } - } - } - - return Task.CompletedTask; - } - - /// - /// Release any queued multimodal chunks and reset state. - /// - private void DisposeMtmdChunks() - { - _mtmdChunks?.Dispose(); - _mtmdChunks = null; - } - - /// - /// Dispose and clear any pending multimodal embeddings queued for evaluation. - /// - private void DisposeEmbeds() - { - if (Embeds.Count == 0) - { - return; - } - - foreach (var embed in Embeds) - { - embed.Dispose(); - } - - Embeds.Clear(); - } - - /// - /// Retrieve the marker token used to signal media segments to the tokenizer. - /// - private string GetMtmdMarker() - { - if (_mtmdMarker is not null) - { - return _mtmdMarker; - } - - _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; - return _mtmdMarker; - } - - private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - - private LLamaToken GetFillerToken(string marker) - { - var markerTokens = Context.Tokenize(marker, false, true); - if (markerTokens.Length > 0) - return markerTokens[markerTokens.Length - 1]; - - var eos = Context.Vocab.EOS; - if (eos.HasValue) - return eos.Value; - - return default(LLamaToken); - } - - /// - /// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers. - /// - /// Prompt text containing optional media markers. - /// Mutable inference state. - /// Whether to treat the prompt as a fresh run and add the BOS token. - private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true) - { - if (ClipModel is null) - { - throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - } - - DisposeMtmdChunks(); - - var marker = GetMtmdMarker(); - var prompt = text; - - if (Embeds.Count > 0) - { - if (prompt.Contains("")) - { - prompt = prompt.Replace("", marker); - } - - if (!prompt.Contains(marker)) - { - var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed. - prompt = string.Concat(prompt, suffix); - } - } - - SafeMtmdInputChunks? chunks = null; - try - { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); - if (status != 0 || chunks is null) - { - ClipModel.ClearMedia(); - throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); - } - - _mtmdChunks = chunks; // Own the chunk collection until evaluation completes. - - var tokens = new List(); - foreach (var chunk in chunks.Enumerate()) - { - using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - { - continue; - } - - foreach (var token in scopedChunk.GetTextTokensSpan()) - { - tokens.Add(unchecked((int)token)); + return PreprocessMtmd(text, args, addBos: false, replaceExisting: false); } } - - var totalPositions = (int)ClipModel.CountPositions(chunks); - var fillerToken = GetFillerToken(marker); - - if (addBos) - { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); - _consumedTokensCount = 0; - } - else - { - if (_embed_inps.Count == 0) - _embed_inps = new List(); - - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; - } - } - catch - { - chunks?.Dispose(); - _mtmdChunks = null; - throw; - } - finally - { - DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval. } return Task.CompletedTask; @@ -380,35 +217,16 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In HandleRunOutOfContext(tokensToKeep); } - if (_mtmdChunks is null) + if (MtmdChunks is null) { TryReuseMatchingPrefix(); } - if (IsMultiModal && _mtmdChunks is not null) + if (IsMultiModal && MtmdChunks is not null) { var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, - nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); } else { @@ -424,30 +242,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } } } - else if (IsMultiModal && _mtmdChunks is not null) + else if (IsMultiModal && MtmdChunks is not null) { _is_prompt_run = false; var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); } _embeds.Clear();