11// Copyright (c) TensorStack. All rights reserved.
22// Licensed under the Apache 2.0 License.
33using Microsoft . Extensions . Logging ;
4- using Microsoft . ML . Tokenizers ;
54using System ;
65using System . Collections . Generic ;
76using System . Diagnostics ;
@@ -21,7 +20,7 @@ namespace TensorStack.StableDiffusion.Pipelines.Nitro
2120 public abstract class NitroBase : PipelineBase
2221 {
2322 /// <summary>
24- /// Initializes a new instance of the <see cref="vBase "/> class.
23+ /// Initializes a new instance of the <see cref="NitroBase "/> class.
2524 /// </summary>
2625 /// <param name="transformer">The transformer.</param>
2726 /// <param name="textEncoder">The text encoder.</param>
@@ -46,6 +45,7 @@ public NitroBase(NitroConfig configuration, ILogger logger = default) : this(
4645 new TransformerNitroModel ( configuration . Transformer ) ,
4746 new LlamaPipeline ( new LlamaConfig
4847 {
48+ OutputLastHiddenStates = true ,
4949 DecoderConfig = configuration . TextEncoder ,
5050 Tokenizer = new BPETokenizer ( configuration . Tokenizer ) ,
5151 } ) ,
@@ -128,46 +128,28 @@ protected override void ValidateOptions(GenerateOptions options)
128128 /// <param name="cancellationToken">The cancellation token.</param>
129129 protected async Task < PromptResult > CreatePromptAsync ( IPipelineOptions options , CancellationToken cancellationToken = default )
130130 {
131- //// Tokenize2
132- //var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
133- //var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
134- //var maxTokenLength = (int)Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
135-
136- //// Tokenizer2
137- //var prompt2Tokens = await TokenizePrompt2Async(options.Prompt, cancellationToken);
138- //var negativePrompt2Tokens = await TokenizePrompt2Async(options.NegativePrompt, cancellationToken);
139-
140- //// TextEncoder
141- //var promptEmbeddings = await EncodePromptAsync(promptTokens, maxTokenLength, cancellationToken);
142- //var negativePromptEmbeddings = await EncodePromptAsync(negativePromptTokens, maxTokenLength, cancellationToken);
143- //if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
144- // await TextEncoder.UnloadAsync();
145-
146-
147-
148- //// Prompt
149- //var promptEmbeds = prompt2Embeddings.HiddenStates;
150- //var promptPooledEmbeds = promptEmbeddings.TextEmbeds;
151- //promptPooledEmbeds = promptPooledEmbeds.Reshape([promptPooledEmbeds.Dimensions[^2], promptPooledEmbeds.Dimensions[^1]]).FirstBatch();
152-
153- //// Negative promt
154- //var negativePromptEmbeds = negativePrompt2Embeddings.HiddenStates;
155- //var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds;
156- //negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch();
157-
158- //return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds);
159-
160-
161- var result = TextEncoder . RunAsync ( new TextGeneration . Common . GenerateOptions
131+ // Conditional Prompt
132+ var promptEmbeds = await TextEncoder . GetLastHiddenState ( new TextGeneration . Common . GenerateOptions
162133 {
163134 Seed = options . Seed ,
164- Prompt = options . Prompt
165- } ) ;
166-
167- return default ;
168- }
135+ Prompt = options . Prompt ,
136+ MaxLength = 128
137+ } , cancellationToken ) ;
169138
139+ // Unconditional prompt
140+ var negativePromptEmbeds = default ( Tensor < float > ) ;
141+ if ( ! string . IsNullOrEmpty ( options . NegativePrompt ) )
142+ {
143+ negativePromptEmbeds = await TextEncoder . GetLastHiddenState ( new TextGeneration . Common . GenerateOptions
144+ {
145+ Seed = options . Seed ,
146+ Prompt = options . NegativePrompt ,
147+ MaxLength = 128
148+ } , cancellationToken ) ;
149+ }
170150
151+ return new PromptResult ( promptEmbeds , default , negativePromptEmbeds , default ) ;
152+ }
171153
172154
173155 /// <summary>
@@ -372,7 +354,7 @@ protected override GenerateOptions ConfigureDefaultOptions()
372354 Scheduler = SchedulerType . FlowMatchEulerDiscrete
373355 } ;
374356
375- // SD3-Turbo Models , 4 Steps, No Guidance
357+ // Nitro-Distilled Models ,4 Steps, No Guidance
376358 if ( Transformer . ModelType == ModelType . Turbo )
377359 {
378360 return options with
0 commit comments