Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions TensorStack.StableDiffusion/Common/EncoderCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using TensorStack.Common.Tensor;

namespace TensorStack.StableDiffusion.Common
{
public record EncoderCache
{
public ImageTensor InputImage { get; init; }
public Tensor<float> CacheResult { get; init; }

public bool IsValid(ImageTensor input)
{
if (input is null || InputImage is null)
return false;

if (!InputImage.Span.SequenceEqual(input.Span))
return false;

return true;
}
}
}
2 changes: 1 addition & 1 deletion TensorStack.StableDiffusion/Common/GenerateOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions
public bool IsLowMemoryEncoderEnabled { get; set; }
public bool IsLowMemoryDecoderEnabled { get; set; }
public bool IsLowMemoryTextEncoderEnabled { get; set; }

public bool IsPipelineCacheEnabled { get; set; } = true;

public bool HasControlNet => ControlNet is not null;
public bool HasInputImage => InputImage is not null;
Expand Down
18 changes: 18 additions & 0 deletions TensorStack.StableDiffusion/Common/PromptCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System;
using TensorStack.StableDiffusion.Pipelines;

namespace TensorStack.StableDiffusion.Common
{
public record PromptCache
{
public string Conditional { get; init; }
public string Unconditional { get; init; }
public PromptResult CacheResult { get; init; }

public bool IsValid(IPipelineOptions options)
{
return string.Equals(Conditional, options.Prompt, StringComparison.OrdinalIgnoreCase)
&& string.Equals(Unconditional, options.NegativePrompt, StringComparison.OrdinalIgnoreCase);
}
}
}
3 changes: 0 additions & 3 deletions TensorStack.StableDiffusion/Common/TextEncoderResult.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright (c) TensorStack. All rights reserved.
// Licensed under the Apache 2.0 License.
using System;
using TensorStack.Common.Tensor;

namespace TensorStack.StableDiffusion.Common
Expand Down Expand Up @@ -32,6 +31,4 @@ public Tensor<float> GetHiddenStates(int index)
return _hiddenStates[0];
}
}

public record TextEncoderBatchedResult(Memory<float> PromptEmbeds, Memory<float> PromptPooledEmbeds);
}
1 change: 0 additions & 1 deletion TensorStack.StableDiffusion/Models/CLIPTextModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Threading;
using System.Threading.Tasks;
using TensorStack.Common;
using TensorStack.Common.Tensor;
using TensorStack.StableDiffusion.Common;
using TensorStack.StableDiffusion.Config;
using TensorStack.TextGeneration.Tokenizers;
Expand Down
25 changes: 18 additions & 7 deletions TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ protected override void ValidateOptions(GenerateOptions options)
/// <param name="cancellationToken">The cancellation token.</param>
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
var cachedPrompt = GetPromptCache(options);
if (cachedPrompt is not null)
return cachedPrompt;

// Tokenize2
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
Expand Down Expand Up @@ -179,7 +183,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds;
negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch();

return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds);
return SetPromptCache(options, new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds));
}


Expand Down Expand Up @@ -264,16 +268,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
/// <param name="options">The options.</param>
/// <param name="image">The latents.</param>
/// <param name="cancellationToken">The cancellation token.</param>
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
var inputTensor = image.ResizeImage(options.Width, options.Height);
var cacheResult = GetEncoderCache(options);
if (cacheResult is not null)
{
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
return cacheResult;
}

var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
await AutoEncoder.EncoderUnloadAsync();

Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
return encoderResult;
return SetEncoderCache(options, encoderResult);
}


