Skip to content

Commit f6eb0aa

Browse files
authored
Merge pull request #2 from TensorStack-AI/AMDNitro
AMD Nitro Pipeline
2 parents e1a07cb + 1047b11 commit f6eb0aa

File tree

10 files changed

+696
-6
lines changed

10 files changed

+696
-6
lines changed

TensorStack.StableDiffusion/Enums/PipelineType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ public enum PipelineType
1111
StableCascade = 10,
1212
LatentConsistency = 20,
1313
Flux = 30,
14+
Nitro = 40
1415
}
1516
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System.Threading;
4+
using System.Threading.Tasks;
5+
using TensorStack.Common;
6+
using TensorStack.Common.Tensor;
7+
using TensorStack.StableDiffusion.Config;
8+
9+
namespace TensorStack.StableDiffusion.Models
10+
{
11+
/// <summary>
12+
/// TransformerModel: Nitro Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
13+
/// </summary>
14+
public class TransformerNitroModel : TransformerModel
15+
{
16+
/// <summary>
17+
/// Initializes a new instance of the <see cref="TransformerNitroModel"/> class.
18+
/// </summary>
19+
/// <param name="configuration">The configuration.</param>
20+
public TransformerNitroModel(TransformerModelConfig configuration)
21+
: base(configuration) { }
22+
23+
24+
/// <summary>
25+
/// Runs the Transformer model with the specified inputs
26+
/// </summary>
27+
/// <param name="timestep">The timestep.</param>
28+
/// <param name="hiddenStates">The hidden states.</param>
29+
/// <param name="encoderHiddenStates">The encoder hidden states.</param>
30+
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
31+
/// <returns>A Task&lt;Tensor`1&gt; representing the asynchronous operation.</returns>
32+
public async Task<Tensor<float>> RunAsync(int timestep, Tensor<float> hiddenStates, Tensor<float> encoderHiddenStates, CancellationToken cancellationToken = default)
33+
{
34+
if (!Transformer.IsLoaded())
35+
await Transformer.LoadAsync(cancellationToken: cancellationToken);
36+
37+
using (var transformerParams = new ModelParameters(Transformer.Metadata, cancellationToken))
38+
{
39+
// Inputs
40+
transformerParams.AddInput(hiddenStates.AsTensorSpan());
41+
transformerParams.AddInput(encoderHiddenStates.AsTensorSpan());
42+
transformerParams.AddScalarInput(timestep);
43+
44+
// Outputs
45+
transformerParams.AddOutput(hiddenStates.Dimensions);
46+
47+
// Inference
48+
using (var results = await Transformer.RunInferenceAsync(transformerParams))
49+
{
50+
return results[0].ToTensor();
51+
}
52+
}
53+
}
54+
55+
}
56+
}

0 commit comments

Comments
 (0)