Skip to content

Commit 1047b11

Browse files
committed
Nitro TextEncoder
1 parent c084cde commit 1047b11

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
33
using Microsoft.Extensions.Logging;
4-
using Microsoft.ML.Tokenizers;
54
using System;
65
using System.Collections.Generic;
76
using System.Diagnostics;
@@ -21,7 +20,7 @@ namespace TensorStack.StableDiffusion.Pipelines.Nitro
2120
public abstract class NitroBase : PipelineBase
2221
{
2322
/// <summary>
24-
/// Initializes a new instance of the <see cref="vBase"/> class.
23+
/// Initializes a new instance of the <see cref="NitroBase"/> class.
2524
/// </summary>
2625
/// <param name="transformer">The transformer.</param>
2726
/// <param name="textEncoder">The text encoder.</param>
@@ -46,6 +45,7 @@ public NitroBase(NitroConfig configuration, ILogger logger = default) : this(
4645
new TransformerNitroModel(configuration.Transformer),
4746
new LlamaPipeline(new LlamaConfig
4847
{
48+
OutputLastHiddenStates = true,
4949
DecoderConfig = configuration.TextEncoder,
5050
Tokenizer = new BPETokenizer(configuration.Tokenizer),
5151
}),
@@ -128,46 +128,28 @@ protected override void ValidateOptions(GenerateOptions options)
128128
/// <param name="cancellationToken">The cancellation token.</param>
129129
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
130130
{
131-
//// Tokenize2
132-
//var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
133-
//var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
134-
//var maxTokenLength = (int)Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
135-
136-
//// Tokenizer2
137-
//var prompt2Tokens = await TokenizePrompt2Async(options.Prompt, cancellationToken);
138-
//var negativePrompt2Tokens = await TokenizePrompt2Async(options.NegativePrompt, cancellationToken);
139-
140-
//// TextEncoder
141-
//var promptEmbeddings = await EncodePromptAsync(promptTokens, maxTokenLength, cancellationToken);
142-
//var negativePromptEmbeddings = await EncodePromptAsync(negativePromptTokens, maxTokenLength, cancellationToken);
143-
//if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
144-
// await TextEncoder.UnloadAsync();
145-
146-
147-
148-
//// Prompt
149-
//var promptEmbeds = prompt2Embeddings.HiddenStates;
150-
//var promptPooledEmbeds = promptEmbeddings.TextEmbeds;
151-
//promptPooledEmbeds = promptPooledEmbeds.Reshape([promptPooledEmbeds.Dimensions[^2], promptPooledEmbeds.Dimensions[^1]]).FirstBatch();
152-
153-
//// Negative promt
154-
//var negativePromptEmbeds = negativePrompt2Embeddings.HiddenStates;
155-
//var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds;
156-
//negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch();
157-
158-
//return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds);
159-
160-
161-
var result = TextEncoder.RunAsync(new TextGeneration.Common.GenerateOptions
131+
// Conditional Prompt
132+
var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
162133
{
163134
Seed = options.Seed,
164-
Prompt = options.Prompt
165-
});
166-
167-
return default;
168-
}
135+
Prompt = options.Prompt,
136+
MaxLength = 128
137+
}, cancellationToken);
169138

139+
// Unconditional prompt
140+
var negativePromptEmbeds = default(Tensor<float>);
141+
if (!string.IsNullOrEmpty(options.NegativePrompt))
142+
{
143+
negativePromptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
144+
{
145+
Seed = options.Seed,
146+
Prompt = options.NegativePrompt,
147+
MaxLength = 128
148+
}, cancellationToken);
149+
}
170150

151+
return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default);
152+
}
171153

172154

173155
/// <summary>
@@ -372,7 +354,7 @@ protected override GenerateOptions ConfigureDefaultOptions()
372354
Scheduler = SchedulerType.FlowMatchEulerDiscrete
373355
};
374356

375-
// SD3-Turbo Models , 4 Steps, No Guidance
357+
// Nitro-Distilled Models ,4 Steps, No Guidance
376358
if (Transformer.ModelType == ModelType.Turbo)
377359
{
378360
return options with

TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using TensorStack.StableDiffusion.Config;
66
using TensorStack.StableDiffusion.Enums;
77
using TensorStack.TextGeneration.Common;
8-
using TensorStack.TextGeneration.Pipelines.Llama;
98
using TensorStack.TextGeneration.Tokenizers;
109

1110
namespace TensorStack.StableDiffusion.Pipelines.Nitro
@@ -26,8 +25,8 @@ public NitroConfig()
2625
{
2726
NumHeads = 32,
2827
NumLayers = 16,
29-
HiddenSize = 2048,
3028
NumKVHeads = 8,
29+
HiddenSize = 2048,
3130
VocabSize = 128256
3231
};
3332
Transformer = new TransformerModelConfig
@@ -40,8 +39,7 @@ public NitroConfig()
4039
AutoEncoder = new AutoEncoderModelConfig
4140
{
4241
LatentChannels = 32,
43-
ScaleFactor = 0.3611f,
44-
ShiftFactor = 0.1159f
42+
ScaleFactor = 0.41407f
4543
};
4644
}
4745

@@ -116,7 +114,7 @@ public static NitroConfig FromFile(string configFile, ExecutionProvider executio
116114
public static NitroConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
117115
{
118116
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
119-
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
117+
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer");
120118
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
121119
config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
122120
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");

TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ public NitroPipeline(NitroConfig configuration, ILogger logger = null)
4444
public async Task<ImageTensor> RunAsync(GenerateOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
4545
{
4646
ValidateOptions(options);
47-
4847
var prompt = await CreatePromptAsync(options, cancellationToken);
4948
using (var scheduler = CreateScheduler(options))
5049
{

TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<Ge
8989
}
9090

9191

92+
/// <summary>
93+
/// Gets the LastHiddenState.
94+
/// </summary>
95+
/// <param name="options">The options.</param>
96+
/// <param name="cancellationToken">The cancellation token.</param>
97+
public async Task<Tensor<float>> GetLastHiddenState(GenerateOptions options, CancellationToken cancellationToken = default)
98+
{
99+
await TokenizePromptAsync(options);
100+
using (var sequence = await InitializeAsync(options))
101+
{
102+
return sequence.LastHiddenState;
103+
}
104+
}
105+
106+
92107
/// <summary>
93108
/// Gets the token processors.
94109
/// </summary>

0 commit comments

Comments
 (0)