diff --git a/TensorStack.StableDiffusion/Enums/PipelineType.cs b/TensorStack.StableDiffusion/Enums/PipelineType.cs
index dab692d..c7aa346 100644
--- a/TensorStack.StableDiffusion/Enums/PipelineType.cs
+++ b/TensorStack.StableDiffusion/Enums/PipelineType.cs
@@ -11,5 +11,6 @@ public enum PipelineType
StableCascade = 10,
LatentConsistency = 20,
Flux = 30,
+ Nitro = 40
}
}
diff --git a/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs b/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs
new file mode 100644
index 0000000..27b875c
--- /dev/null
+++ b/TensorStack.StableDiffusion/Models/TransformerNitroModel.cs
@@ -0,0 +1,56 @@
+// Copyright (c) TensorStack. All rights reserved.
+// Licensed under the Apache 2.0 License.
+using System.Threading;
+using System.Threading.Tasks;
+using TensorStack.Common;
+using TensorStack.Common.Tensor;
+using TensorStack.StableDiffusion.Config;
+
+namespace TensorStack.StableDiffusion.Models
+{
+ ///
+ /// TransformerModel: Nitro Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ ///
+ public class TransformerNitroModel : TransformerModel
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ public TransformerNitroModel(TransformerModelConfig configuration)
+ : base(configuration) { }
+
+
+ ///
+ /// Runs the Transformer model with the specified inputs
+ ///
+ /// The timestep.
+ /// The hidden states.
+ /// The encoder hidden states.
+ /// The cancellation token that can be used by other objects or threads to receive notice of cancellation.
+ /// A Task<Tensor`1> representing the asynchronous operation.
+ public async Task> RunAsync(int timestep, Tensor hiddenStates, Tensor encoderHiddenStates, CancellationToken cancellationToken = default)
+ {
+ if (!Transformer.IsLoaded())
+ await Transformer.LoadAsync(cancellationToken: cancellationToken);
+
+ using (var transformerParams = new ModelParameters(Transformer.Metadata, cancellationToken))
+ {
+ // Inputs
+ transformerParams.AddInput(hiddenStates.AsTensorSpan());
+ transformerParams.AddInput(encoderHiddenStates.AsTensorSpan());
+ transformerParams.AddScalarInput(timestep);
+
+ // Outputs
+ transformerParams.AddOutput(hiddenStates.Dimensions);
+
+ // Inference
+ using (var results = await Transformer.RunInferenceAsync(transformerParams))
+ {
+ return results[0].ToTensor();
+ }
+ }
+ }
+
+ }
+}
diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs
new file mode 100644
index 0000000..8081d5d
--- /dev/null
+++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs
@@ -0,0 +1,392 @@
+// Copyright (c) TensorStack. All rights reserved.
+// Licensed under the Apache 2.0 License.
+using Microsoft.Extensions.Logging;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Threading;
+using System.Threading.Tasks;
+using TensorStack.Common;
+using TensorStack.Common.Tensor;
+using TensorStack.StableDiffusion.Common;
+using TensorStack.StableDiffusion.Enums;
+using TensorStack.StableDiffusion.Models;
+using TensorStack.StableDiffusion.Schedulers;
+using TensorStack.TextGeneration.Pipelines.Llama;
+using TensorStack.TextGeneration.Tokenizers;
+
+namespace TensorStack.StableDiffusion.Pipelines.Nitro
+{
+ public abstract class NitroBase : PipelineBase
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The transformer.
+ /// The text encoder.
+ /// The automatic encoder.
+ /// The logger.
+ public NitroBase(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = default) : base(logger)
+ {
+ Transformer = transformer;
+ AutoEncoder = autoEncoder;
+ TextEncoder = textEncoder;
+ Initialize();
+ Logger?.LogInformation("[NitroPipeline] Name: {Name}", Name);
+ }
+
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The logger.
+ public NitroBase(NitroConfig configuration, ILogger logger = default) : this(
+ new TransformerNitroModel(configuration.Transformer),
+ new LlamaPipeline(new LlamaConfig
+ {
+ OutputLastHiddenStates = true,
+ DecoderConfig = configuration.TextEncoder,
+ Tokenizer = new BPETokenizer(configuration.Tokenizer),
+ }),
+ new AutoEncoderModel(configuration.AutoEncoder),
+ logger)
+ {
+ Name = configuration.Name;
+ }
+
+ ///
+ /// Gets the type of the pipeline.
+ ///
+ public override PipelineType PipelineType => PipelineType.Nitro;
+
+ ///
+ /// Gets the friendly name.
+ ///
+ public override string Name { get; init; } = nameof(PipelineType.Nitro);
+
+ ///
+ /// Gets the TextEncoder.
+ ///
+ public LlamaPipeline TextEncoder { get; init; }
+
+ ///
+ /// Gets the transformer.
+ ///
+ public TransformerNitroModel Transformer { get; init; }
+
+ ///
+ /// Gets the automatic encoder.
+ ///
+ public AutoEncoderModel AutoEncoder { get; init; }
+
+
+ ///
+ /// Loads the pipeline.
+ ///
+ /// The cancellation token.
+ public Task LoadAsync(CancellationToken cancellationToken = default)
+ {
+ // Nitro pipelines are lazy loaded on first run
+ return Task.CompletedTask;
+ }
+
+
+ ///
+ /// Unloads the pipeline.
+ ///
+ /// The cancellation token.
+ public async Task UnloadAsync(CancellationToken cancellationToken = default)
+ {
+ await Task.WhenAll
+ (
+ Transformer.UnloadAsync(),
+ TextEncoder.UnloadAsync(cancellationToken),
+ AutoEncoder.EncoderUnloadAsync(),
+ AutoEncoder.DecoderUnloadAsync()
+ );
+ Logger?.LogInformation("[{PipeLineType}] Pipeline Unloaded", PipelineType);
+ }
+
+
+ ///
+ /// Validates the options.
+ ///
+ /// The options.
+ protected override void ValidateOptions(GenerateOptions options)
+ {
+ base.ValidateOptions(options);
+ if (!Transformer.HasControlNet && options.HasControlNet)
+ throw new ArgumentException("Model does not support ControlNet");
+ }
+
+
+ ///
+ /// Creates the prompt input embeddings.
+ ///
+ /// The options.
+ /// The cancellation token.
+ protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
+ {
+ // Conditional Prompt
+ var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
+ {
+ Seed = options.Seed,
+ Prompt = options.Prompt,
+ MaxLength = 128
+ }, cancellationToken);
+
+ // Unconditional prompt
+ var negativePromptEmbeds = default(Tensor);
+ if (!string.IsNullOrEmpty(options.NegativePrompt))
+ {
+ negativePromptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
+ {
+ Seed = options.Seed,
+ Prompt = options.NegativePrompt,
+ MaxLength = 128
+ }, cancellationToken);
+ }
+
+ return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default);
+ }
+
+
+ ///
+ /// Decode the model latents to image
+ ///
+ /// The options.
+ /// The latents.
+ /// The cancellation token.
+ protected async Task DecodeLatentsAsync(IPipelineOptions options, Tensor latents, CancellationToken cancellationToken = default)
+ {
+ var timestamp = Logger.LogBegin(LogLevel.Debug, "[DecodeLatentsAsync] Begin AutoEncoder Decode");
+ var decoderResult = await AutoEncoder.DecodeAsync(latents, cancellationToken: cancellationToken);
+ if (options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled)
+ await AutoEncoder.DecoderUnloadAsync();
+
+ Logger.LogEnd(LogLevel.Debug, timestamp, "[DecodeLatentsAsync] AutoEncoder Decode Complete");
+ return decoderResult.AsImageTensor();
+ }
+
+
+ ///
+ /// Encode the image to model latents
+ ///
+ /// The options.
+ /// The latents.
+ /// The cancellation token.
+ private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
+ {
+ var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
+ var inputTensor = image.ResizeImage(options.Width, options.Height);
+ var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
+ if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
+ await AutoEncoder.EncoderUnloadAsync();
+
+ Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
+ return encoderResult;
+ }
+
+
+ protected async Task> RunInferenceAsync(IPipelineOptions options, IScheduler scheduler, PromptResult prompt, IProgress progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ var timestamp = Logger.LogBegin(LogLevel.Debug, "[RunInferenceAsync] Begin Transformer Inference");
+
+ // Prompt
+ var isGuidanceEnabled = IsGuidanceEnabled(options);
+ var promptEmbedsCond = prompt.PromptEmbeds;
+ var promptEmbedsUncond = prompt.NegativePromptEmbeds;
+
+ // Latents
+ var latents = await CreateLatentInputAsync(options, scheduler, cancellationToken);
+
+ // Load Model
+ await LoadTransformerAsync(options, progressCallback, cancellationToken);
+
+ // Timesteps
+ var timesteps = scheduler.GetTimesteps();
+ for (int i = 0; i < timesteps.Count; i++)
+ {
+ var timestep = timesteps[i];
+ var steptime = Stopwatch.GetTimestamp();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Inputs.
+ var latentInput = scheduler.ScaleInput(timestep, latents);
+
+ // Inference
+ var conditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsCond, cancellationToken: cancellationToken);
+ if (isGuidanceEnabled)
+ {
+ var unconditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsUncond, cancellationToken: cancellationToken);
+ conditional = ApplyGuidance(conditional, unconditional, options.GuidanceScale);
+ }
+
+ // Scheduler
+ var stepResult = scheduler.Step(timestep, conditional, latents);
+
+ // Result
+ latents = stepResult.Sample;
+
+ // Progress
+ if (scheduler.IsFinalOrder)
+ progressCallback.Notify(scheduler.CurrentStep, scheduler.TotalSteps, latents, steptime);
+
+ Logger.LogEnd(LogLevel.Debug, steptime, $"[RunInferenceAsync] Step: {i + 1}/{timesteps.Count}");
+ }
+
+ // Unload
+ if (options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled)
+ await Transformer.UnloadAsync();
+
+ Logger.LogEnd(LogLevel.Debug, timestamp, "[RunInferenceAsync] Transformer Inference Complete");
+ return latents;
+ }
+
+
+ ///
+ /// Create latent input.
+ ///
+ /// The options.
+ /// The scheduler.
+ /// The cancellation token.
+ private async Task> CreateLatentInputAsync(IPipelineOptions options, IScheduler scheduler, CancellationToken cancellationToken = default)
+ {
+ var dimensions = new int[] { 1, AutoEncoder.LatentChannels, options.Height / AutoEncoder.Scale, options.Width / AutoEncoder.Scale };
+ var noiseTensor = scheduler.CreateRandomSample(dimensions);
+ if (options.HasInputImage)
+ {
+ var timestep = scheduler.GetStartTimestep();
+ var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
+ return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor);
+ }
+ return noiseTensor;
+ }
+
+
+ ///
+ /// Gets the model optimizations.
+ ///
+ /// The generate options.
+ /// The progress callback.
+ private ModelOptimization GetOptimizations(IPipelineOptions generateOptions, IProgress progressCallback = null)
+ {
+ var optimizations = new ModelOptimization(Optimization.None);
+ if (Transformer.HasOptimizationsChanged(optimizations))
+ {
+ progressCallback.Notify("Optimizing Pipeline...");
+ }
+ return optimizations;
+ }
+
+
+ ///
+ /// Determines whether classifier-free guidance is enabled
+ ///
+ /// The options.
+ private bool IsGuidanceEnabled(IPipelineOptions options)
+ {
+ return options.GuidanceScale > 1;
+ }
+
+
+ ///
+ /// Load Transformer with optimizations
+ ///
+ /// The options.
+ /// The progress callback.
+ /// The cancellation token.
+ private async Task LoadTransformerAsync(IPipelineOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ var optimizations = GetOptimizations(options, progressCallback);
+ return await Transformer.LoadAsync(optimizations, cancellationToken);
+ }
+
+
+ ///
+ /// Checks the state of the pipeline.
+ ///
+ /// The options.
+ protected override async Task CheckPipelineState(IPipelineOptions options)
+ {
+ // Check Transformer/ControlNet status
+ if (options.HasControlNet && Transformer.IsLoaded())
+ await Transformer.UnloadAsync();
+ if (!options.HasControlNet && Transformer.IsControlNetLoaded())
+ await Transformer.UnloadControlNetAsync();
+
+ // Check LowMemory status
+ if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)) // TODO
+ await TextEncoder.UnloadAsync();
+ if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsLoaded())
+ await Transformer.UnloadAsync();
+ if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsControlNetLoaded())
+ await Transformer.UnloadControlNetAsync();
+ if ((options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) && AutoEncoder.IsEncoderLoaded())
+ await AutoEncoder.EncoderUnloadAsync();
+ if ((options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) && AutoEncoder.IsDecoderLoaded())
+ await AutoEncoder.DecoderUnloadAsync();
+ }
+
+
+ ///
+ /// Configures the supported schedulers.
+ ///
+ protected override IReadOnlyList ConfigureSchedulers()
+ {
+ return [SchedulerType.FlowMatchEulerDiscrete, SchedulerType.FlowMatchEulerDynamic];
+ }
+
+
+ ///
+ /// Configures the default SchedulerOptions.
+ ///
+ protected override GenerateOptions ConfigureDefaultOptions()
+ {
+ var options = new GenerateOptions
+ {
+ Steps = 20,
+ Shift = 1f,
+ Width = 1024,
+ Height = 1024,
+ GuidanceScale = 0f,
+ Scheduler = SchedulerType.FlowMatchEulerDiscrete
+ };
+
+ // Nitro-Distilled Models ,4 Steps, No Guidance
+ if (Transformer.ModelType == ModelType.Turbo)
+ {
+ return options with
+ {
+ Steps = 4,
+ Shift = 1f,
+ Width = 1024,
+ Height = 1024,
+ GuidanceScale = 0,
+ Scheduler = SchedulerType.FlowMatchEulerDiscrete
+ };
+ }
+
+ return options;
+ }
+
+
+ ///
+ /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+ ///
+ private bool _disposed;
+ protected override void Dispose(bool disposing)
+ {
+ if (_disposed)
+ return;
+ if (disposing)
+ {
+ TextEncoder?.Dispose();
+ Transformer?.Dispose();
+ AutoEncoder?.Dispose();
+ }
+ _disposed = true;
+ }
+ }
+}
diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs
new file mode 100644
index 0000000..e761a3c
--- /dev/null
+++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs
@@ -0,0 +1,126 @@
+// Copyright (c) TensorStack. All rights reserved.
+// Licensed under the Apache 2.0 License.
+using System.IO;
+using TensorStack.Common;
+using TensorStack.StableDiffusion.Config;
+using TensorStack.StableDiffusion.Enums;
+using TensorStack.TextGeneration.Common;
+using TensorStack.TextGeneration.Tokenizers;
+
+namespace TensorStack.StableDiffusion.Pipelines.Nitro
+{
+ public record NitroConfig : PipelineConfig
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public NitroConfig()
+ {
+ Tokenizer = new TokenizerConfig
+ {
+ BOS = 128000,
+ EOS = 128001
+ };
+ TextEncoder = new DecoderConfig
+ {
+ NumHeads = 32,
+ NumLayers = 16,
+ NumKVHeads = 8,
+ HiddenSize = 2048,
+ VocabSize = 128256
+ };
+ Transformer = new TransformerModelConfig
+ {
+ InChannels = 32,
+ OutChannels = 32,
+ JointAttention = 2048,
+ IsOptimizationSupported = true
+ };
+ AutoEncoder = new AutoEncoderModelConfig
+ {
+ LatentChannels = 32,
+ ScaleFactor = 0.41407f
+ };
+ }
+
+ public string Name { get; init; } = "Nitro";
+ public override PipelineType Pipeline { get; } = PipelineType.Nitro;
+ public TokenizerConfig Tokenizer { get; init; }
+ public DecoderConfig TextEncoder { get; init; }
+ public TransformerModelConfig Transformer { get; init; }
+ public AutoEncoderModelConfig AutoEncoder { get; init; }
+
+
+ ///
+ /// Sets the execution provider for all models.
+ ///
+ /// The execution provider.
+ public override void SetProvider(ExecutionProvider executionProvider)
+ {
+ TextEncoder.SetProvider(executionProvider);
+ Transformer.SetProvider(executionProvider);
+ AutoEncoder.SetProvider(executionProvider);
+ }
+
+
+ ///
+ /// Saves the configuration to file.
+ ///
+ /// The configuration file.
+ /// if set to true use relative paths.
+ public override void Save(string configFile, bool useRelativePaths = true)
+ {
+ ConfigService.Serialize(configFile, this, useRelativePaths);
+ }
+
+
+ ///
+ /// Create Nitro configuration from default values
+ ///
+ /// The name.
+ /// Type of the model.
+ /// The execution provider.
+ /// NitroConfig.
+ public static NitroConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default)
+ {
+ var config = new NitroConfig { Name = name };
+ config.Transformer.ModelType = modelType;
+ config.SetProvider(executionProvider);
+ return config;
+ }
+
+
+ ///
+ /// Create StableDiffusionv configuration from json file
+ ///
+ /// The configuration file.
+ /// The execution provider.
+ /// NitroConfig.
+ public static NitroConfig FromFile(string configFile, ExecutionProvider executionProvider = default)
+ {
+ var config = ConfigService.Deserialize(configFile);
+ config.SetProvider(executionProvider);
+ return config;
+ }
+
+
+ ///
+ /// Create Nitro configuration from folder structure
+ ///
+ /// The model folder.
+ /// Type of the model.
+ /// The execution provider.
+ /// NitroConfig.
+ public static NitroConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
+ {
+ var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
+ config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer");
+ config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
+ config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
+ config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
+ config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
+ return config;
+ }
+
+ }
+}
diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs
new file mode 100644
index 0000000..28bd624
--- /dev/null
+++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs
@@ -0,0 +1,82 @@
+// Copyright (c) TensorStack. All rights reserved.
+// Licensed under the Apache 2.0 License.
+using Microsoft.Extensions.Logging;
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using TensorStack.Common;
+using TensorStack.Common.Pipeline;
+using TensorStack.Common.Tensor;
+using TensorStack.StableDiffusion.Common;
+using TensorStack.StableDiffusion.Enums;
+using TensorStack.StableDiffusion.Models;
+using TensorStack.TextGeneration.Pipelines.Llama;
+
+namespace TensorStack.StableDiffusion.Pipelines.Nitro
+{
+ public class NitroPipeline : NitroBase, IPipeline
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The transformer.
+ /// The text encoder.
+ /// The automatic encoder.
+ /// The logger.
+ public NitroPipeline(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = null)
+ : base(transformer, textEncoder, autoEncoder, logger) { }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The logger.
+ public NitroPipeline(NitroConfig configuration, ILogger logger = null)
+ : base(configuration, logger) { }
+
+
+ ///
+ /// Run ImageTensor pipeline.
+ ///
+ /// The options.
+ /// The progress callback.
+ /// The cancellation token.
+ public async Task RunAsync(GenerateOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ ValidateOptions(options);
+ var prompt = await CreatePromptAsync(options, cancellationToken);
+ using (var scheduler = CreateScheduler(options))
+ {
+ var latents = await RunInferenceAsync(options, scheduler, prompt, progressCallback, cancellationToken);
+ return await DecodeLatentsAsync(options, latents, cancellationToken);
+ }
+ }
+
+
+ ///
+ /// Create Nitro pipeline from StableDiffusionConfig file
+ ///
+ /// The configuration file.
+ /// The execution provider.
+ /// The logger.
+ /// NitroPipeline.
+ public static NitroPipeline FromConfig(string configFile, ExecutionProvider executionProvider, ILogger logger = default)
+ {
+ return new NitroPipeline(NitroConfig.FromFile(configFile, executionProvider), logger);
+ }
+
+
+ ///
+ /// Create Nitro pipeline from folder structure
+ ///
+ /// The model folder.
+ /// Type of the model.
+ /// The execution provider.
+ /// The logger.
+ /// NitroPipeline.
+ public static NitroPipeline FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
+ {
+ return new NitroPipeline(NitroConfig.FromFolder(modelFolder, modelType, executionProvider), logger);
+ }
+ }
+}
diff --git a/TensorStack.TextGeneration/Common/GenerateOptions.cs b/TensorStack.TextGeneration/Common/GenerateOptions.cs
index 2ec2d21..c977291 100644
--- a/TensorStack.TextGeneration/Common/GenerateOptions.cs
+++ b/TensorStack.TextGeneration/Common/GenerateOptions.cs
@@ -19,5 +19,6 @@ public record GenerateOptions : IRunOptions
public float LengthPenalty { get; set; } = 1.0f;
public EarlyStopping EarlyStopping { get; set; }
public int DiversityLength { get; set; } = 20;
+ public bool OutputLastHiddenStates { get; set; }
}
}
diff --git a/TensorStack.TextGeneration/Common/GenerateResult.cs b/TensorStack.TextGeneration/Common/GenerateResult.cs
index ca98498..46a379d 100644
--- a/TensorStack.TextGeneration/Common/GenerateResult.cs
+++ b/TensorStack.TextGeneration/Common/GenerateResult.cs
@@ -1,6 +1,7 @@
// Copyright (c) TensorStack. All rights reserved.
// Licensed under the Apache 2.0 License.
using System.Collections.Generic;
+using TensorStack.Common.Tensor;
using TensorStack.Common.Vision;
namespace TensorStack.TextGeneration.Common
@@ -12,5 +13,7 @@ public class GenerateResult
public string Result { get; set; }
public float PenaltyScore { get; set; }
public List CoordinateResults { get; set; }
+ public Tensor LastHiddenState { get; set; }
+ public IReadOnlyList Tokens { get; set; }
}
}
diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs
index 4cd27ef..1ad1281 100644
--- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs
+++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaConfig.cs
@@ -4,5 +4,6 @@ namespace TensorStack.TextGeneration.Pipelines.Llama
{
public record LlamaConfig : TransformerConfig
{
+ public bool OutputLastHiddenStates { get; set; }
}
}
diff --git a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs
index dded84f..1623fd1 100644
--- a/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs
+++ b/TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs
@@ -49,7 +49,9 @@ public virtual async Task RunAsync(GenerateOptions options, IPro
return new GenerateResult
{
Score = sequence.Score,
- Result = Tokenizer.Decode(sequence.Tokens)
+ Result = Tokenizer.Decode(sequence.Tokens),
+ Tokens = sequence.Tokens,
+ LastHiddenState = sequence.LastHiddenState
};
}
}
@@ -77,7 +79,9 @@ public async Task RunAsync(SearchOptions options, IProgress RunAsync(SearchOptions options, IProgress
+ /// Gets the LastHiddenState.
+ ///
+ /// The options.
+ /// The cancellation token.
+ public async Task> GetLastHiddenState(GenerateOptions options, CancellationToken cancellationToken = default)
+ {
+ await TokenizePromptAsync(options);
+ using (var sequence = await InitializeAsync(options))
+ {
+ return sequence.LastHiddenState;
+ }
+ }
+
+
///
/// Gets the token processors.
///
@@ -173,9 +192,9 @@ private Tensor RunDecoderInternal(ModelMetadata modelMetadata, Sequence s
{
var dimension = logitsResult.GetDimensions();
var logits = logitsResult.ToTensor(dimension[1..]);
- var presentKeyValues = modelResult.ToArray()[1..];
-
- sequence.UpdateCache(presentKeyValues, useBranchCache);
+ var lastHiddenState = Configuration.OutputLastHiddenStates ? modelResult[^1].ToTensor() : default;
+ var presentKeyValues = Configuration.OutputLastHiddenStates ? modelResult.ToArray()[1..^1] : modelResult.ToArray()[1..];
+ sequence.UpdateCache(presentKeyValues, useBranchCache, lastHiddenState);
return logits;
}
}
@@ -200,6 +219,7 @@ public static LlamaPipeline Create(ExecutionProvider provider, string modelPath,
var vocabSize = 128256;
var config = new LlamaConfig
{
+ OutputLastHiddenStates = true,
Tokenizer = new BPETokenizer(new TokenizerConfig
{
BOS = 128000,
diff --git a/TensorStack.TextGeneration/Sequence.cs b/TensorStack.TextGeneration/Sequence.cs
index cfc321c..524f4cc 100644
--- a/TensorStack.TextGeneration/Sequence.cs
+++ b/TensorStack.TextGeneration/Sequence.cs
@@ -3,6 +3,7 @@
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
+using TensorStack.Common.Tensor;
using TensorStack.TextGeneration.Cache;
namespace TensorStack.TextGeneration
@@ -10,6 +11,7 @@ namespace TensorStack.TextGeneration
public sealed class Sequence : IDisposable
{
private IKVCache _cache;
+ private Tensor _lastHiddenState;
///
/// Initializes a new instance of the class.
@@ -55,6 +57,11 @@ private Sequence(List tokens, float score, IKVCache cache)
///
public OrtValue[] Cache => _cache.Values;
+ ///
+ /// Gets the LastHiddenState.
+ ///
+ public Tensor LastHiddenState => _lastHiddenState;
+
///
/// Gets or sets the sequnece score.
///
@@ -89,8 +96,9 @@ public bool Initialize(int initialLength)
///
/// The current values.
/// if set to true use branch cache.
- public void UpdateCache(OrtValue[] currentValues, bool useBranchCache)
+ public void UpdateCache(OrtValue[] currentValues, bool useBranchCache, Tensor lastHiddenState = default)
{
+ _lastHiddenState = lastHiddenState;
_cache.Update(currentValues, useBranchCache);
}