Skip to content

Commit

Permalink
Update with configuration available in SwiftDiffusionCLI
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Feb 13, 2023
1 parent bf5dca8 commit 0960286
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import CoreML
/// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c)
///
@available(iOS 16.2, macOS 13.1, *)
struct NumPyRandomSource: RandomNumberGenerator {
struct NumPyRandomSource: RandomNumberGenerator, RandomSource {

struct State {
var key = [UInt32](repeating: 0, count: 624)
Expand Down
6 changes: 6 additions & 0 deletions swift/StableDiffusion/pipeline/RandomSource.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import CoreML

@available(iOS 16.2, macOS 13.1, *)
public protocol RandomSource {
mutating func normalShapedArray(_ shape: [Int], mean: Double, stdev: Double) -> MLShapedArray<Double>
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ extension StableDiffusionPipeline {
public var disableSafety: Bool = false
/// The type of Scheduler to use.
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
/// The type of RNG to use
public var rngType: StableDiffusionRNG = .numpyRNG

/// Given the configuration, what mode will be used for generation
public var mode: Mode {
Expand Down
30 changes: 23 additions & 7 deletions swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ public enum StableDiffusionScheduler {
case dpmSolverMultistepScheduler
}

/// RNG compatible with StableDiffusionPipeline
public enum StableDiffusionRNG {
/// RNG that matches numpy implementation
case numpyRNG
/// RNG that matches PyTorch CPU implementation.
case torchRNG
}

/// A pipeline used to generate image samples from text input using stable diffusion
///
/// This implementation matches:
Expand Down Expand Up @@ -157,7 +165,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
throw Error.startingImageProvidedWithoutEncoder
}

let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, rng: config.rngType, stdev: 1, seed: config.seed)
latents = try noiseTuples.map({
try encoder.encode(
image: startingImage,
Expand All @@ -168,7 +176,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
} else {
timestepStrength = nil
// Generate random latent samples from specified seed
latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
latents = generateLatentSamples(config.imageCount, rng: config.rngType, stdev: stdev, seed: config.seed)
}

// De-noising loop
Expand Down Expand Up @@ -224,11 +232,19 @@ public struct StableDiffusionPipeline: ResourceManaging {
return try decodeToImages(latents, disableSafety: config.disableSafety)
}

func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
private func randomSource(from rng: StableDiffusionRNG, seed: UInt32) -> RandomSource {
switch rng {
case .numpyRNG:
return NumPyRandomSource(seed: seed)
case .torchRNG:
return TorchRandomSource(seed: seed)
}
}

func generateLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
var sampleShape = unet.latentSampleShape
sampleShape[0] = 1

var random = NumPyRandomSource(seed: seed)
var random = randomSource(from: rng, seed: seed)
let samples = (0..<count).map { _ in
MLShapedArray<Float32>(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
Expand All @@ -245,11 +261,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation,
/// but I have seen implementations of pipelines where it is the same.
/// - Returns: An array of tuples of noise values with length of batch size.
func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
func generateImage2ImageLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
var sampleShape = unet.latentSampleShape
sampleShape[0] = 1

var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed))
var random = randomSource(from: rng, seed: seed)
let samples = (0..<count).map { _ in
if diagonalAndLatentNoiseIsSame {
let noise = MLShapedArray<Float32>(
Expand Down
2 changes: 1 addition & 1 deletion swift/StableDiffusion/pipeline/TorchRandomSource.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import CoreML
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/TransformationHelper.h
///
@available(iOS 16.2, macOS 13.1, *)
struct TorchRandomSource: RandomNumberGenerator {
struct TorchRandomSource: RandomNumberGenerator, RandomSource {

struct State {
var key = [UInt32](repeating: 0, count: 624)
Expand Down
14 changes: 14 additions & 0 deletions swift/StableDiffusionCLI/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ struct StableDiffusionSample: ParsableCommand {
@Option(help: "Scheduler to use, one of {pndm, dpmpp}")
var scheduler: SchedulerOption = .pndm

@Option(help: "Random number generator to use, one of {numpy, torch}")
var rng: RNGOption = .numpy

@Flag(help: "Disable safety checking")
var disableSafety: Bool = false

Expand Down Expand Up @@ -250,6 +253,17 @@ enum SchedulerOption: String, ExpressibleByArgument {
}
}

@available(iOS 16.2, macOS 13.1, *)
enum RNGOption: String, ExpressibleByArgument {
case numpy, torch
var stableDiffusionRNG: StableDiffusionRNG {
switch self {
case .numpy: return .numpyRNG
case .torch: return .torchRNG
}
}
}

if #available(iOS 16.2, macOS 13.1, *) {
StableDiffusionSample.main()
} else {
Expand Down

0 comments on commit 0960286

Please sign in to comment.