diff --git a/TensorStack.StableDiffusion/Common/EncoderCache.cs b/TensorStack.StableDiffusion/Common/EncoderCache.cs new file mode 100644 index 0000000..590f338 --- /dev/null +++ b/TensorStack.StableDiffusion/Common/EncoderCache.cs @@ -0,0 +1,22 @@ +using System; +using TensorStack.Common.Tensor; + +namespace TensorStack.StableDiffusion.Common +{ + public record EncoderCache + { + public ImageTensor InputImage { get; init; } + public Tensor CacheResult { get; init; } + + public bool IsValid(ImageTensor input) + { + if (input is null || InputImage is null) + return false; + + if (!InputImage.Span.SequenceEqual(input.Span)) + return false; + + return true; + } + } +} diff --git a/TensorStack.StableDiffusion/Common/GenerateOptions.cs b/TensorStack.StableDiffusion/Common/GenerateOptions.cs index 8a0f0a5..b981f82 100644 --- a/TensorStack.StableDiffusion/Common/GenerateOptions.cs +++ b/TensorStack.StableDiffusion/Common/GenerateOptions.cs @@ -45,7 +45,7 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions public bool IsLowMemoryEncoderEnabled { get; set; } public bool IsLowMemoryDecoderEnabled { get; set; } public bool IsLowMemoryTextEncoderEnabled { get; set; } - + public bool IsPipelineCacheEnabled { get; set; } = true; public bool HasControlNet => ControlNet is not null; public bool HasInputImage => InputImage is not null; diff --git a/TensorStack.StableDiffusion/Common/PromptCache.cs b/TensorStack.StableDiffusion/Common/PromptCache.cs new file mode 100644 index 0000000..4b8da96 --- /dev/null +++ b/TensorStack.StableDiffusion/Common/PromptCache.cs @@ -0,0 +1,18 @@ +using System; +using TensorStack.StableDiffusion.Pipelines; + +namespace TensorStack.StableDiffusion.Common +{ + public record PromptCache + { + public string Conditional { get; init; } + public string Unconditional { get; init; } + public PromptResult CacheResult { get; init; } + + public bool IsValid(IPipelineOptions options) + { + return string.Equals(Conditional, options.Prompt, StringComparison.OrdinalIgnoreCase) + && string.Equals(Unconditional, options.NegativePrompt, StringComparison.OrdinalIgnoreCase); + } + } +} diff --git a/TensorStack.StableDiffusion/Common/TextEncoderResult.cs b/TensorStack.StableDiffusion/Common/TextEncoderResult.cs index f1add21..385a43c 100644 --- a/TensorStack.StableDiffusion/Common/TextEncoderResult.cs +++ b/TensorStack.StableDiffusion/Common/TextEncoderResult.cs @@ -1,6 +1,5 @@ // Copyright (c) TensorStack. All rights reserved. // Licensed under the Apache 2.0 License. -using System; using TensorStack.Common.Tensor; namespace TensorStack.StableDiffusion.Common @@ -32,6 +31,4 @@ public Tensor GetHiddenStates(int index) return _hiddenStates[0]; } } - - public record TextEncoderBatchedResult(Memory PromptEmbeds, Memory PromptPooledEmbeds); } diff --git a/TensorStack.StableDiffusion/Models/CLIPTextModel.cs b/TensorStack.StableDiffusion/Models/CLIPTextModel.cs index ce3c153..2967775 100644 --- a/TensorStack.StableDiffusion/Models/CLIPTextModel.cs +++ b/TensorStack.StableDiffusion/Models/CLIPTextModel.cs @@ -4,7 +4,6 @@ using System.Threading; using System.Threading.Tasks; using TensorStack.Common; -using TensorStack.Common.Tensor; using TensorStack.StableDiffusion.Common; using TensorStack.StableDiffusion.Config; using TensorStack.TextGeneration.Tokenizers; diff --git a/TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs b/TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs index 0163155..0809e3e 100644 --- a/TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs @@ -148,6 +148,10 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + // Tokenize2 var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); @@ -179,7 +183,7 @@ protected async Task CreatePromptAsync(IPipelineOptions options, C var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds; negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch(); - return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds); + return SetPromptCache(options, new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds)); } @@ -264,16 +268,23 @@ protected async Task DecodeLatentsAsync(IPipelineOptions options, T /// The options. /// The latents. /// The cancellation token. - private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + private async Task> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); - var inputTensor = image.ResizeImage(options.Width, options.Height); + var cacheResult = GetEncoderCache(options); + if (cacheResult is not null) + { + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result."); + return cacheResult; + } + + var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height); var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) await AutoEncoder.EncoderUnloadAsync(); Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); - return encoderResult; + return SetEncoderCache(options, encoderResult); } @@ -392,7 +403,7 @@ private async Task> CreateLatentInputAsync(IPipelineOptions option if (options.HasInputImage) { var timestep = scheduler.GetStartTimestep(); - var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + var encoderResult = await EncodeLatentsAsync(options, cancellationToken); var noiseTensor = scheduler.CreateRandomSample(encoderResult.Dimensions); return PackLatents(scheduler.ScaleNoise(timestep, encoderResult, noiseTensor)); } @@ -410,8 +421,8 @@ private async Task> CreateLatentInputAsync(IPipelineOptions option /// protected Tensor CreateLatentImageIds(IPipelineOptions options) { - var height = options.Height / AutoEncoder.LatentChannels; - var width = options.Width / AutoEncoder.LatentChannels; + var height = options.Height / AutoEncoder.LatentChannels; + var width = options.Width / AutoEncoder.LatentChannels; var latentIds = new Tensor([height, width, 3]); for (int i = 0; i < height; i++) diff --git a/TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs b/TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs index 2ed0d7c..4e889b5 100644 --- a/TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs +++ b/TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs @@ -18,7 +18,7 @@ public record FluxConfig : PipelineConfig public FluxConfig() { Tokenizer = new TokenizerConfig(); - Tokenizer2 = new TokenizerConfig{MaxLength = 512 }; + Tokenizer2 = new TokenizerConfig { MaxLength = 512 }; TextEncoder = new CLIPModelConfig(); TextEncoder2 = new CLIPModelConfig { diff --git a/TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs b/TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs index 79d3f9e..e1c1192 100644 --- a/TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs +++ b/TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs @@ -27,7 +27,7 @@ public interface IPipelineOptions : IRunOptions float ControlNetStrength { get; set; } ImageTensor InputControlImage { get; set; } - int ClipSkip{ get; set; } + int ClipSkip { get; set; } float AestheticScore { get; set; } float AestheticNegativeScore { get; set; } @@ -36,7 +36,7 @@ public interface IPipelineOptions : IRunOptions bool IsLowMemoryEncoderEnabled { get; set; } bool IsLowMemoryDecoderEnabled { get; set; } bool IsLowMemoryTextEncoderEnabled { get; set; } - + bool IsPipelineCacheEnabled { get; set; } bool HasControlNet => ControlNet is not null; bool HasInputImage => InputImage is not null; diff --git a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs index dbb1b03..ccde5a3 100644 --- a/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs @@ -137,6 +137,10 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + // Conditional Prompt var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions { @@ -159,7 +163,7 @@ protected async Task CreatePromptAsync(IPipelineOptions options, C }, cancellationToken); } - return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default); + return SetPromptCache(options, new PromptResult(promptEmbeds, default, negativePromptEmbeds, default)); } @@ -187,16 +191,23 @@ protected async Task DecodeLatentsAsync(IPipelineOptions options, T /// The options. /// The latents. /// The cancellation token. - private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + private async Task> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); - var inputTensor = image.ResizeImage(options.Width, options.Height); + var cacheResult = GetEncoderCache(options); + if (cacheResult is not null) + { + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result."); + return cacheResult; + } + + var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height); var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) await AutoEncoder.EncoderUnloadAsync(); Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); - return encoderResult; + return SetEncoderCache(options, encoderResult); } @@ -270,7 +281,7 @@ private async Task> CreateLatentInputAsync(IPipelineOptions option if (options.HasInputImage) { var timestep = scheduler.GetStartTimestep(); - var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + var encoderResult = await EncodeLatentsAsync(options, cancellationToken); return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor); } return noiseTensor; diff --git a/TensorStack.StableDiffusion/Pipelines/PipelineBase.cs b/TensorStack.StableDiffusion/Pipelines/PipelineBase.cs index 10faac0..49054b1 100644 --- a/TensorStack.StableDiffusion/Pipelines/PipelineBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/PipelineBase.cs @@ -15,6 +15,8 @@ namespace TensorStack.StableDiffusion.Pipelines { public abstract class PipelineBase : IDisposable { + private PromptCache _promptCache; + private EncoderCache _encoderCache; private GenerateOptions _defaultOptions; private IReadOnlyList _schedulers; @@ -149,11 +151,78 @@ protected Tensor ApplyGuidance(Tensor conditional, Tensor u } + /// + /// Gets the prompt cache. + /// + /// The options. + protected PromptResult GetPromptCache(IPipelineOptions options) + { + if (!options.IsPipelineCacheEnabled) + return default; + + if (_promptCache is null || !_promptCache.IsValid(options)) + return default; + + return _promptCache.CacheResult; + } + + + /// + /// Sets the prompt cache. + /// + /// The options. + /// The prompt result to cache. + protected PromptResult SetPromptCache(IPipelineOptions options, PromptResult promptResult) + { + _promptCache = new PromptCache + { + CacheResult = promptResult, + Conditional = options.Prompt, + Unconditional = options.NegativePrompt, + }; + return promptResult; + } + + + /// + /// Gets the encoder cache. + /// + /// The options. + protected Tensor GetEncoderCache(IPipelineOptions options) + { + if (!options.IsPipelineCacheEnabled) + return default; + + if (_encoderCache is null || !_encoderCache.IsValid(options.InputImage)) + return default; + + return _encoderCache.CacheResult; + } + + + /// + /// Sets the encoder cache. + /// + /// The options. + /// The encoded. + protected Tensor SetEncoderCache(IPipelineOptions options, Tensor encoded) + { + _encoderCache = new EncoderCache + { + InputImage = options.InputImage, + CacheResult = encoded + }; + return encoded; + } + + /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// public void Dispose() { + _promptCache = null; + _encoderCache = null; Dispose(disposing: true); GC.SuppressFinalize(this); } diff --git a/TensorStack.StableDiffusion/Pipelines/PipelineConfig.cs b/TensorStack.StableDiffusion/Pipelines/PipelineConfig.cs index 63119b8..18e50f1 100644 --- a/TensorStack.StableDiffusion/Pipelines/PipelineConfig.cs +++ b/TensorStack.StableDiffusion/Pipelines/PipelineConfig.cs @@ -11,7 +11,7 @@ public abstract record PipelineConfig /// /// Gets or sets the type. /// - public abstract PipelineType Pipeline { get;} + public abstract PipelineType Pipeline { get; } /// /// Saves the configuration to file. diff --git a/TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadeBase.cs b/TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadeBase.cs index 2300290..1c6a32c 100644 --- a/TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadeBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadeBase.cs @@ -147,6 +147,10 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + // Tokenizer var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); @@ -167,7 +171,7 @@ protected async Task CreatePromptAsync(IPipelineOptions options, C ? negativePromptEmbeddings.TextEmbeds : negativePromptEmbeddings.TextEmbeds.Reshape([1, .. negativePromptEmbeddings.TextEmbeds.Dimensions]); - return new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds); + return SetPromptCache(options, new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds)); } @@ -217,7 +221,7 @@ protected async Task> RunPriorAsync(GenerateOptions options, Tenso // Create latent sample var latents = await CreatePriorLatentInputAsync(options, scheduler, cancellationToken); - var image = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + var image = await EncodeLatentsAsync(options, cancellationToken); // Get Model metadata var metadata = await PriorUnet.LoadAsync(cancellationToken: cancellationToken); @@ -387,7 +391,7 @@ private Task> CreateDecoderLatentsAsync(GenerateOptions options, I /// The options. /// The latents. /// The cancellation token. - private Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + private Task> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { return Task.FromResult(new Tensor([1, 1, ImageEncoder.HiddenSize])); } diff --git a/TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionBase.cs b/TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionBase.cs index 7a3ec5f..ff01175 100644 --- a/TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionBase.cs @@ -146,6 +146,10 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + // Tokenizer var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); @@ -157,7 +161,7 @@ protected async Task CreatePromptAsync(IPipelineOptions options, C if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) await TextEncoder.UnloadAsync(); - return new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds); + return SetPromptCache(options, new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds)); } @@ -217,13 +221,20 @@ protected async Task DecodeLatentsAsync(IPipelineOptions options, T private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) { var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); + var cacheResult = GetEncoderCache(options); + if (cacheResult is not null) + { + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result."); + return cacheResult; + } + var inputTensor = image.ResizeImage(options.Width, options.Height); var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) await AutoEncoder.EncoderUnloadAsync(); Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); - return encoderResult; + return SetEncoderCache(options, encoderResult); } diff --git a/TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Base.cs b/TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Base.cs index 5f65d9b..9cf5591 100644 --- a/TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Base.cs +++ b/TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Base.cs @@ -165,6 +165,10 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + // Tokenizer var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); @@ -217,7 +221,7 @@ protected async Task CreatePromptAsync(IPipelineOptions options, C negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch(); negativePromptPooledEmbeds = negativePromptPooledEmbeds.Concatenate(negativePromptPooledEmbeds2, 1); - return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds); + return SetPromptCache(options, new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds)); } @@ -338,16 +342,23 @@ protected async Task DecodeLatentsAsync(IPipelineOptions options, T /// The options. /// The latents. /// The cancellation token. - private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + private async Task> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); - var inputTensor = image.ResizeImage(options.Width, options.Height); + var cacheResult = GetEncoderCache(options); + if (cacheResult is not null) + { + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result."); + return cacheResult; + } + + var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height); var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) await AutoEncoder.EncoderUnloadAsync(); Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); - return encoderResult; + return SetEncoderCache(options, encoderResult); } @@ -521,7 +532,7 @@ private async Task> CreateLatentInputAsync(IPipelineOptions option if (options.HasInputImage) { var timestep = scheduler.GetStartTimestep(); - var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + var encoderResult = await EncodeLatentsAsync(options, cancellationToken); return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor); } return noiseTensor; diff --git a/TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLBase.cs b/TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLBase.cs index a327878..e0b8d9f 100644 --- a/TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLBase.cs +++ b/TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLBase.cs @@ -146,6 +146,10 @@ protected override void ValidateOptions(GenerateOptions options) /// The cancellation token. protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + // Tokenizer var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); @@ -176,7 +180,7 @@ protected async Task CreatePromptAsync(IPipelineOptions options, C var pooledNegativePromptEmbeds = negativePrompt2Embeddings.TextEmbeds; var negativePromptEmbeddings = negativePrompt1Embeddings.HiddenStates.Concatenate(negativePrompt2Embeddings.HiddenStates, 2); - return new PromptResult(promptEmbeddings, pooledPromptEmbeds, negativePromptEmbeddings, pooledNegativePromptEmbeds); + return SetPromptCache(options, new PromptResult(promptEmbeddings, pooledPromptEmbeds, negativePromptEmbeddings, pooledNegativePromptEmbeds)); } @@ -267,16 +271,23 @@ protected async Task DecodeLatentsAsync(IPipelineOptions options, T /// The options. /// The latents. /// The cancellation token. - private async Task> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default) + private async Task> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default) { var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); - var inputTensor = image.ResizeImage(options.Width, options.Height); + var cacheResult = GetEncoderCache(options); + if (cacheResult is not null) + { + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result."); + return cacheResult; + } + + var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height); var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) await AutoEncoder.EncoderUnloadAsync(); Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); - return encoderResult; + return SetEncoderCache(options, encoderResult); } @@ -458,7 +469,7 @@ private async Task> CreateLatentInputAsync(IPipelineOptions option if (options.HasInputImage) { var timestep = scheduler.GetStartTimestep(); - var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken); + var encoderResult = await EncodeLatentsAsync(options, cancellationToken); return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor); } return noiseTensor.Multiply(scheduler.StartSigma);