Expand Down Expand Up @@ -392,7 +403,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
if (options.HasInputImage)
{
var timestep = scheduler.GetStartTimestep();
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
var noiseTensor = scheduler.CreateRandomSample(encoderResult.Dimensions);
return PackLatents(scheduler.ScaleNoise(timestep, encoderResult, noiseTensor));
}
Expand All @@ -410,8 +421,8 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
/// <returns></returns>
protected Tensor<float> CreateLatentImageIds(IPipelineOptions options)
{
var height = options.Height / AutoEncoder.LatentChannels;
var width = options.Width / AutoEncoder.LatentChannels;
var height = options.Height / AutoEncoder.LatentChannels;
var width = options.Width / AutoEncoder.LatentChannels;
var latentIds = new Tensor<float>([height, width, 3]);

for (int i = 0; i < height; i++)
Expand Down
2 changes: 1 addition & 1 deletion TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public record FluxConfig : PipelineConfig
public FluxConfig()
{
Tokenizer = new TokenizerConfig();
Tokenizer2 = new TokenizerConfig{MaxLength = 512 };
Tokenizer2 = new TokenizerConfig { MaxLength = 512 };
TextEncoder = new CLIPModelConfig();
TextEncoder2 = new CLIPModelConfig
{
Expand Down
4 changes: 2 additions & 2 deletions TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public interface IPipelineOptions : IRunOptions
float ControlNetStrength { get; set; }
ImageTensor InputControlImage { get; set; }

int ClipSkip{ get; set; }
int ClipSkip { get; set; }
float AestheticScore { get; set; }
float AestheticNegativeScore { get; set; }

Expand All @@ -36,7 +36,7 @@ public interface IPipelineOptions : IRunOptions
bool IsLowMemoryEncoderEnabled { get; set; }
bool IsLowMemoryDecoderEnabled { get; set; }
bool IsLowMemoryTextEncoderEnabled { get; set; }

bool IsPipelineCacheEnabled { get; set; }

bool HasControlNet => ControlNet is not null;
bool HasInputImage => InputImage is not null;
Expand Down
21 changes: 16 additions & 5 deletions TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ protected override void ValidateOptions(GenerateOptions options)
/// <param name="cancellationToken">The cancellation token.</param>
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
var cachedPrompt = GetPromptCache(options);
if (cachedPrompt is not null)
return cachedPrompt;

// Conditional Prompt
var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
{
Expand All @@ -159,7 +163,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
}, cancellationToken);
}

return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default);
return SetPromptCache(options, new PromptResult(promptEmbeds, default, negativePromptEmbeds, default));
}


Expand Down Expand Up @@ -187,16 +191,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
/// <param name="options">The options.</param>
/// <param name="image">The latents.</param>
/// <param name="cancellationToken">The cancellation token.</param>
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
var inputTensor = image.ResizeImage(options.Width, options.Height);
var cacheResult = GetEncoderCache(options);
if (cacheResult is not null)
{
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
return cacheResult;
}

var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
await AutoEncoder.EncoderUnloadAsync();

Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
return encoderResult;
return SetEncoderCache(options, encoderResult);
}


