diff --git a/TensorStack.StableDiffusion/Enums/PipelineType.cs b/TensorStack.StableDiffusion/Enums/PipelineType.cs index dab692d..c7aa346 100644 --- a/TensorStack.StableDiffusion/Enums/PipelineType.cs +++ b/TensorStack.StableDiffusion/Enums/PipelineType.cs @@ -11,5 +11,6 @@ public enum PipelineType StableCascade = 10, LatentConsistency = 20, Flux = 30, + Nitro = 40 } } diff --git a/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs b/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs new file mode 100644 index 0000000..27b875c --- /dev/null +++ b/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs @@ -0,0 +1,56 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Config; + +namespace TensorStack.StableDiffusion.Models +{ + /// + /// TransformerModel: Nitro Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + /// + public class TransformerNitroModel : TransformerModel + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + public TransformerNitroModel(TransformerModelConfig configuration) + : base(configuration) { } + + + /// + /// Runs the Transformer model with the specified inputs + /// + /// The timestep. + /// The hidden states. + /// The encoder hidden states. + /// The cancellation token that can be used by other objects or threads to receive notice of cancellation. + /// A Task<Tensor`1> representing the asynchronous operation. + public async Task> RunAsync(int timestep, Tensor hiddenStates, Tensor encoderHiddenStates, CancellationToken cancellationToken = default) + { + if (!Transformer.IsLoaded()) + await Transformer.LoadAsync(cancellationToken: cancellationToken); + + using (var transformerParams = new ModelParameters(Transformer.Metadata, cancellationToken)) + { + // Inputs + transformerParams.AddInput(hiddenStates.AsTensorSpan()); + transformerParams.AddInput(encoderHiddenStates.AsTensorSpan()); + transformerParams.AddScalarInput(timestep); + + // Outputs + transformerParams.AddOutput(hiddenStates.Dimensions); + + // Inference + using (var results = await Transformer.RunInferenceAsync(transformerParams)) + { + return results[0].ToTensor(); + } + } + } + + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs new file mode 100644 index 0000000..8081d5d --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs @@ -0,0 +1,392 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Common; +using TensorStack.StableDiffusion.Enums; +using TensorStack.StableDiffusion.Models; +using TensorStack.StableDiffusion.Schedulers; +using TensorStack.TextGeneration.Pipelines.Llama; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Nitro +{ + public abstract class NitroBase : PipelineBase + { + /// + /// Initializes a new instance of the class. + /// + /// The transformer. + /// The text encoder. + /// The automatic encoder. + /// The logger. + public NitroBase(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = default) : base(logger) + { + Transformer = transformer; + AutoEncoder = autoEncoder; + TextEncoder = textEncoder; + Initialize(); + Logger?.LogInformation("[NitroPipeline] Name: {Name}", Name); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The logger. + public NitroBase(NitroConfig configuration, ILogger logger = default) : this( + new TransformerNitroModel(configuration.Transformer), + new LlamaPipeline(new LlamaConfig + { + OutputLastHiddenStates = true, + DecoderConfig = configuration.TextEncoder, + Tokenizer = new BPETokenizer(configuration.Tokenizer), + }), + new AutoEncoderModel(configuration.AutoEncoder), + logger) + { + Name = configuration.Name; + } + + /// + /// Gets the type of the pipeline. + /// + public override PipelineType PipelineType => PipelineType.Nitro; + + /// + /// Gets the friendly name. + /// + public override string Name { get; init; } = nameof(PipelineType.Nitro); + + /// + /// Gets the TextEncoder. + /// + public LlamaPipeline TextEncoder { get; init; } + + /// + /// Gets the transformer. + /// + public TransformerNitroModel Transformer { get; init; } + + /// + /// Gets the automatic encoder. + /// + public AutoEncoderModel AutoEncoder { get; init; } + + + /// + /// Loads the pipeline. + /// + /// The cancellation token. + public Task LoadAsync(CancellationToken cancellationToken = default) + { + // Nitro pipelines are lazy loaded on first run + return Task.CompletedTask; + } + + + /// + /// Unloads the pipeline. + /// + /// The cancellation token. + public async Task UnloadAsync(CancellationToken cancellationToken = default) + { + await Task.WhenAll + ( + Transformer.UnloadAsync(), + TextEncoder.UnloadAsync(cancellationToken), + AutoEncoder.EncoderUnloadAsync(), + AutoEncoder.DecoderUnloadAsync() + ); + Logger?.LogInformation("[{PipeLineType}] Pipeline Unloaded", PipelineType); + } + + + /// + /// Validates the options. + /// + /// The options. + protected override void ValidateOptions(GenerateOptions options) + { + base.ValidateOptions(options); + if (!Transformer.HasControlNet && options.HasControlNet) + throw new ArgumentException("Model does not support ControlNet"); + } + + + /// + /// Creates the prompt input embeddings. + /// + /// The options. + /// The cancellation token. + protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) + { + // Conditional Prompt + var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions + { + Seed = options.Seed, + Prompt = options.Prompt, + MaxLength = 128 + }, cancellationToken); + + // Unconditional prompt + var negativePromptEmbeds = default(Tensor); + if (!string.IsNullOrEmpty(options.NegativePrompt)) + { + negativePromptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions + { + Seed = options.Seed, + Prompt = options.NegativePrompt, + MaxLength = 128 + }, cancellationToken); + } + + return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default); + } + + + /// + /// Decode the model latents to image + /// + /// The options. + /// The latents. + /// The cancellation token. + protected async Task DecodeLatentsAsync(IPipelineOptions options, Tensor latents, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[DecodeLatentsAsync] Begin AutoEncoder Decode"); + var decoderResult = await AutoEncoder.DecodeAsync(latents, cancellationToken: cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) + await AutoEncoder.DecoderUnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[DecodeLatentsAsync] AutoEncoder Decode Complete"); + return decoderResult.AsImageTensor(); + } + + + /// + /// Encode the image to model latents + /// + /// The options. + /// The latents. + /// The cancellation token. + private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); + var inputTensor = image.ResizeImage(options.Width, options.Height); + var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) + await AutoEncoder.EncoderUnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); + return encoderResult; + } + + + protected async Task> RunInferenceAsync(IPipelineOptions options, IScheduler scheduler, PromptResult prompt, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[RunInferenceAsync] Begin Transformer Inference"); + + // Prompt + var isGuidanceEnabled = IsGuidanceEnabled(options); + var promptEmbedsCond = prompt.PromptEmbeds; + var promptEmbedsUncond = prompt.NegativePromptEmbeds; + + // Latents + var latents = await CreateLatentInputAsync(options, scheduler, cancellationToken); + + // Load Model + await LoadTransformerAsync(options, progressCallback, cancellationToken); + + // Timesteps + var timesteps = scheduler.GetTimesteps(); + for (int i = 0; i < timesteps.Count; i++) + { + var timestep = timesteps[i]; + var steptime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Inputs. + var latentInput = scheduler.ScaleInput(timestep, latents); + + // Inference + var conditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsCond, cancellationToken: cancellationToken); + if (isGuidanceEnabled) + { + var unconditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsUncond, cancellationToken: cancellationToken); + conditional = ApplyGuidance(conditional, unconditional, options.GuidanceScale); + } + + // Scheduler + var stepResult = scheduler.Step(timestep, conditional, latents); + + // Result + latents = stepResult.Sample; + + // Progress + if (scheduler.IsFinalOrder) + progressCallback.Notify(scheduler.CurrentStep, scheduler.TotalSteps, latents, steptime); + + Logger.LogEnd(LogLevel.Debug, steptime, $"[RunInferenceAsync] Step: {i + 1}/{timesteps.Count}"); + } + + // Unload + if (options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) + await Transformer.UnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[RunInferenceAsync] Transformer Inference Complete"); + return latents; + } + + + /// + /// Create latent input. + /// + /// The options. + /// The scheduler. + /// The cancellation token. + private async Task> CreateLatentInputAsync(IPipelineOptions options, IScheduler scheduler, CancellationToken cancellationToken = default) + { + var dimensions = new int[] { 1, AutoEncoder.LatentChannels, options.Height / AutoEncoder.Scale, options.Width / AutoEncoder.Scale }; + var noiseTensor = scheduler.CreateRandomSample(dimensions); + if (options.HasInputImage) + { + var timestep = scheduler.GetStartTimestep(); + var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor); + } + return noiseTensor; + } + + + /// + /// Gets the model optimizations. + /// + /// The generate options. + /// The progress callback. + private ModelOptimization GetOptimizations(IPipelineOptions generateOptions, IProgress progressCallback = null) + { + var optimizations = new ModelOptimization(Optimization.None); + if (Transformer.HasOptimizationsChanged(optimizations)) + { + progressCallback.Notify("Optimizing Pipeline..."); + } + return optimizations; + } + + + /// + /// Determines whether classifier-free guidance is enabled + /// + /// The options. + private bool IsGuidanceEnabled(IPipelineOptions options) + { + return options.GuidanceScale > 1; + } + + + /// + /// Load Transformer with optimizations + /// + /// The options. + /// The progress callback. + /// The cancellation token. + private async Task LoadTransformerAsync(IPipelineOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + var optimizations = GetOptimizations(options, progressCallback); + return await Transformer.LoadAsync(optimizations, cancellationToken); + } + + + /// + /// Checks the state of the pipeline. + /// + /// The options. + protected override async Task CheckPipelineState(IPipelineOptions options) + { + // Check Transformer/ControlNet status + if (options.HasControlNet && Transformer.IsLoaded()) + await Transformer.UnloadAsync(); + if (!options.HasControlNet && Transformer.IsControlNetLoaded()) + await Transformer.UnloadControlNetAsync(); + + // Check LowMemory status + if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)) // TODO + await TextEncoder.UnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsLoaded()) + await Transformer.UnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsControlNetLoaded()) + await Transformer.UnloadControlNetAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) && AutoEncoder.IsEncoderLoaded()) + await AutoEncoder.EncoderUnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) && AutoEncoder.IsDecoderLoaded()) + await AutoEncoder.DecoderUnloadAsync(); + } + + + /// + /// Configures the supported schedulers. + /// + protected override IReadOnlyList ConfigureSchedulers() + { + return [SchedulerType.FlowMatchEulerDiscrete, SchedulerType.FlowMatchEulerDynamic]; + } + + + /// + /// Configures the default SchedulerOptions. + /// + protected override GenerateOptions ConfigureDefaultOptions() + { + var options = new GenerateOptions + { + Steps = 20, + Shift = 1f, + Width = 1024, + Height = 1024, + GuidanceScale = 0f, + Scheduler = SchedulerType.FlowMatchEulerDiscrete + }; + + // Nitro-Distilled Models ,4 Steps, No Guidance + if (Transformer.ModelType == ModelType.Turbo) + { + return options with + { + Steps = 4, + Shift = 1f, + Width = 1024, + Height = 1024, + GuidanceScale = 0, + Scheduler = SchedulerType.FlowMatchEulerDiscrete + }; + } + + return options; + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + private bool _disposed; + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) + { + TextEncoder?.Dispose(); + Transformer?.Dispose(); + AutoEncoder?.Dispose(); + } + _disposed = true; + } + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs new file mode 100644 index 0000000..e761a3c --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs @@ -0,0 +1,126 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using System.IO; +using TensorStack.Common; +using TensorStack.StableDiffusion.Config; +using TensorStack.StableDiffusion.Enums; +using TensorStack.TextGeneration.Common; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Nitro +{ + public record NitroConfig : PipelineConfig + { + /// + /// Initializes a new instance of the class. + /// + public NitroConfig() + { + Tokenizer = new TokenizerConfig + { + BOS = 128000, + EOS = 128001 + }; + TextEncoder = new DecoderConfig + { + NumHeads = 32, + NumLayers = 16, + NumKVHeads = 8, + HiddenSize = 2048, + VocabSize = 128256 + }; + Transformer = new TransformerModelConfig + { + InChannels = 32, + OutChannels = 32, + JointAttention = 2048, + IsOptimizationSupported = true + }; + AutoEncoder = new AutoEncoderModelConfig + { + LatentChannels = 32, + ScaleFactor = 0.41407f + }; + } + + public string Name { get; init; } = "Nitro"; + public override PipelineType Pipeline { get; } = PipelineType.Nitro; + public TokenizerConfig Tokenizer { get; init; } + public DecoderConfig TextEncoder { get; init; } + public TransformerModelConfig Transformer { get; init; } + public AutoEncoderModelConfig AutoEncoder { get; init; } + + + /// + /// Sets the execution provider for all models. + /// + /// The execution provider. + public override void SetProvider(ExecutionProvider executionProvider) + { + TextEncoder.SetProvider(executionProvider); + Transformer.SetProvider(executionProvider); + AutoEncoder.SetProvider(executionProvider); + } + + + /// + /// Saves the configuration to file. + /// + /// The configuration file. + /// if set to true use relative paths. + public override void Save(string configFile, bool useRelativePaths = true) + { + ConfigService.Serialize(configFile, this, useRelativePaths); + } + + + /// + /// Create Nitro configuration from default values + /// + /// The name. + /// Type of the model. + /// The execution provider. + /// NitroConfig. + public static NitroConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default) + { + var config = new NitroConfig { Name = name }; + config.Transformer.ModelType = modelType; + config.SetProvider(executionProvider); + return config; + } + + + /// + /// Create StableDiffusionv configuration from json file + /// + /// The configuration file. + /// The execution provider. + /// NitroConfig. + public static NitroConfig FromFile(string configFile, ExecutionProvider executionProvider = default) + { + var config = ConfigService.Deserialize(configFile); + config.SetProvider(executionProvider); + return config; + } + + + /// + /// Create Nitro configuration from folder structure + /// + /// The model folder. + /// Type of the model. + /// The execution provider. + /// NitroConfig. + public static NitroConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) + { + var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer"); + config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx"); + config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx"); + config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx"); + config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx"); + return config; + } + + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs new file mode 100644 index 0000000..28bd624 --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs @@ -0,0 +1,82 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using Microsoft.Extensions.Logging; +using System; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Pipeline; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Common; +using TensorStack.StableDiffusion.Enums; +using TensorStack.StableDiffusion.Models; +using TensorStack.TextGeneration.Pipelines.Llama; + +namespace TensorStack.StableDiffusion.Pipelines.Nitro +{ + public class NitroPipeline : NitroBase, IPipeline + { + /// + /// Initializes a new instance of the class. + /// + /// The transformer. + /// The text encoder. + /// The automatic encoder. + /// The logger. + public NitroPipeline(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = null) + : base(transformer, textEncoder, autoEncoder, logger) { } + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The logger. + public NitroPipeline(NitroConfig configuration, ILogger logger = null) + : base(configuration, logger) { } + + + /// + /// Run ImageTensor pipeline. + /// + /// The options. + /// The progress callback. + /// The cancellation token. + public async Task RunAsync(GenerateOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + ValidateOptions(options); + var prompt = await CreatePromptAsync(options, cancellationToken); + using (var scheduler = CreateScheduler(options)) + { + var latents = await RunInferenceAsync(options, scheduler, prompt, progressCallback, cancellationToken); + return await DecodeLatentsAsync(options, latents, cancellationToken); + } + } + + + /// + /// Create Nitro pipeline from StableDiffusionConfig file + /// + /// The configuration file. + /// The execution provider. + /// The logger. + /// NitroPipeline. + public static NitroPipeline FromConfig(string configFile, ExecutionProvider executionProvider, ILogger logger = default) + { + return new NitroPipeline(NitroConfig.FromFile(configFile, executionProvider), logger); + } + + + /// + /// Create Nitro pipeline from folder structure + /// + /// The model folder. + /// Type of the model. + /// The execution provider. + /// The logger. + /// NitroPipeline. + public static NitroPipeline FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default) + { + return new NitroPipeline(NitroConfig.FromFolder(modelFolder, modelType, executionProvider), logger); + } + } +} diff --git a/TensorStack.TextGeneration/Common/GenerateOptions.cs b/TensorStack.TextGeneration/Common/GenerateOptions.cs index 2ec2d21..c977291 100644 --- a/TensorStack.TextGeneration/Common/GenerateOptions.cs +++ b/TensorStack.TextGeneration/Common/GenerateOptions.cs @@ -19,5 +19,6 @@ public record GenerateOptions : IRunOptions public float LengthPenalty { get; set; } = 1.0f; public EarlyStopping EarlyStopping { get; set; } public int DiversityLength { get; set; } = 20; + public bool OutputLastHiddenStates { get; set; } } } diff --git a/TensorStack.TextGeneration/Common/GenerateResult.cs b/TensorStack.TextGeneration/Common/GenerateResult.cs index ca98498..46a379d 100644 --- a/TensorStack.TextGeneration/Common/GenerateResult.cs +++ b/TensorStack.TextGeneration/Common/GenerateResult.cs @@ -1,6 +1,7 @@ // Copyright (c) TensorStack. All rights reserved. // Licensed under the Apache 2.0 License. using System.Collections.Generic; +using TensorStack.Common.Tensor; using TensorStack.Common.Vision; namespace TensorStack.TextGeneration.Common @@ -12,5 +13,7 @@ public class GenerateResult public string Result { get; set; } public float PenaltyScore { get; set; } public List CoordinateResults { get; set; } + public Tensor LastHiddenState { get; set; } + public IReadOnlyList Tokens { get; set; } } } diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs index 4cd27ef..1ad1281 100644 --- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs +++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs @@ -4,5 +4,6 @@ namespace TensorStack.TextGeneration.Pipelines.Llama { public record LlamaConfig : TransformerConfig { + public bool OutputLastHiddenStates { get; set; } } } diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs index dded84f..1623fd1 100644 --- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs +++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs @@ -49,7 +49,9 @@ public virtual async Task RunAsync(GenerateOptions options, IPro return new GenerateResult { Score = sequence.Score, - Result = Tokenizer.Decode(sequence.Tokens) + Result = Tokenizer.Decode(sequence.Tokens), + Tokens = sequence.Tokens, + LastHiddenState = sequence.LastHiddenState }; } } @@ -77,7 +79,9 @@ public async Task RunAsync(SearchOptions options, IProgress RunAsync(SearchOptions options, IProgress + /// Gets the LastHiddenState. + /// + /// The options. + /// The cancellation token. + public async Task> GetLastHiddenState(GenerateOptions options, CancellationToken cancellationToken = default) + { + await TokenizePromptAsync(options); + using (var sequence = await InitializeAsync(options)) + { + return sequence.LastHiddenState; + } + } + + /// /// Gets the token processors. /// @@ -173,9 +192,9 @@ private Tensor RunDecoderInternal(ModelMetadata modelMetadata, Sequence s { var dimension = logitsResult.GetDimensions(); var logits = logitsResult.ToTensor(dimension[1..]); - var presentKeyValues = modelResult.ToArray()[1..]; - - sequence.UpdateCache(presentKeyValues, useBranchCache); + var lastHiddenState = Configuration.OutputLastHiddenStates ? modelResult[^1].ToTensor() : default; + var presentKeyValues = Configuration.OutputLastHiddenStates ? modelResult.ToArray()[1..^1] : modelResult.ToArray()[1..]; + sequence.UpdateCache(presentKeyValues, useBranchCache, lastHiddenState); return logits; } } @@ -200,6 +219,7 @@ public static LlamaPipeline Create(ExecutionProvider provider, string modelPath, var vocabSize = 128256; var config = new LlamaConfig { + OutputLastHiddenStates = true, Tokenizer = new BPETokenizer(new TokenizerConfig { BOS = 128000, diff --git a/TensorStack.TextGeneration/Sequence.cs b/TensorStack.TextGeneration/Sequence.cs index cfc321c..524f4cc 100644 --- a/TensorStack.TextGeneration/Sequence.cs +++ b/TensorStack.TextGeneration/Sequence.cs @@ -3,6 +3,7 @@ using Microsoft.ML.OnnxRuntime; using System; using System.Collections.Generic; +using TensorStack.Common.Tensor; using TensorStack.TextGeneration.Cache; namespace TensorStack.TextGeneration @@ -10,6 +11,7 @@ namespace TensorStack.TextGeneration public sealed class Sequence : IDisposable { private IKVCache _cache; + private Tensor _lastHiddenState; /// /// Initializes a new instance of the class. @@ -55,6 +57,11 @@ private Sequence(List tokens, float score, IKVCache cache) /// public OrtValue[] Cache => _cache.Values; + /// + /// Gets the LastHiddenState. + /// + public Tensor LastHiddenState => _lastHiddenState; + /// /// Gets or sets the sequnece score. /// @@ -89,8 +96,9 @@ public bool Initialize(int initialLength) /// /// The current values. /// if set to true use branch cache. - public void UpdateCache(OrtValue[] currentValues, bool useBranchCache) + public void UpdateCache(OrtValue[] currentValues, bool useBranchCache, Tensor lastHiddenState = default) { + _lastHiddenState = lastHiddenState; _cache.Update(currentValues, useBranchCache); }