From c084cde63c1cbc80f970285b9971e574af0d4d70 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Thu, 6 Nov 2025 19:23:31 +1300 Subject: [PATCH 1/2] Nitro Pipeline --- .../Enums/PipelineType.cs | 1 + .../Models/TransformerNitroModel.cs | 56 +++ .../Pipelines/Nitro/NitroBase.cs | 410 ++++++++++++++++++ .../Pipelines/Nitro/NitroConfig.cs | 128 ++++++ .../Pipelines/Nitro/NitroPipeline.cs | 83 ++++ .../Common/GenerateOptions.cs | 1 + .../Common/GenerateResult.cs | 3 + .../Pipelines/Llama/LlamaConfig.cs | 1 + .../Pipelines/Llama/LlamaPipeline.cs | 15 +- TensorStack.TextGeneration/Sequence.cs | 10 +- 10 files changed, 702 insertions(+), 6 deletions(-) create mode 100644 TensorStack.StableDiffusion/Models/TransformerNitroModel.cs create mode 100644 TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs create mode 100644 TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs create mode 100644 TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs diff --git a/TensorStack.StableDiffusion/Enums/PipelineType.cs b/TensorStack.StableDiffusion/Enums/PipelineType.cs index dab692d..c7aa346 100644 --- a/TensorStack.StableDiffusion/Enums/PipelineType.cs +++ b/TensorStack.StableDiffusion/Enums/PipelineType.cs @@ -11,5 +11,6 @@ public enum PipelineType StableCascade = 10, LatentConsistency = 20, Flux = 30, + Nitro = 40 } } diff --git a/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs b/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs new file mode 100644 index 0000000..27b875c --- /dev/null +++ b/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs @@ -0,0 +1,56 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Config; + +namespace TensorStack.StableDiffusion.Models +{ + /// + /// TransformerModel: Nitro Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + /// + public class TransformerNitroModel : TransformerModel + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + public TransformerNitroModel(TransformerModelConfig configuration) + : base(configuration) { } + + + /// + /// Runs the Transformer model with the specified inputs + /// + /// The timestep. + /// The hidden states. + /// The encoder hidden states. + /// The cancellation token that can be used by other objects or threads to receive notice of cancellation. + /// A Task<Tensor`1> representing the asynchronous operation. + public async Task> RunAsync(int timestep, Tensor hiddenStates, Tensor encoderHiddenStates, CancellationToken cancellationToken = default) + { + if (!Transformer.IsLoaded()) + await Transformer.LoadAsync(cancellationToken: cancellationToken); + + using (var transformerParams = new ModelParameters(Transformer.Metadata, cancellationToken)) + { + // Inputs + transformerParams.AddInput(hiddenStates.AsTensorSpan()); + transformerParams.AddInput(encoderHiddenStates.AsTensorSpan()); + transformerParams.AddScalarInput(timestep); + + // Outputs + transformerParams.AddOutput(hiddenStates.Dimensions); + + // Inference + using (var results = await Transformer.RunInferenceAsync(transformerParams)) + { + return results[0].ToTensor(); + } + } + } + + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs new file mode 100644 index 0000000..dd93e24 --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs @@ -0,0 +1,410 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using Microsoft.Extensions.Logging; +using Microsoft.ML.Tokenizers; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Common; +using TensorStack.StableDiffusion.Enums; +using TensorStack.StableDiffusion.Models; +using TensorStack.StableDiffusion.Schedulers; +using TensorStack.TextGeneration.Pipelines.Llama; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Nitro +{ + public abstract class NitroBase : PipelineBase + { + /// + /// Initializes a new instance of the class. + /// + /// The transformer. + /// The text encoder. + /// The automatic encoder. + /// The logger. + public NitroBase(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = default) : base(logger) + { + Transformer = transformer; + AutoEncoder = autoEncoder; + TextEncoder = textEncoder; + Initialize(); + Logger?.LogInformation("[NitroPipeline] Name: {Name}", Name); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The logger. + public NitroBase(NitroConfig configuration, ILogger logger = default) : this( + new TransformerNitroModel(configuration.Transformer), + new LlamaPipeline(new LlamaConfig + { + DecoderConfig = configuration.TextEncoder, + Tokenizer = new BPETokenizer(configuration.Tokenizer), + }), + new AutoEncoderModel(configuration.AutoEncoder), + logger) + { + Name = configuration.Name; + } + + /// + /// Gets the type of the pipeline. + /// + public override PipelineType PipelineType => PipelineType.Nitro; + + /// + /// Gets the friendly name. + /// + public override string Name { get; init; } = nameof(PipelineType.Nitro); + + /// + /// Gets the TextEncoder. + /// + public LlamaPipeline TextEncoder { get; init; } + + /// + /// Gets the transformer. + /// + public TransformerNitroModel Transformer { get; init; } + + /// + /// Gets the automatic encoder. + /// + public AutoEncoderModel AutoEncoder { get; init; } + + + /// + /// Loads the pipeline. + /// + /// The cancellation token. + public Task LoadAsync(CancellationToken cancellationToken = default) + { + // Nitro pipelines are lazy loaded on first run + return Task.CompletedTask; + } + + + /// + /// Unloads the pipeline. + /// + /// The cancellation token. + public async Task UnloadAsync(CancellationToken cancellationToken = default) + { + await Task.WhenAll + ( + Transformer.UnloadAsync(), + TextEncoder.UnloadAsync(cancellationToken), + AutoEncoder.EncoderUnloadAsync(), + AutoEncoder.DecoderUnloadAsync() + ); + Logger?.LogInformation("[{PipeLineType}] Pipeline Unloaded", PipelineType); + } + + + /// + /// Validates the options. + /// + /// The options. + protected override void ValidateOptions(GenerateOptions options) + { + base.ValidateOptions(options); + if (!Transformer.HasControlNet && options.HasControlNet) + throw new ArgumentException("Model does not support ControlNet"); + } + + + /// + /// Creates the prompt input embeddings. + /// + /// The options. + /// The cancellation token. + protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) + { + //// Tokenize2 + //var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); + //var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); + //var maxTokenLength = (int)Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length); + + //// Tokenizer2 + //var prompt2Tokens = await TokenizePrompt2Async(options.Prompt, cancellationToken); + //var negativePrompt2Tokens = await TokenizePrompt2Async(options.NegativePrompt, cancellationToken); + + //// TextEncoder + //var promptEmbeddings = await EncodePromptAsync(promptTokens, maxTokenLength, cancellationToken); + //var negativePromptEmbeddings = await EncodePromptAsync(negativePromptTokens, maxTokenLength, cancellationToken); + //if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) + // await TextEncoder.UnloadAsync(); + + + + //// Prompt + //var promptEmbeds = prompt2Embeddings.HiddenStates; + //var promptPooledEmbeds = promptEmbeddings.TextEmbeds; + //promptPooledEmbeds = promptPooledEmbeds.Reshape([promptPooledEmbeds.Dimensions[^2], promptPooledEmbeds.Dimensions[^1]]).FirstBatch(); + + //// Negative promt + //var negativePromptEmbeds = negativePrompt2Embeddings.HiddenStates; + //var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds; + //negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch(); + + //return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds); + + + var result = TextEncoder.RunAsync(new TextGeneration.Common.GenerateOptions + { + Seed = options.Seed, + Prompt = options.Prompt + }); + + return default; + } + + + + + /// + /// Decode the model latents to image + /// + /// The options. + /// The latents. + /// The cancellation token. + protected async Task DecodeLatentsAsync(IPipelineOptions options, Tensor latents, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[DecodeLatentsAsync] Begin AutoEncoder Decode"); + var decoderResult = await AutoEncoder.DecodeAsync(latents, cancellationToken: cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) + await AutoEncoder.DecoderUnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[DecodeLatentsAsync] AutoEncoder Decode Complete"); + return decoderResult.AsImageTensor(); + } + + + /// + /// Encode the image to model latents + /// + /// The options. + /// The latents. + /// The cancellation token. + private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); + var inputTensor = image.ResizeImage(options.Width, options.Height); + var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) + await AutoEncoder.EncoderUnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); + return encoderResult; + } + + + protected async Task> RunInferenceAsync(IPipelineOptions options, IScheduler scheduler, PromptResult prompt, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[RunInferenceAsync] Begin Transformer Inference"); + + // Prompt + var isGuidanceEnabled = IsGuidanceEnabled(options); + var promptEmbedsCond = prompt.PromptEmbeds; + var promptEmbedsUncond = prompt.NegativePromptEmbeds; + + // Latents + var latents = await CreateLatentInputAsync(options, scheduler, cancellationToken); + + // Load Model + await LoadTransformerAsync(options, progressCallback, cancellationToken); + + // Timesteps + var timesteps = scheduler.GetTimesteps(); + for (int i = 0; i < timesteps.Count; i++) + { + var timestep = timesteps[i]; + var steptime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Inputs. + var latentInput = scheduler.ScaleInput(timestep, latents); + + // Inference + var conditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsCond, cancellationToken: cancellationToken); + if (isGuidanceEnabled) + { + var unconditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsUncond, cancellationToken: cancellationToken); + conditional = ApplyGuidance(conditional, unconditional, options.GuidanceScale); + } + + // Scheduler + var stepResult = scheduler.Step(timestep, conditional, latents); + + // Result + latents = stepResult.Sample; + + // Progress + if (scheduler.IsFinalOrder) + progressCallback.Notify(scheduler.CurrentStep, scheduler.TotalSteps, latents, steptime); + + Logger.LogEnd(LogLevel.Debug, steptime, $"[RunInferenceAsync] Step: {i + 1}/{timesteps.Count}"); + } + + // Unload + if (options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) + await Transformer.UnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[RunInferenceAsync] Transformer Inference Complete"); + return latents; + } + + + /// + /// Create latent input. + /// + /// The options. + /// The scheduler. + /// The cancellation token. + private async Task> CreateLatentInputAsync(IPipelineOptions options, IScheduler scheduler, CancellationToken cancellationToken = default) + { + var dimensions = new int[] { 1, AutoEncoder.LatentChannels, options.Height / AutoEncoder.Scale, options.Width / AutoEncoder.Scale }; + var noiseTensor = scheduler.CreateRandomSample(dimensions); + if (options.HasInputImage) + { + var timestep = scheduler.GetStartTimestep(); + var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor); + } + return noiseTensor; + } + + + /// + /// Gets the model optimizations. + /// + /// The generate options. + /// The progress callback. + private ModelOptimization GetOptimizations(IPipelineOptions generateOptions, IProgress progressCallback = null) + { + var optimizations = new ModelOptimization(Optimization.None); + if (Transformer.HasOptimizationsChanged(optimizations)) + { + progressCallback.Notify("Optimizing Pipeline..."); + } + return optimizations; + } + + + /// + /// Determines whether classifier-free guidance is enabled + /// + /// The options. + private bool IsGuidanceEnabled(IPipelineOptions options) + { + return options.GuidanceScale > 1; + } + + + /// + /// Load Transformer with optimizations + /// + /// The options. + /// The progress callback. + /// The cancellation token. + private async Task LoadTransformerAsync(IPipelineOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + var optimizations = GetOptimizations(options, progressCallback); + return await Transformer.LoadAsync(optimizations, cancellationToken); + } + + + /// + /// Checks the state of the pipeline. + /// + /// The options. + protected override async Task CheckPipelineState(IPipelineOptions options) + { + // Check Transformer/ControlNet status + if (options.HasControlNet && Transformer.IsLoaded()) + await Transformer.UnloadAsync(); + if (!options.HasControlNet && Transformer.IsControlNetLoaded()) + await Transformer.UnloadControlNetAsync(); + + // Check LowMemory status + if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)) // TODO + await TextEncoder.UnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsLoaded()) + await Transformer.UnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsControlNetLoaded()) + await Transformer.UnloadControlNetAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) && AutoEncoder.IsEncoderLoaded()) + await AutoEncoder.EncoderUnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) && AutoEncoder.IsDecoderLoaded()) + await AutoEncoder.DecoderUnloadAsync(); + } + + + /// + /// Configures the supported schedulers. + /// + protected override IReadOnlyList ConfigureSchedulers() + { + return [SchedulerType.FlowMatchEulerDiscrete, SchedulerType.FlowMatchEulerDynamic]; + } + + + /// + /// Configures the default SchedulerOptions. + /// + protected override GenerateOptions ConfigureDefaultOptions() + { + var options = new GenerateOptions + { + Steps = 20, + Shift = 1f, + Width = 1024, + Height = 1024, + GuidanceScale = 0f, + Scheduler = SchedulerType.FlowMatchEulerDiscrete + }; + + // SD3-Turbo Models , 4 Steps, No Guidance + if (Transformer.ModelType == ModelType.Turbo) + { + return options with + { + Steps = 4, + Shift = 1f, + Width = 1024, + Height = 1024, + GuidanceScale = 0, + Scheduler = SchedulerType.FlowMatchEulerDiscrete + }; + } + + return options; + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + private bool _disposed; + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) + { + TextEncoder?.Dispose(); + Transformer?.Dispose(); + AutoEncoder?.Dispose(); + } + _disposed = true; + } + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs new file mode 100644 index 0000000..387faec --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs @@ -0,0 +1,128 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using System.IO; +using TensorStack.Common; +using TensorStack.StableDiffusion.Config; +using TensorStack.StableDiffusion.Enums; +using TensorStack.TextGeneration.Common; +using TensorStack.TextGeneration.Pipelines.Llama; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Nitro +{ + public record NitroConfig : PipelineConfig + { + /// + /// Initializes a new instance of the class. + /// + public NitroConfig() + { + Tokenizer = new TokenizerConfig + { + BOS = 128000, + EOS = 128001 + }; + TextEncoder = new DecoderConfig + { + NumHeads = 32, + NumLayers = 16, + HiddenSize = 2048, + NumKVHeads = 8, + VocabSize = 128256 + }; + Transformer = new TransformerModelConfig + { + InChannels = 32, + OutChannels = 32, + JointAttention = 2048, + IsOptimizationSupported = true + }; + AutoEncoder = new AutoEncoderModelConfig + { + LatentChannels = 32, + ScaleFactor = 0.3611f, + ShiftFactor = 0.1159f + }; + } + + public string Name { get; init; } = "Nitro"; + public override PipelineType Pipeline { get; } = PipelineType.Nitro; + public TokenizerConfig Tokenizer { get; init; } + public DecoderConfig TextEncoder { get; init; } + public TransformerModelConfig Transformer { get; init; } + public AutoEncoderModelConfig AutoEncoder { get; init; } + + + /// + /// Sets the execution provider for all models. + /// + /// The execution provider. + public override void SetProvider(ExecutionProvider executionProvider) + { + TextEncoder.SetProvider(executionProvider); + Transformer.SetProvider(executionProvider); + AutoEncoder.SetProvider(executionProvider); + } + + + /// + /// Saves the configuration to file. + /// + /// The configuration file. + /// if set to true use relative paths. + public override void Save(string configFile, bool useRelativePaths = true) + { + ConfigService.Serialize(configFile, this, useRelativePaths); + } + + + /// + /// Create Nitro configuration from default values + /// + /// The name. + /// Type of the model. + /// The execution provider. + /// NitroConfig. + public static NitroConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default) + { + var config = new NitroConfig { Name = name }; + config.Transformer.ModelType = modelType; + config.SetProvider(executionProvider); + return config; + } + + + /// + /// Create StableDiffusionv configuration from json file + /// + /// The configuration file. + /// The execution provider. + /// NitroConfig. + public static NitroConfig FromFile(string configFile, ExecutionProvider executionProvider = default) + { + var config = ConfigService.Deserialize(configFile); + config.SetProvider(executionProvider); + return config; + } + + + /// + /// Create Nitro configuration from folder structure + /// + /// The model folder. + /// Type of the model. + /// The execution provider. + /// NitroConfig. + public static NitroConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) + { + var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json"); + config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx"); + config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx"); + config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx"); + config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx"); + return config; + } + + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs new file mode 100644 index 0000000..de5e053 --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs @@ -0,0 +1,83 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using Microsoft.Extensions.Logging; +using System; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Pipeline; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Common; +using TensorStack.StableDiffusion.Enums; +using TensorStack.StableDiffusion.Models; +using TensorStack.TextGeneration.Pipelines.Llama; + +namespace TensorStack.StableDiffusion.Pipelines.Nitro +{ + public class NitroPipeline : NitroBase, IPipeline + { + /// + /// Initializes a new instance of the class. + /// + /// The transformer. + /// The text encoder. + /// The automatic encoder. + /// The logger. + public NitroPipeline(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = null) + : base(transformer, textEncoder, autoEncoder, logger) { } + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The logger. + public NitroPipeline(NitroConfig configuration, ILogger logger = null) + : base(configuration, logger) { } + + + /// + /// Run ImageTensor pipeline. + /// + /// The options. + /// The progress callback. + /// The cancellation token. + public async Task RunAsync(GenerateOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + ValidateOptions(options); + + var prompt = await CreatePromptAsync(options, cancellationToken); + using (var scheduler = CreateScheduler(options)) + { + var latents = await RunInferenceAsync(options, scheduler, prompt, progressCallback, cancellationToken); + return await DecodeLatentsAsync(options, latents, cancellationToken); + } + } + + + /// + /// Create Nitro pipeline from StableDiffusionConfig file + /// + /// The configuration file. + /// The execution provider. + /// The logger. + /// NitroPipeline. + public static NitroPipeline FromConfig(string configFile, ExecutionProvider executionProvider, ILogger logger = default) + { + return new NitroPipeline(NitroConfig.FromFile(configFile, executionProvider), logger); + } + + + /// + /// Create Nitro pipeline from folder structure + /// + /// The model folder. + /// Type of the model. + /// The execution provider. + /// The logger. + /// NitroPipeline. + public static NitroPipeline FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default) + { + return new NitroPipeline(NitroConfig.FromFolder(modelFolder, modelType, executionProvider), logger); + } + } +} diff --git a/TensorStack.TextGeneration/Common/GenerateOptions.cs b/TensorStack.TextGeneration/Common/GenerateOptions.cs index 2ec2d21..c977291 100644 --- a/TensorStack.TextGeneration/Common/GenerateOptions.cs +++ b/TensorStack.TextGeneration/Common/GenerateOptions.cs @@ -19,5 +19,6 @@ public record GenerateOptions : IRunOptions public float LengthPenalty { get; set; } = 1.0f; public EarlyStopping EarlyStopping { get; set; } public int DiversityLength { get; set; } = 20; + public bool OutputLastHiddenStates { get; set; } } } diff --git a/TensorStack.TextGeneration/Common/GenerateResult.cs b/TensorStack.TextGeneration/Common/GenerateResult.cs index ca98498..46a379d 100644 --- a/TensorStack.TextGeneration/Common/GenerateResult.cs +++ b/TensorStack.TextGeneration/Common/GenerateResult.cs @@ -1,6 +1,7 @@ // Copyright (c) TensorStack. All rights reserved. // Licensed under the Apache 2.0 License. using System.Collections.Generic; +using TensorStack.Common.Tensor; using TensorStack.Common.Vision; namespace TensorStack.TextGeneration.Common @@ -12,5 +13,7 @@ public class GenerateResult public string Result { get; set; } public float PenaltyScore { get; set; } public List CoordinateResults { get; set; } + public Tensor LastHiddenState { get; set; } + public IReadOnlyList Tokens { get; set; } } } diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs index 4cd27ef..1ad1281 100644 --- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs +++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs @@ -4,5 +4,6 @@ namespace TensorStack.TextGeneration.Pipelines.Llama { public record LlamaConfig : TransformerConfig { + public bool OutputLastHiddenStates { get; set; } } } diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs index dded84f..aae9cd6 100644 --- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs +++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs @@ -49,7 +49,9 @@ public virtual async Task RunAsync(GenerateOptions options, IPro return new GenerateResult { Score = sequence.Score, - Result = Tokenizer.Decode(sequence.Tokens) + Result = Tokenizer.Decode(sequence.Tokens), + Tokens = sequence.Tokens, + LastHiddenState = sequence.LastHiddenState }; } } @@ -77,7 +79,9 @@ public async Task RunAsync(SearchOptions options, IProgress RunDecoderInternal(ModelMetadata modelMetadata, Sequence s { var dimension = logitsResult.GetDimensions(); var logits = logitsResult.ToTensor(dimension[1..]); - var presentKeyValues = modelResult.ToArray()[1..]; - - sequence.UpdateCache(presentKeyValues, useBranchCache); + var lastHiddenState = Configuration.OutputLastHiddenStates ? modelResult[^1].ToTensor() : default; + var presentKeyValues = Configuration.OutputLastHiddenStates ? modelResult.ToArray()[1..^1] : modelResult.ToArray()[1..]; + sequence.UpdateCache(presentKeyValues, useBranchCache, lastHiddenState); return logits; } } @@ -200,6 +204,7 @@ public static LlamaPipeline Create(ExecutionProvider provider, string modelPath, var vocabSize = 128256; var config = new LlamaConfig { + OutputLastHiddenStates = true, Tokenizer = new BPETokenizer(new TokenizerConfig { BOS = 128000, diff --git a/TensorStack.TextGeneration/Sequence.cs b/TensorStack.TextGeneration/Sequence.cs index cfc321c..524f4cc 100644 --- a/TensorStack.TextGeneration/Sequence.cs +++ b/TensorStack.TextGeneration/Sequence.cs @@ -3,6 +3,7 @@ using Microsoft.ML.OnnxRuntime; using System; using System.Collections.Generic; +using TensorStack.Common.Tensor; using TensorStack.TextGeneration.Cache; namespace TensorStack.TextGeneration @@ -10,6 +11,7 @@ namespace TensorStack.TextGeneration public sealed class Sequence : IDisposable { private IKVCache _cache; + private Tensor _lastHiddenState; /// /// Initializes a new instance of the class. @@ -55,6 +57,11 @@ private Sequence(List tokens, float score, IKVCache cache) /// public OrtValue[] Cache => _cache.Values; + /// + /// Gets the LastHiddenState. + /// + public Tensor LastHiddenState => _lastHiddenState; + /// /// Gets or sets the sequnece score. /// @@ -89,8 +96,9 @@ public bool Initialize(int initialLength) /// /// The current values. /// if set to true use branch cache. - public void UpdateCache(OrtValue[] currentValues, bool useBranchCache) + public void UpdateCache(OrtValue[] currentValues, bool useBranchCache, Tensor lastHiddenState = default) { + _lastHiddenState = lastHiddenState; _cache.Update(currentValues, useBranchCache); } From 1047b11d21c2f2ebf090ab0e9281f5c5a6414c16 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Fri, 7 Nov 2025 09:32:36 +1300 Subject: [PATCH 2/2] Nitro TextEncoder --- .../Pipelines/Nitro/NitroBase.cs | 60 +++++++------------ .../Pipelines/Nitro/NitroConfig.cs | 8 +-- .../Pipelines/Nitro/NitroPipeline.cs | 1 - .../Pipelines/Llama/LlamaPipeline.cs | 15 +++++ 4 files changed, 39 insertions(+), 45 deletions(-) diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs index dd93e24..8081d5d 100644 --- a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs @@ -1,7 +1,6 @@ // Copyright (c) TensorStack. All rights reserved. // Licensed under the Apache 2.0 License. using Microsoft.Extensions.Logging; -using Microsoft.ML.Tokenizers; using System; using System.Collections.Generic; using System.Diagnostics; @@ -21,7 +20,7 @@ namespace TensorStack.StableDiffusion.Pipelines.Nitro public abstract class NitroBase : PipelineBase { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The transformer. /// The text encoder. @@ -46,6 +45,7 @@ public NitroBase(NitroConfig configuration, ILogger logger = default) : this( new TransformerNitroModel(configuration.Transformer), new LlamaPipeline(new LlamaConfig { + OutputLastHiddenStates = true, DecoderConfig = configuration.TextEncoder, Tokenizer = new BPETokenizer(configuration.Tokenizer), }), @@ -128,46 +128,28 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { - //// Tokenize2 - //var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); - //var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); - //var maxTokenLength = (int)Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length); - - //// Tokenizer2 - //var prompt2Tokens = await TokenizePrompt2Async(options.Prompt, cancellationToken); - //var negativePrompt2Tokens = await TokenizePrompt2Async(options.NegativePrompt, cancellationToken); - - //// TextEncoder - //var promptEmbeddings = await EncodePromptAsync(promptTokens, maxTokenLength, cancellationToken); - //var negativePromptEmbeddings = await EncodePromptAsync(negativePromptTokens, maxTokenLength, cancellationToken); - //if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) - // await TextEncoder.UnloadAsync(); - - - - //// Prompt - //var promptEmbeds = prompt2Embeddings.HiddenStates; - //var promptPooledEmbeds = promptEmbeddings.TextEmbeds; - //promptPooledEmbeds = promptPooledEmbeds.Reshape([promptPooledEmbeds.Dimensions[^2], promptPooledEmbeds.Dimensions[^1]]).FirstBatch(); - - //// Negative promt - //var negativePromptEmbeds = negativePrompt2Embeddings.HiddenStates; - //var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds; - //negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch(); - - //return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds); - - - var result = TextEncoder.RunAsync(new TextGeneration.Common.GenerateOptions + // Conditional Prompt + var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions { Seed = options.Seed, - Prompt = options.Prompt - }); - - return default; - } + Prompt = options.Prompt, + MaxLength = 128 + }, cancellationToken); + // Unconditional prompt + var negativePromptEmbeds = default(Tensor); + if (!string.IsNullOrEmpty(options.NegativePrompt)) + { + negativePromptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions + { + Seed = options.Seed, + Prompt = options.NegativePrompt, + MaxLength = 128 + }, cancellationToken); + } + return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default); + } /// @@ -372,7 +354,7 @@ protected override GenerateOptions ConfigureDefaultOptions() Scheduler = SchedulerType.FlowMatchEulerDiscrete }; - // SD3-Turbo Models , 4 Steps, No Guidance + // Nitro-Distilled Models ,4 Steps, No Guidance if (Transformer.ModelType == ModelType.Turbo) { return options with diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs index 387faec..e761a3c 100644 --- a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs @@ -5,7 +5,6 @@ using TensorStack.StableDiffusion.Config; using TensorStack.StableDiffusion.Enums; using TensorStack.TextGeneration.Common; -using TensorStack.TextGeneration.Pipelines.Llama; using TensorStack.TextGeneration.Tokenizers; namespace TensorStack.StableDiffusion.Pipelines.Nitro @@ -26,8 +25,8 @@ public NitroConfig() { NumHeads = 32, NumLayers = 16, - HiddenSize = 2048, NumKVHeads = 8, + HiddenSize = 2048, VocabSize = 128256 }; Transformer = new TransformerModelConfig @@ -40,8 +39,7 @@ public NitroConfig() AutoEncoder = new AutoEncoderModelConfig { LatentChannels = 32, - ScaleFactor = 0.3611f, - ShiftFactor = 0.1159f + ScaleFactor = 0.41407f }; } @@ -116,7 +114,7 @@ public static NitroConfig FromFile(string configFile, ExecutionProvider executio public static NitroConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) { var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); - config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json"); + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer"); config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx"); config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx"); config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx"); diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs index de5e053..28bd624 100644 --- a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs @@ -44,7 +44,6 @@ public NitroPipeline(NitroConfig configuration, ILogger logger = null) public async Task RunAsync(GenerateOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) { ValidateOptions(options); - var prompt = await CreatePromptAsync(options, cancellationToken); using (var scheduler = CreateScheduler(options)) { diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs index aae9cd6..1623fd1 100644 --- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs +++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs @@ -89,6 +89,21 @@ public async Task RunAsync(SearchOptions options, IProgress + /// Gets the LastHiddenState. + /// + /// The options. + /// The cancellation token. + public async Task> GetLastHiddenState(GenerateOptions options, CancellationToken cancellationToken = default) + { + await TokenizePromptAsync(options); + using (var sequence = await InitializeAsync(options)) + { + return sequence.LastHiddenState; + } + } + + /// /// Gets the token processors. ///