Expand Down Expand Up @@ -270,7 +281,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
if (options.HasInputImage)
{
var timestep = scheduler.GetStartTimestep();
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor);
}
return noiseTensor;
Expand Down
69 changes: 69 additions & 0 deletions TensorStack.StableDiffusion/Pipelines/PipelineBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace TensorStack.StableDiffusion.Pipelines
{
public abstract class PipelineBase : IDisposable
{
private PromptCache _promptCache;
private EncoderCache _encoderCache;
private GenerateOptions _defaultOptions;
private IReadOnlyList<SchedulerType> _schedulers;

Expand Down Expand Up @@ -149,11 +151,78 @@ protected Tensor<float> ApplyGuidance(Tensor<float> conditional, Tensor<float> u
}


/// <summary>
/// Gets the prompt cache.
/// </summary>
/// <param name="options">The options.</param>
protected PromptResult GetPromptCache(IPipelineOptions options)
{
if (!options.IsPipelineCacheEnabled)
return default;

if (_promptCache is null || !_promptCache.IsValid(options))
return default;

return _promptCache.CacheResult;
}


/// <summary>
/// Sets the prompt cache.
/// </summary>
/// <param name="options">The options.</param>
/// <param name="promptResult">The prompt result to cache.</param>
protected PromptResult SetPromptCache(IPipelineOptions options, PromptResult promptResult)
{
_promptCache = new PromptCache
{
CacheResult = promptResult,
Conditional = options.Prompt,
Unconditional = options.NegativePrompt,
};
return promptResult;
}


/// <summary>
/// Gets the encoder cache.
/// </summary>
/// <param name="options">The options.</param>
protected Tensor<float> GetEncoderCache(IPipelineOptions options)
{
if (!options.IsPipelineCacheEnabled)
return default;

if (_encoderCache is null || !_encoderCache.IsValid(options.InputImage))
return default;

return _encoderCache.CacheResult;
}


/// <summary>
/// Sets the encoder cache.
/// </summary>
/// <param name="options">The options.</param>
/// <param name="encoded">The encoded.</param>
protected Tensor<float> SetEncoderCache(IPipelineOptions options, Tensor<float> encoded)
{
_encoderCache = new EncoderCache
{
InputImage = options.InputImage,
CacheResult = encoded
};
return encoded;
}


/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
public void Dispose()
{
_promptCache = null;
_encoderCache = null;
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
Expand Down
2 changes: 1 addition & 1 deletion TensorStack.StableDiffusion/Pipelines/PipelineConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public abstract record PipelineConfig
/// <summary>
/// Gets or sets the type.
/// </summary>
public abstract PipelineType Pipeline { get;}
public abstract PipelineType Pipeline { get; }

/// <summary>
/// Saves the configuration to file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ protected override void ValidateOptions(GenerateOptions options)
/// <param name="cancellationToken">The cancellation token.</param>
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
var cachedPrompt = GetPromptCache(options);
if (cachedPrompt is not null)
return cachedPrompt;

// Tokenizer
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
Expand All @@ -167,7 +171,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
? negativePromptEmbeddings.TextEmbeds
: negativePromptEmbeddings.TextEmbeds.Reshape([1, .. negativePromptEmbeddings.TextEmbeds.Dimensions]);

return new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds);
return SetPromptCache(options, new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds));
}


Expand Down Expand Up @@ -217,7 +221,7 @@ protected async Task<Tensor<float>> RunPriorAsync(GenerateOptions options, Tenso
// Create latent sample
var latents = await CreatePriorLatentInputAsync(options, scheduler, cancellationToken);

var image = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
var image = await EncodeLatentsAsync(options, cancellationToken);

// Get Model metadata
var metadata = await PriorUnet.LoadAsync(cancellationToken: cancellationToken);
Expand Down Expand Up @@ -387,7 +391,7 @@ private Task<Tensor<float>> CreateDecoderLatentsAsync(GenerateOptions options, I
/// <param name="options">The options.</param>
/// <param name="image">The latents.</param>
/// <param name="cancellationToken">The cancellation token.</param>
private Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
private Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
return Task.FromResult(new Tensor<float>([1, 1, ImageEncoder.HiddenSize]));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ protected override void ValidateOptions(GenerateOptions options)
/// <param name="cancellationToken">The cancellation token.</param>
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
{
var cachedPrompt = GetPromptCache(options);
if (cachedPrompt is not null)
return cachedPrompt;

// Tokenizer
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
Expand All @@ -157,7 +161,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
await TextEncoder.UnloadAsync();

return new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds);
return SetPromptCache(options, new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds));
}


Expand Down Expand Up @@ -217,13 +221,20 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
{
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
var cacheResult = GetEncoderCache(options);
if (cacheResult is not null)
{
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
return cacheResult;
}

var inputTensor = image.ResizeImage(options.Width, options.Height);
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
await AutoEncoder.EncoderUnloadAsync();

Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
return encoderResult;
return SetEncoderCache(options, encoderResult);
}


Expand Down
Loading