diff --git a/Mochi Diffusion.xcodeproj/project.pbxproj b/Mochi Diffusion.xcodeproj/project.pbxproj index bdd56573..568b6a7d 100644 --- a/Mochi Diffusion.xcodeproj/project.pbxproj +++ b/Mochi Diffusion.xcodeproj/project.pbxproj @@ -16,7 +16,6 @@ 03173C132999E2B500B03456 /* SDModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03173C122999E2B500B03456 /* SDModel.swift */; }; 03173C152999F5C700B03456 /* ImageController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03173C142999F5C700B03456 /* ImageController.swift */; }; 03173C18299B3C1300B03456 /* ImageStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03173C17299B3C1300B03456 /* ImageStore.swift */; }; - 0352E2A5294E2591003FBF25 /* Sparkle in Frameworks */ = {isa = PBXBuildFile; productRef = 0352E2A4294E2591003FBF25 /* Sparkle */; }; 0352E2A7294E3148003FBF25 /* HelpCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0352E2A6294E3148003FBF25 /* HelpCommands.swift */; }; 0352E2AA294EA0E1003FBF25 /* AppView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0352E2A9294EA0E1003FBF25 /* AppView.swift */; }; 0352E2AE294EA2B4003FBF25 /* MessageBanner.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0352E2AD294EA2B4003FBF25 /* MessageBanner.swift */; }; @@ -40,7 +39,6 @@ 0386DF8D2950F25500CA4CEB /* GalleryToolbarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0386DF8C2950F25500CA4CEB /* GalleryToolbarView.swift */; }; 0388DEB0297B00FC008B1C1C /* CompactSlider in Frameworks */ = {isa = PBXBuildFile; productRef = 0388DEAF297B00FC008B1C1C /* CompactSlider */; }; 0395B3112995C70400465B73 /* Scheduler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0395B3102995C70400465B73 /* Scheduler.swift */; }; - 03ADC8B9299581AF00B2843F /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = 03ADC8B8299581AF00B2843F /* StableDiffusion */; }; 03B1ACA8295167B900302F54 /* SettingsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03B1ACA7295167B900302F54 /* SettingsView.swift */; }; 03CB9D18297399AE0041A4FA /* ImageCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03CB9D17297399AE0041A4FA /* ImageCommands.swift */; }; 03D280F2294FD60E00C7D184 /* PromptView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03D280F1294FD60E00C7D184 /* PromptView.swift */; }; @@ -61,6 +59,10 @@ C1220FC02AFB122F007E5055 /* ImageWellView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1220FBF2AFB122F007E5055 /* ImageWellView.swift */; }; C1ADEC2C2B16957800E142CA /* FolderMonitor.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1ADEC2B2B16957800E142CA /* FolderMonitor.swift */; }; D7B03F2029D42F9900DF89DD /* SDModelAttentionType.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7B03F1F29D42F9900DF89DD /* SDModelAttentionType.swift */; }; + E45FA5822B7F7E4B009E90F0 /* GuernikaKit in Frameworks */ = {isa = PBXBuildFile; productRef = E45FA5812B7F7E4B009E90F0 /* GuernikaKit */; }; + E4EC46322B890B6D00351E8C /* Sparkle in Frameworks */ = {isa = PBXBuildFile; productRef = E4EC46312B890B6D00351E8C /* Sparkle */; }; + E4F15B1D2B86720F00E1EC3C /* de-coremldata.bin in Resources */ = {isa = PBXBuildFile; fileRef = E4F15B1C2B86720F00E1EC3C /* de-coremldata.bin */; }; + E4F15B1F2B8673DB00E1EC3C /* en-coremldata.bin in Resources */ = {isa = PBXBuildFile; fileRef = E4F15B1E2B8673DB00E1EC3C /* en-coremldata.bin */; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ @@ -132,6 +134,8 @@ C1220FBF2AFB122F007E5055 /* ImageWellView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageWellView.swift; sourceTree = ""; }; C1ADEC2B2B16957800E142CA /* FolderMonitor.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FolderMonitor.swift; sourceTree = ""; }; D7B03F1F29D42F9900DF89DD /* SDModelAttentionType.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SDModelAttentionType.swift; sourceTree = ""; }; + E4F15B1C2B86720F00E1EC3C /* de-coremldata.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; path = "de-coremldata.bin"; sourceTree = ""; }; + E4F15B1E2B8673DB00E1EC3C /* en-coremldata.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; path = "en-coremldata.bin"; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -140,8 +144,8 @@ buildActionMask = 2147483647; files = ( 036BFC4E294B9FDB00D8AD04 /* Path in Frameworks */, - 03ADC8B9299581AF00B2843F /* StableDiffusion in Frameworks */, - 0352E2A5294E2591003FBF25 /* Sparkle in Frameworks */, + E4EC46322B890B6D00351E8C /* Sparkle in Frameworks */, + E45FA5822B7F7E4B009E90F0 /* GuernikaKit in Frameworks */, 0388DEB0297B00FC008B1C1C /* CompactSlider in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; @@ -275,6 +279,8 @@ isa = PBXGroup; children = ( 036BFC25294B9F7600D8AD04 /* Assets.xcassets */, + E4F15B1E2B8673DB00E1EC3C /* en-coremldata.bin */, + E4F15B1C2B86720F00E1EC3C /* de-coremldata.bin */, 03FD2318295F74B6006EEEE2 /* RealESRGAN.mlmodel */, ); path = Resources; @@ -299,9 +305,9 @@ name = "Mochi Diffusion"; packageProductDependencies = ( 036BFC4D294B9FDB00D8AD04 /* Path */, - 0352E2A4294E2591003FBF25 /* Sparkle */, 0388DEAF297B00FC008B1C1C /* CompactSlider */, - 03ADC8B8299581AF00B2843F /* StableDiffusion */, + E45FA5812B7F7E4B009E90F0 /* GuernikaKit */, + E4EC46312B890B6D00351E8C /* Sparkle */, ); productName = "Mochi Diffusion"; productReference = 036BFC1E294B9F7500D8AD04 /* Mochi Diffusion.app */; @@ -315,7 +321,7 @@ attributes = { BuildIndependentTargetsInParallel = 1; LastSwiftUpdateCheck = 1420; - LastUpgradeCheck = 1430; + LastUpgradeCheck = 1510; TargetAttributes = { 036BFC1D294B9F7500D8AD04 = { CreatedOnToolsVersion = 14.2; @@ -368,9 +374,9 @@ mainGroup = 036BFC15294B9F7500D8AD04; packageReferences = ( 036BFC4C294B9FDB00D8AD04 /* XCRemoteSwiftPackageReference "Path.swift" */, - 0352E2A3294E2591003FBF25 /* XCRemoteSwiftPackageReference "Sparkle" */, 0388DEAE297B00FC008B1C1C /* XCRemoteSwiftPackageReference "CompactSlider" */, - 03ADC8B7299581AF00B2843F /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */, + E45FA5802B7F7E4B009E90F0 /* XCRemoteSwiftPackageReference "GuernikaKit" */, + E4EC46302B890B6D00351E8C /* XCRemoteSwiftPackageReference "Sparkle" */, ); productRefGroup = 036BFC1F294B9F7500D8AD04 /* Products */; projectDirPath = ""; @@ -390,6 +396,8 @@ 03F578642967CE9A003A815F /* Localizable.strings in Resources */, 0311C6BD2989E3EF0074BCAE /* Localizable.stringsdict in Resources */, 036BFC26294B9F7600D8AD04 /* Assets.xcassets in Resources */, + E4F15B1F2B8673DB00E1EC3C /* en-coremldata.bin in Resources */, + E4F15B1D2B86720F00E1EC3C /* de-coremldata.bin in Resources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -512,7 +520,6 @@ isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; - ARCHS = arm64; CLANG_ANALYZER_LOCALIZABILITY_NONLOCALIZED = YES; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; @@ -547,6 +554,7 @@ DEBUG_INFORMATION_FORMAT = dwarf; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = NO; GCC_C_LANGUAGE_STANDARD = gnu11; GCC_DYNAMIC_NO_PIC = NO; GCC_NO_COMMON_BLOCKS = YES; @@ -576,7 +584,6 @@ isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; - ARCHS = arm64; CLANG_ANALYZER_LOCALIZABILITY_NONLOCALIZED = YES; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; @@ -611,6 +618,7 @@ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = NO; GCC_C_LANGUAGE_STANDARD = gnu11; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; @@ -622,6 +630,7 @@ MACOSX_DEPLOYMENT_TARGET = 14.0; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; SDKROOT = macosx; SWIFT_COMPILATION_MODE = wholemodule; SWIFT_EMIT_LOC_STRINGS = YES; @@ -638,10 +647,10 @@ CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; COMBINE_HIDPI_IMAGES = YES; - CURRENT_PROJECT_VERSION = 5.0; + CURRENT_PROJECT_VERSION = 6.0; DEAD_CODE_STRIPPING = YES; DEVELOPMENT_ASSET_PATHS = "\"Mochi Diffusion/Preview Content\""; - DEVELOPMENT_TEAM = TCQ6328PP6; + DEVELOPMENT_TEAM = "TCQ6328PP6"; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; @@ -653,7 +662,7 @@ "$(inherited)", "@executable_path/../Frameworks", ); - MARKETING_VERSION = 5.0; + MARKETING_VERSION = 6.0; PRODUCT_BUNDLE_IDENTIFIER = "com.joshua-park.Mochi-Diffusion"; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; @@ -671,10 +680,10 @@ CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; COMBINE_HIDPI_IMAGES = YES; - CURRENT_PROJECT_VERSION = 5.0; + CURRENT_PROJECT_VERSION = 6.0; DEAD_CODE_STRIPPING = YES; DEVELOPMENT_ASSET_PATHS = "\"Mochi Diffusion/Preview Content\""; - DEVELOPMENT_TEAM = TCQ6328PP6; + DEVELOPMENT_TEAM = "TCQ6328PP6"; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; @@ -686,7 +695,7 @@ "$(inherited)", "@executable_path/../Frameworks", ); - MARKETING_VERSION = 5.0; + MARKETING_VERSION = 6.0; PRODUCT_BUNDLE_IDENTIFIER = "com.joshua-park.Mochi-Diffusion"; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; @@ -719,14 +728,6 @@ /* End XCConfigurationList section */ /* Begin XCRemoteSwiftPackageReference section */ - 0352E2A3294E2591003FBF25 /* XCRemoteSwiftPackageReference "Sparkle" */ = { - isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/sparkle-project/Sparkle"; - requirement = { - kind = upToNextMajorVersion; - minimumVersion = 2.0.0; - }; - }; 036BFC4C294B9FDB00D8AD04 /* XCRemoteSwiftPackageReference "Path.swift" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/mxcl/Path.swift.git"; @@ -743,22 +744,25 @@ minimumVersion = 1.0.0; }; }; - 03ADC8B7299581AF00B2843F /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = { + E45FA5802B7F7E4B009E90F0 /* XCRemoteSwiftPackageReference "GuernikaKit" */ = { isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/apple/ml-stable-diffusion"; + repositoryURL = "https://github.com/GuernikaCore/GuernikaKit.git"; requirement = { - branch = main; - kind = branch; + kind = exactVersion; + version = 1.6.1; + }; + }; + E4EC46302B890B6D00351E8C /* XCRemoteSwiftPackageReference "Sparkle" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/sparkle-project/Sparkle.git"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 2.5.2; }; }; /* End XCRemoteSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ - 0352E2A4294E2591003FBF25 /* Sparkle */ = { - isa = XCSwiftPackageProductDependency; - package = 0352E2A3294E2591003FBF25 /* XCRemoteSwiftPackageReference "Sparkle" */; - productName = Sparkle; - }; 036BFC4D294B9FDB00D8AD04 /* Path */ = { isa = XCSwiftPackageProductDependency; package = 036BFC4C294B9FDB00D8AD04 /* XCRemoteSwiftPackageReference "Path.swift" */; @@ -769,10 +773,15 @@ package = 0388DEAE297B00FC008B1C1C /* XCRemoteSwiftPackageReference "CompactSlider" */; productName = CompactSlider; }; - 03ADC8B8299581AF00B2843F /* StableDiffusion */ = { + E45FA5812B7F7E4B009E90F0 /* GuernikaKit */ = { + isa = XCSwiftPackageProductDependency; + package = E45FA5802B7F7E4B009E90F0 /* XCRemoteSwiftPackageReference "GuernikaKit" */; + productName = GuernikaKit; + }; + E4EC46312B890B6D00351E8C /* Sparkle */ = { isa = XCSwiftPackageProductDependency; - package = 03ADC8B7299581AF00B2843F /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */; - productName = StableDiffusion; + package = E4EC46302B890B6D00351E8C /* XCRemoteSwiftPackageReference "Sparkle" */; + productName = Sparkle; }; /* End XCSwiftPackageProductDependency section */ }; diff --git a/Mochi Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Mochi Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 983c2e46..0133f142 100644 --- a/Mochi Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Mochi Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -10,12 +10,12 @@ } }, { - "identity" : "ml-stable-diffusion", + "identity" : "guernikakit", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/ml-stable-diffusion", + "location" : "https://github.com/GuernikaCore/GuernikaKit.git", "state" : { - "branch" : "main", - "revision" : "d456a972cd7d84cab2ec353a29896d59b8602248" + "revision" : "500b1f80ca14159afe7a9d14b7ce100413a568fc", + "version" : "1.6.1" } }, { @@ -28,21 +28,30 @@ } }, { - "identity" : "sparkle", + "identity" : "randomgenerator", "kind" : "remoteSourceControl", - "location" : "https://github.com/sparkle-project/Sparkle", + "location" : "https://github.com/GuernikaCore/RandomGenerator.git", "state" : { - "revision" : "47d3d90aee3c52b6f61d04ceae426e607df62347", - "version" : "2.5.2" + "revision" : "7c91e2f454ecc753075a526cf5ffc34d3c4a10c5", + "version" : "1.0.0" } }, { - "identity" : "swift-argument-parser", + "identity" : "schedulers", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-argument-parser.git", + "location" : "https://github.com/GuernikaCore/Schedulers.git", "state" : { - "revision" : "fddd1c00396eed152c45a46bea9f47b98e59301d", - "version" : "1.2.0" + "revision" : "1f517514d679e38bb9915c3a74bf04f75d5b5875", + "version" : "1.4.0" + } + }, + { + "identity" : "sparkle", + "kind" : "remoteSourceControl", + "location" : "https://github.com/sparkle-project/Sparkle.git", + "state" : { + "revision" : "47d3d90aee3c52b6f61d04ceae426e607df62347", + "version" : "2.5.2" } } ], diff --git a/Mochi Diffusion/App.swift b/Mochi Diffusion/App.swift index 31fceb81..58542e35 100644 --- a/Mochi Diffusion/App.swift +++ b/Mochi Diffusion/App.swift @@ -56,6 +56,10 @@ struct MochiDiffusionApp: App { /// cleanup MPS temp folder let mpsURL = FileManager.default.temporaryDirectory.appendingPathComponent("com.apple.MetalPerformanceShadersGraph", isDirectory: true) try? FileManager.default.removeItem(at: mpsURL) + /// cleanup modified variable size models + if let variableSizeModelDir = controller.variableSizeModelDir { + try? FileManager.default.removeItem(at: variableSizeModelDir) + } } } .commands { diff --git a/Mochi Diffusion/Model/SDControlNet.swift b/Mochi Diffusion/Model/SDControlNet.swift index cc5a652d..e79f1f14 100644 --- a/Mochi Diffusion/Model/SDControlNet.swift +++ b/Mochi Diffusion/Model/SDControlNet.swift @@ -8,21 +8,28 @@ import CoreGraphics import Foundation +enum ControlType { + case controlNet + case t2IAdapter + case all +} + struct SDControlNet { let name: String let url: URL let size: CGSize let attention: SDModelAttentionType + let controltype: ControlType init?(url: URL) { - guard let size = identifyControlNetSize(url), let attention = identifyControlNetAttentionType(url) else { + guard let size = identifyControlNetSize(url), let attention = identifyControlNetAttentionType(url), let type = identifyControlNetType(url) else { return nil } - self.name = url.deletingPathExtension().lastPathComponent self.url = url self.size = size self.attention = attention + self.controltype = type } } @@ -49,7 +56,7 @@ private func identifyControlNetSize(_ url: URL) -> CGSize? { return nil } - guard let controlnetCond = inputSchema.first(where: { ($0["name"] as? String) == "controlnet_cond" }) else { + guard let controlnetCond = inputSchema.first(where: { ($0["name"] as? String) == "controlnet_cond" }) ?? inputSchema.first(where: { ($0["name"] as? String) == "input" }) else { print("Error: 'controlnet_cond' not found in 'inputSchema'") return nil } @@ -96,3 +103,33 @@ private func identifyControlNetAttentionType(_ url: URL) -> SDModelAttentionType return nil } } + +private func identifyControlNetType(_ url: URL) -> ControlType? { + let metadataURL = url.appendingPathComponent("metadata.json") + + guard let jsonData = try? Data(contentsOf: metadataURL) else { + print("Error: Could not read data from \(metadataURL)") + return nil + } + + guard let jsonArray = (try? JSONSerialization.jsonObject(with: jsonData)) as? [[String: Any]] else { + print("Error: Could not parse JSON data") + return nil + } + + guard let jsonItem = jsonArray.first else { + print("Error: JSON array is empty") + return nil + } + + guard let inputSchema = jsonItem["inputSchema"] as? [[String: Any]] else { + print("Error: Missing 'inputSchema' in JSON") + return nil + } + + if inputSchema.first(where: { ($0["name"] as? String) == "controlnet_cond" }) != nil { + return .controlNet + } else { + return .t2IAdapter + } +} diff --git a/Mochi Diffusion/Model/SDImage.swift b/Mochi Diffusion/Model/SDImage.swift index 529cdc5a..37f8d29f 100644 --- a/Mochi Diffusion/Model/SDImage.swift +++ b/Mochi Diffusion/Model/SDImage.swift @@ -9,7 +9,7 @@ import AppKit import CoreGraphics import CoreML import Foundation -import StableDiffusion +import GuernikaKit import UniformTypeIdentifiers struct SDImage: Identifiable, Hashable { @@ -21,11 +21,11 @@ struct SDImage: Identifiable, Hashable { var height: Int { self.image?.height ?? 0 } var aspectRatio: CGFloat = 0.0 var model = "" - var scheduler = Scheduler.dpmSolverMultistepScheduler + var scheduler = Scheduler.dpmSolverMultistepKarras var mlComputeUnit: MLComputeUnits? var seed: UInt32 = 0 - var steps = 28 - var guidanceScale = 11.0 + var steps = 15 + var guidanceScale = 5.0 var generatedDate = Date() var upscaler = "" var isUpscaling = false diff --git a/Mochi Diffusion/Model/SDModel.swift b/Mochi Diffusion/Model/SDModel.swift index ab535638..1d3b924e 100644 --- a/Mochi Diffusion/Model/SDModel.swift +++ b/Mochi Diffusion/Model/SDModel.swift @@ -17,28 +17,35 @@ struct SDModel: Identifiable { let attention: SDModelAttentionType let controlNet: [String] let isXL: Bool - let inputSize: CGSize? + var inputSize: CGSize? + let controltype: ControlType? + let allowsVariableSize: Bool + private let vaeAllowsVariableSize: Bool var id: URL { url } init?(url: URL, name: String, controlNet: [SDControlNet]) { - guard let attention = identifyAttentionType(url) else { + guard + let attention = identifyAttentionType(url), + let allowsVariableSize = identifyAllowsVariableSize(url), + let vaeAllowsVariableSize = identifyVaeAllowsVariableSize(url), + let size = identifyInputSize(url) + else { return nil } let isXL = identifyIfXL(url) - let size = identifyInputSize(url) + let controltype = identifyControlNetType(url) self.url = url self.name = name self.attention = attention - if let size = size { - self.controlNet = controlNet.filter { $0.size == size && $0.attention == attention }.map { $0.name } - } else { - self.controlNet = [] - } + self.controlNet = controlNet.filter { $0.size == size && $0.attention == attention && $0.controltype == controltype ?? .all }.map { $0.name } self.isXL = isXL self.inputSize = size + self.controltype = controltype + self.allowsVariableSize = allowsVariableSize + self.vaeAllowsVariableSize = vaeAllowsVariableSize } } @@ -48,6 +55,206 @@ extension SDModel: Hashable { } } +extension SDModel { + /// replace VAEEncoder.mlmodelc/coremldata.bin with en-coremldata.bin + /// replace VAEDecoder.mlmodelc/coremldata.bin with de-coremldata.bin + func resizeableCopy(target: URL, controlNet: [SDControlNet] = []) -> SDModel? { + guard allowsVariableSize && !vaeAllowsVariableSize else { + return nil + } + do { + if !FileManager.default.fileExists(atPath: target.path(percentEncoded: false)) { + try recursiveHardLink(source: url, target: target) + + let encoderBinURL = target.appending(components: "VAEEncoder.mlmodelc", "coremldata.bin") + try? FileManager.default.removeItem(at: encoderBinURL) + try FileManager.default.copyItem(at: Bundle.main.url(forResource: "en-coremldata", withExtension: "bin")!, to: encoderBinURL) + + let decoderBinURL = target.appending(components: "VAEDecoder.mlmodelc", "coremldata.bin") + try? FileManager.default.removeItem(at: decoderBinURL) + try FileManager.default.copyItem(at: Bundle.main.url(forResource: "de-coremldata", withExtension: "bin")!, to: decoderBinURL) + + let encoderMilURL = target.appending(components: "VAEEncoder.mlmodelc", "model.mil") + let encoderMilBakURL = target.appending(components: "VAEEncoder.mlmodelc", "model.mil.bak") + try? FileManager.default.removeItem(at: encoderMilURL) + try? FileManager.default.removeItem(at: encoderMilBakURL) + try FileManager.default.copyItem(at: url.appending(components: "VAEEncoder.mlmodelc", "model.mil"), to: encoderMilBakURL) + + let decoderMilURL = target.appending(components: "VAEDecoder.mlmodelc", "model.mil") + let decoderMilBakURL = target.appending(components: "VAEDecoder.mlmodelc", "model.mil.bak") + try? FileManager.default.removeItem(at: decoderMilURL) + try? FileManager.default.removeItem(at: decoderMilBakURL) + try FileManager.default.copyItem(at: url.appending(components: "VAEDecoder.mlmodelc", "model.mil"), to: decoderMilBakURL) + + let encoderMetadataURL = target.appending(components: "VAEEncoder.mlmodelc", "metadata.json") + try? FileManager.default.removeItem(at: encoderMetadataURL) + try FileManager.default.copyItem(at: url.appending(components: "VAEEncoder.mlmodelc", "metadata.json"), to: encoderMetadataURL) + } + + return SDModel(url: target, name: name, controlNet: controlNet) + } catch { + print("ERROR: Unable to create resizeable copy of SDModel \(name) \(error)") + return nil + } + } + + /// overwrite shape data in VAEEncoder.mlmodelc/model.mil + func modifyEncoderMil(width: Int, height: Int) { + let milURL = url.appending(components: "VAEEncoder.mlmodelc", "model.mil") + let milBakUrl = url.appending(components: "VAEEncoder.mlmodelc", "model.mil.bak") + do { + var fileContent = try String(contentsOf: milBakUrl, encoding: .utf8) + if isXL { + fileContent = fileContent.replacingOccurrences(of: "[1, 8, 128, 128]", with: "[1, 8, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 16384, 512]", with: "[1, 1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 16384, 16384]", with: "[1, 1, \(height / 8 * width / 8), \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 16384, 1, 512]", with: "[1, \(height / 8 * width / 8), 1, 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 16384, 512]", with: "[1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 16384]", with: "[1, 512, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 16384]", with: "[1, 32, 16, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 128, 128]", with: "[1, 32, 16, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 128, 128]", with: "[1, 512, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 257, 257]", with: "[1, 512, \(height / 4 + 1), \(width / 4 + 1)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 256, 256]", with: "[1, 32, 16, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 256, 256]", with: "[1, 512, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 256, 256]", with: "[1, 32, 8, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 256, 256]", with: "[1, 256, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 513, 513]", with: "[1, 256, \(height / 2 + 1), \(width / 2 + 1)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 512, 512]", with: "[1, 32, 8, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 512, 512]", with: "[1, 256, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 4, 512, 512]", with: "[1, 32, 4, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 512, 512]", with: "[1, 128, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 1025, 1025]", with: "[1, 128, \(height + 1), \(width + 1)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 1024, 1024]", with: "[1, 128, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 4, 1024, 1024]", with: "[1, 32, 4, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 3, 1024, 1024]", with: "[1, 3, \(height), \(width)]") + } else { + fileContent = fileContent.replacingOccurrences(of: "[1, 8, 64, 64]", with: "[1, 8, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 4096, 512]", with: "[1, 1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 4096, 4096]", with: "[1, 1, \(height / 8 * width / 8), \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 4096, 1, 512]", with: "[1, \(height / 8 * width / 8), 1, 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 4096, 512]", with: "[1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 4096]", with: "[1, 512, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 4096]", with: "[1, 32, 16, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 64, 64]", with: "[1, 32, 16, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 64, 64]", with: "[1, 512, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 129, 129]", with: "[1, 512, \(height / 4 + 1), \(width / 4 + 1)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 128, 128]", with: "[1, 32, 16, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 128, 128]", with: "[1, 512, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 128, 128]", with: "[1, 32, 8, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 128, 128]", with: "[1, 256, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 257, 257]", with: "[1, 256, \(height / 2 + 1), \(width / 2 + 1)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 256, 256]", with: "[1, 32, 8, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 256, 256]", with: "[1, 256, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 4, 256, 256]", with: "[1, 32, 4, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 256, 256]", with: "[1, 128, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 513, 513]", with: "[1, 128, \(height + 1), \(width + 1)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 512, 512]", with: "[1, 128, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 4, 512, 512]", with: "[1, 32, 4, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 3, 512, 512]", with: "[1, 3, \(height), \(width)]") + } + try fileContent.write(to: milURL, atomically: false, encoding: .utf8) + } catch { + print("Error: Unable to modify \(milURL.path(percentEncoded: false))") + } + } + + /// overwrite shape data in VAEDecoder.mlmodelc/model.mil + func modifyDecoderMil(width: Int, height: Int) { + let milURL = url.appending(components: "VAEDecoder.mlmodelc", "model.mil") + let milBakURL = url.appending(components: "VAEDecoder.mlmodelc", "model.mil.bak") + do { + var fileContent = try String(contentsOf: milBakURL, encoding: .utf8) + if isXL { + fileContent = fileContent.replacingOccurrences(of: "[1, 4, 128, 128]", with: "[1, 4, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 128, 128]", with: "[1, 512, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 128, 128]", with: "[1, 32, 16, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 16384]", with: "[1, 32, 16, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 16384]", with: "[1, 512, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 16384, 512]", with: "[1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 16384, 1, 512]", with: "[1, \(height / 8 * width / 8), 1, 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 16384, 512]", with: "[1, 1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 16384, 16384]", with: "[1, 1, \(height / 8 * width / 8), \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 256, 256]", with: "[1, 512, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 256, 256]", with: "[1, 32, 16, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 512, 512]", with: "[1, 512, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 512, 512]", with: "[1, 32, 16, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 512, 512]", with: "[1, 256, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 512, 512]", with: "[1, 32, 8, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 1024, 1024]", with: "[1, 256, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 1024, 1024]", with: "[1, 32, 8, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 1024, 1024]", with: "[1, 128, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 4, 1024, 1024]", with: "[1, 32, 4, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 3, 1024, 1024]", with: "[1, 3, \(height), \(width)]") + } else { + fileContent = fileContent.replacingOccurrences(of: "[1, 4, 64, 64]", with: "[1, 4, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 64, 64]", with: "[1, 512, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 64, 64]", with: "[1, 32, 16, \(height / 8), \(width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 4096]", with: "[1, 32, 16, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 4096]", with: "[1, 512, \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 4096, 512]", with: "[1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 4096, 1, 512]", with: "[1, \(height / 8 * width / 8), 1, 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 4096, 512]", with: "[1, 1, \(height / 8 * width / 8), 512]") + fileContent = fileContent.replacingOccurrences(of: "[1, 1, 4096, 4096]", with: "[1, 1, \(height / 8 * width / 8), \(height / 8 * width / 8)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 128, 128]", with: "[1, 512, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 128, 128]", with: "[1, 32, 16, \(height / 4), \(width / 4)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 512, 256, 256]", with: "[1, 512, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 16, 256, 256]", with: "[1, 32, 16, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 256, 256]", with: "[1, 256, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 256, 256]", with: "[1, 32, 8, \(height / 2), \(width / 2)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 256, 512, 512]", with: "[1, 256, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 8, 512, 512]", with: "[1, 32, 8, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 128, 512, 512]", with: "[1, 128, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 32, 4, 512, 512]", with: "[1, 32, 4, \(height), \(width)]") + fileContent = fileContent.replacingOccurrences(of: "[1, 3, 512, 512]", with: "[1, 3, \(height), \(width)]") + } + try fileContent.write(to: milURL, atomically: false, encoding: .utf8) + } catch { + print("Error: Unable to modify \(milURL.path(percentEncoded: false))") + } + } + + /// Writes desired size value to inputSchema["shape"] of VAEEncoder.mlmodelc/metadata.json + public func modifyInputSize(width: Int, height: Int) { + guard allowsVariableSize && !vaeAllowsVariableSize else { + print("ERROR: model \(name) cannot modify input size") + return + } + + let encoderMetadataURL = url.appendingPathComponent("VAEEncoder.mlmodelc").appendingPathComponent("metadata.json") + guard + let jsonData = try? Data(contentsOf: encoderMetadataURL), + var jsonArray = try? JSONSerialization.jsonObject(with: jsonData) as? [[String: Any]], + var jsonItem = jsonArray.first, + var inputSchema = jsonItem["inputSchema"] as? [[String: Any]], + var controlnetCond = inputSchema.first, + var shapeString = controlnetCond["shape"] as? String + else { + return + } + + var shapeIntArray = shapeString.trimmingCharacters(in: CharacterSet(charactersIn: "[]")) + .components(separatedBy: ", ") + .compactMap { Int($0.trimmingCharacters(in: .whitespaces)) } + + shapeIntArray[3] = width + shapeIntArray[2] = height + shapeString = "[\(shapeIntArray.map { String($0) }.joined(separator: ", "))]" + + controlnetCond["shape"] = shapeString + inputSchema[0] = controlnetCond + jsonItem["inputSchema"] = inputSchema + jsonArray[0] = jsonItem + + if let updatedJsonData = try? JSONSerialization.data(withJSONObject: jsonArray, options: .prettyPrinted) { + try? updatedJsonData.write(to: encoderMetadataURL) + print("update metadata.") + } else { + print("Failed to update metadata.") + } + } +} + private func identifyAttentionType(_ url: URL) -> SDModelAttentionType? { guard let metadataURL = unetMetadataURL(from: url) else { logger.warning("No model metadata found at '\(url)'") @@ -103,8 +310,7 @@ private func identifyIfXL(_ url: URL) -> Bool { private func unetMetadataURL(from url: URL) -> URL? { let potentialMetadataURLs = [ url.appending(components: "Unet.mlmodelc", "metadata.json"), - url.appending(components: "UnetChunk1.mlmodelc", "metadata.json"), - url.appending(components: "ControlledUnet.mlmodelc", "metadata.json") + url.appending(components: "UnetChunk1.mlmodelc", "metadata.json") ] return potentialMetadataURLs.first { @@ -113,13 +319,16 @@ private func unetMetadataURL(from url: URL) -> URL? { } private func identifyInputSize(_ url: URL) -> CGSize? { - let encoderMetadataURL = url.appending(path: "VAEEncoder.mlmodelc").appending(path: "metadata.json") + let encoderMetadataURL = url.appending(path: "VAEDecoder.mlmodelc").appending(path: "metadata.json") if let jsonData = try? Data(contentsOf: encoderMetadataURL), let jsonArray = try? JSONSerialization.jsonObject(with: jsonData) as? [[String: Any]], let jsonItem = jsonArray.first, - let inputSchema = jsonItem["inputSchema"] as? [[String: Any]], + let inputSchema = jsonItem["outputSchema"] as? [[String: Any]], let controlnetCond = inputSchema.first, let shapeString = controlnetCond["shape"] as? String { + if shapeString == "[]"{ + return nil + } let shapeIntArray = shapeString.trimmingCharacters(in: CharacterSet(charactersIn: "[]")) .components(separatedBy: ", ") .compactMap { Int($0.trimmingCharacters(in: .whitespaces)) } @@ -130,3 +339,89 @@ private func identifyInputSize(_ url: URL) -> CGSize? { return nil } } + +private func identifyControlNetType(_ url: URL) -> ControlType? { + let metadataURL = url.appending(path: "Unet.mlmodelc").appending(path: "metadata.json") + + guard let jsonData = try? Data(contentsOf: metadataURL) else { + print("Error: Could not read data from \(metadataURL)") + return nil + } + + guard let jsonArray = (try? JSONSerialization.jsonObject(with: jsonData)) as? [[String: Any]] else { + print("Error: Could not parse JSON data") + return nil + } + + guard let jsonItem = jsonArray.first else { + print("Error: JSON array is empty") + return nil + } + + guard let inputSchema = jsonItem["inputSchema"] as? [[String: Any]] else { + print("Error: Missing 'inputSchema' in JSON") + return nil + } + + if inputSchema.first(where: { ($0["name"] as? String) == "adapter_res_samples_00" }) != nil && inputSchema.first(where: { ($0["name"] as? String) == "down_block_res_samples_00" }) != nil { + return .all + } else if inputSchema.first(where: { ($0["name"] as? String) == "adapter_res_samples_00" }) != nil { + return .t2IAdapter + } else { + return .controlNet + } +} + +// swiftlint:disable discouraged_optional_boolean +private func identifyAllowsVariableSize(_ url: URL) -> Bool? { + let metadataURL = url.appending(path: "Unet.mlmodelc").appending(path: "metadata.json") + + guard let jsonData = try? Data(contentsOf: metadataURL) else { + print("Error: Could not read data from \(metadataURL)") + return nil + } + + guard let jsonArray = (try? JSONSerialization.jsonObject(with: jsonData)) as? [[String: Any]] else { + print("Error: Could not parse JSON data") + return nil + } + + guard let jsonItem = jsonArray.first else { + print("Error: JSON array is empty") + return nil + } + + guard let inputSchema = jsonItem["inputSchema"] as? [[String: Any]] else { + print("Error: Missing 'inputSchema' in JSON") + return nil + } + + return inputSchema.first { ($0["hasShapeFlexibility"] as? String) == "1" } != nil +} + +private func identifyVaeAllowsVariableSize(_ url: URL) -> Bool? { + let metadataURL = url.appending(path: "VAEDecoder.mlmodelc").appending(path: "metadata.json") + + guard let jsonData = try? Data(contentsOf: metadataURL) else { + print("Error: Could not read data from \(metadataURL)") + return nil + } + + guard let jsonArray = (try? JSONSerialization.jsonObject(with: jsonData)) as? [[String: Any]] else { + print("Error: Could not parse JSON data") + return nil + } + + guard let jsonItem = jsonArray.first else { + print("Error: JSON array is empty") + return nil + } + + guard let inputSchema = jsonItem["inputSchema"] as? [[String: Any]] else { + print("Error: Missing 'inputSchema' in JSON") + return nil + } + + return inputSchema.first { ($0["hasShapeFlexibility"] as? String) == "1" } != nil +} +// swiftlint:enable discouraged_optional_boolean diff --git a/Mochi Diffusion/Model/Scheduler.swift b/Mochi Diffusion/Model/Scheduler.swift index 851dca8f..926d4f31 100644 --- a/Mochi Diffusion/Model/Scheduler.swift +++ b/Mochi Diffusion/Model/Scheduler.swift @@ -5,21 +5,60 @@ // Created by Joshua Park on 2/9/23. // -import StableDiffusion +import Schedulers /// Schedulers compatible with StableDiffusionPipeline enum Scheduler: String, CaseIterable { - /// Scheduler that uses a pseudo-linear multi-step (PLMS) method - case pndmScheduler = "PNDM" - /// Scheduler that uses a second order DPM-Solver++ algorithm - case dpmSolverMultistepScheduler = "DPM-Solver++" + case ddim = "DDIM" + + case dpmSolverMultistep = "DPM++ 2M" + + case dpmSolverMultistepKarras = "DPM++ 2M Karras" + + case dpmSolverSinglestep = "DPM++ SDE" + + case dpmSolverSinglestepKarras = "DPM++ SDE Karras" + + case dpm2 = "DPM2" + + case dpm2Karras = "DPM2 Karras" + + case eulerDiscrete = "Euler" + + case eulerDiscreteKarras = "Euler Karras" + + case eulerAncestralDiscrete = "Euler Ancestral" + + case lcm = "LCM" + + case pndm = "PNDM" } -func convertScheduler(_ scheduler: Scheduler) -> StableDiffusionScheduler { +func convertScheduler(_ scheduler: Scheduler) -> Schedulers { switch scheduler { - case .pndmScheduler: - return StableDiffusionScheduler.pndmScheduler - case .dpmSolverMultistepScheduler: - return StableDiffusionScheduler.dpmSolverMultistepScheduler + case .ddim: + return .ddim + case .dpmSolverMultistep: + return .dpmSolverMultistep + case .dpmSolverMultistepKarras: + return .dpmSolverMultistepKarras + case .dpmSolverSinglestep: + return .dpmSolverSinglestep + case .dpmSolverSinglestepKarras: + return .dpmSolverSinglestepKarras + case .dpm2: + return .dpm2 + case .dpm2Karras: + return .dpm2Karras + case .eulerDiscrete: + return .eulerDiscrete + case .eulerDiscreteKarras: + return .eulerDiscreteKarras + case .eulerAncestralDiscrete: + return .eulerAncenstralDiscrete + case .lcm: + return .lcm + case .pndm: + return .pndm } } diff --git a/Mochi Diffusion/Resources/de-coremldata.bin b/Mochi Diffusion/Resources/de-coremldata.bin new file mode 100644 index 00000000..c6028cb3 Binary files /dev/null and b/Mochi Diffusion/Resources/de-coremldata.bin differ diff --git a/Mochi Diffusion/Resources/en-coremldata.bin b/Mochi Diffusion/Resources/en-coremldata.bin new file mode 100644 index 00000000..be5ebd2c Binary files /dev/null and b/Mochi Diffusion/Resources/en-coremldata.bin differ diff --git a/Mochi Diffusion/Support/Extensions.swift b/Mochi Diffusion/Support/Extensions.swift index eb5a7688..c5b0d335 100644 --- a/Mochi Diffusion/Support/Extensions.swift +++ b/Mochi Diffusion/Support/Extensions.swift @@ -7,7 +7,6 @@ import CompactSlider import CoreML -import StableDiffusion import SwiftUI import UniformTypeIdentifiers @@ -228,3 +227,14 @@ extension MLComputeUnits { } } } + +extension FileManager { + func temporaryDirectoryInSameVolume(as url: URL, name: String = "temp") throws -> URL { + let defaultTempDir = FileManager.default.temporaryDirectory + if (try defaultTempDir.resourceValues(forKeys: [.volumeURLKey]).volume) == (try url.resourceValues(forKeys: [.volumeURLKey]).volume) { + return defaultTempDir.appending(path: name) + } else { + return url.deletingLastPathComponent().appending(path: name) + } + } +} diff --git a/Mochi Diffusion/Support/Functions.swift b/Mochi Diffusion/Support/Functions.swift index 7a86d27c..643c28bb 100644 --- a/Mochi Diffusion/Support/Functions.swift +++ b/Mochi Diffusion/Support/Functions.swift @@ -9,7 +9,6 @@ import AppKit import CoreGraphics import CoreML import Foundation -import StableDiffusion import UniformTypeIdentifiers func compareVersion(_ thisVersion: String, _ compareTo: String) -> ComparisonResult { @@ -58,7 +57,7 @@ func createSDImageFromURL(_ url: URL) -> SDImage? { case Metadata.upscaler: sdi.upscaler = String(value) case Metadata.scheduler: - sdi.scheduler = Scheduler(rawValue: String(value))! + sdi.scheduler = Scheduler(rawValue: String(value)) ?? Scheduler.dpmSolverMultistepKarras case Metadata.mlComputeUnit: sdi.mlComputeUnit = MLComputeUnits.fromString(value) case Metadata.generator: @@ -86,3 +85,26 @@ func formatTimeRemaining(_ interval: Double?, stepsLeft: Int) -> String { return formattedString ?? "-" } + +func recursiveHardLink(source: URL, target: URL) throws { + let fileManager = FileManager.default + + guard fileManager.fileExists(atPath: source.path) else { + throw NSError(domain: "Source does not exist", code: 1, userInfo: nil) + } + + guard !fileManager.fileExists(atPath: target.path) else { + throw NSError(domain: "Target already exists", code: 1, userInfo: nil) + } + + if let isDirectory = try? source.resourceValues(forKeys: [.isDirectoryKey]).isDirectory, isDirectory { + try fileManager.createDirectory(at: target, withIntermediateDirectories: true, attributes: nil) + let contents = try fileManager.contentsOfDirectory(at: source, includingPropertiesForKeys: nil, options: []) + for item in contents { + let itemTarget = target.appending(component: item.lastPathComponent) + try recursiveHardLink(source: item, target: itemTarget) + } + } else { + try fileManager.linkItem(at: source, to: target) + } +} diff --git a/Mochi Diffusion/Support/ImageController.swift b/Mochi Diffusion/Support/ImageController.swift index 3e8e6cae..a1d5ddbd 100644 --- a/Mochi Diffusion/Support/ImageController.swift +++ b/Mochi Diffusion/Support/ImageController.swift @@ -7,12 +7,12 @@ import CoreML import Foundation +import GuernikaKit import os -import StableDiffusion import SwiftUI import UniformTypeIdentifiers -typealias StableDiffusionProgress = StableDiffusionPipeline.Progress +typealias StableDiffusionProgress = DiffusionProgress enum ComputeUnitPreference: String { case auto @@ -59,6 +59,9 @@ final class ImageController: ObservableObject { @Published var startingImage: CGImage? + @Published + var maskImage: CGImage? + @Published var numberOfImages = 1.0 @@ -90,7 +93,13 @@ final class ImageController: ObservableObject { guard let model = currentModel else { return } + if let modelHeight = currentModel?.inputSize?.height { + self.height = Int(modelHeight) + } + if let modelWidth = currentModel?.inputSize?.width { + self.width = Int(modelWidth) + } modelName = model.name controlNet = model.controlNet currentControlNets = [] @@ -100,7 +109,13 @@ final class ImageController: ObservableObject { @Published private(set) var currentControlNets: [(name: String?, image: CGImage?)] = [] - @AppStorage("ModelDir") var modelDir = "" + @AppStorage("ModelDir") var modelDir = "" { + didSet { + Task { + await reloadModels() + } + } + } @AppStorage("ControlNetDir") var controlNetDir = "" @AppStorage("Model") private(set) var modelName = "" @AppStorage("AutosaveImages") var autosaveImages = true @@ -109,18 +124,20 @@ final class ImageController: ObservableObject { @AppStorage("Prompt") var prompt = "" @AppStorage("NegativePrompt") var negativePrompt = "" @AppStorage("ImageStrength") var strength = 0.75 - @AppStorage("Steps") var steps = 12.0 - @AppStorage("Scale") var guidanceScale = 11.0 + @AppStorage("Steps") var steps = 15.0 + @AppStorage("Scale") var guidanceScale = 5.0 @AppStorage("ImageWidth") var width = 512 @AppStorage("ImageHeight") var height = 512 - @AppStorage("Scheduler") var scheduler: Scheduler = .dpmSolverMultistepScheduler + @AppStorage("Scheduler") var scheduler: Scheduler = .dpmSolverMultistepKarras @AppStorage("UpscaleGeneratedImages") var upscaleGeneratedImages = false - @AppStorage("ShowGenerationPreview") var showGenerationPreview = true + @AppStorage("ShowHighqualityPreview") var showHighqualityPreview = false @AppStorage("MLComputeUnitPreference") var mlComputeUnitPreference: ComputeUnitPreference = .auto @AppStorage("ReduceMemory") var reduceMemory = false @AppStorage("SafetyChecker") var safetyChecker = false @AppStorage("UseTrash") var useTrash = true + var variableSizeModelDir: URL? + private var imageFolderMonitor: FolderMonitor? private var modelFolderMonitor: FolderMonitor? private var controlNetFolderMonitor: FolderMonitor? @@ -159,12 +176,12 @@ final class ImageController: ObservableObject { } self.modelFolderMonitor = FolderMonitor(path: modelDir) { Task { - await self.loadModels() + await self.reloadModels() } } self.controlNetFolderMonitor = FolderMonitor(path: controlNetDir) { Task { - await self.loadModels() + await self.reloadModels() } } } @@ -215,17 +232,39 @@ final class ImageController: ObservableObject { return finalModelDirURL } - func loadModels() async { + private func loadModels() async { + let modelDirectoryURL = directoryURL(fromPath: modelDir, defaultingTo: "MochiDiffusion/models/") + self.modelDir = modelDirectoryURL.path(percentEncoded: false) + + let controlNetDirectoryURL = directoryURL(fromPath: controlNetDir, defaultingTo: "MochiDiffusion/controlnet/") + self.controlNetDir = controlNetDirectoryURL.path(percentEncoded: false) + + await reloadModels() + } + + private func reloadModels() async { models = [] logger.info("Started loading model directory at: \"\(self.modelDir)\"") - do { - let modelDirectoryURL = directoryURL(fromPath: modelDir, defaultingTo: "MochiDiffusion/models/") - self.modelDir = modelDirectoryURL.path(percentEncoded: false) - - let controlNetDirectoryURL = directoryURL(fromPath: controlNetDir, defaultingTo: "MochiDiffusion/controlnet/") - self.controlNetDir = controlNetDirectoryURL.path(percentEncoded: false) - await self.models = try ImageGenerator.shared.getModels(modelDirectoryURL: modelDirectoryURL, controlNetDirectoryURL: controlNetDirectoryURL) + do { + if let variableSizeModelDir { + try? FileManager.default.removeItem(at: variableSizeModelDir) + } + self.variableSizeModelDir = try FileManager.default.temporaryDirectoryInSameVolume(as: URL(fileURLWithPath: modelDir), name: "variable size models") + let modelDirectoryURL = URL(filePath: self.modelDir) + let controlNetDirectoryURL = URL(filePath: self.controlNetDir) + + var vmodels = [SDModel]() + for model in try await ImageGenerator.shared.getModels(modelDirectoryURL: modelDirectoryURL, controlNetDirectoryURL: controlNetDirectoryURL) { + if model.allowsVariableSize { + if let resizeableCopy = model.resizeableCopy(target: variableSizeModelDir!.appending(component: model.name)) { + vmodels.append(resizeableCopy) + } + } else { + vmodels.append(model) + } + } + self.models = vmodels logger.info("Found \(self.models.count) model(s)") @@ -255,23 +294,50 @@ final class ImageController: ObservableObject { return } - var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: prompt) + var pipelineConfig = SampleInput(prompt: prompt) pipelineConfig.negativePrompt = negativePrompt - if let size = currentModel?.inputSize { - pipelineConfig.startingImage = startingImage?.scaledAndCroppedTo(size: size) + + if model.allowsVariableSize { + pipelineConfig.size = CGSize(width: self.width, height: self.height) + } else { + pipelineConfig.size = model.inputSize } + + if let size = pipelineConfig.size, startingImage != nil { + pipelineConfig.initImage = startingImage?.scaledAndCroppedTo(size: size) + pipelineConfig.inpaintMask = maskImage?.scaledAndCroppedTo(size: size) + } + let strength = startingImage == nil && currentControlNets.isEmpty ? 1.0 : self.strength + pipelineConfig.strength = Float(strength) pipelineConfig.stepCount = Int(steps) pipelineConfig.seed = seed + pipelineConfig.originalStepCount = 50 pipelineConfig.guidanceScale = Float(guidanceScale) - pipelineConfig.disableSafety = !safetyChecker - pipelineConfig.schedulerType = convertScheduler(scheduler) + pipelineConfig.scheduler = convertScheduler(scheduler) + for controlNet in currentControlNets { - if controlNet.name != nil, let size = currentModel?.inputSize, let image = controlNet.image?.scaledAndCroppedTo(size: size) { - pipelineConfig.controlNetInputs.append(image) + if controlNet.name != nil, let size = pipelineConfig.size, let image = controlNet.image?.scaledAndCroppedTo(size: size) { + let control = SDControlNet(url: URL(fileURLWithPath: controlNetDir + controlNet.name! + ".mlmodelc")) + if (model.controltype == .controlNet || model.controltype == .all) && control?.controltype == .controlNet { + guard let c = try? ControlNet(modelAt: URL(fileURLWithPath: controlNetDir + controlNet.name! + ".mlmodelc")) else { + self.logger.error("Couldn't load ControlNet \(controlNet.name!)") + continue + } + let cinput = ConditioningInput(module: c) + cinput.image = image +// ImageGenerator.shared.pipeline?.conditioningInput = [cinput] + } else if (model.controltype == .t2IAdapter || model.controltype == .all) && control?.controltype == .t2IAdapter { + guard let a = try? T2IAdapter(modelAt: URL(fileURLWithPath: controlNetDir + controlNet.name! + ".mlmodelc")) else { + self.logger.error("Couldn't load T2IAdapter \(controlNet.name!)") + continue + } + let ainput = ConditioningInput(module: a) + ainput.image = image +// ImageGenerator.shared.pipeline?.conditioningInput = [ainput] + } } } - pipelineConfig.useDenoisedIntermediates = showGenerationPreview let genConfig = GenerationConfig( pipelineConfig: pipelineConfig, @@ -293,6 +359,8 @@ final class ImageController: ObservableObject { } } + private var prevPipeline: Int? + private func runGenerationJobs() async { guard case .ready = ImageGenerator.shared.state else { return } @@ -300,12 +368,29 @@ final class ImageController: ObservableObject { let genConfig = generationQueue.removeFirst() self.currentGeneration = genConfig do { - try await ImageGenerator.shared.loadPipeline( - model: genConfig.model, - controlNet: genConfig.controlNets, - computeUnit: genConfig.mlComputeUnit, - reduceMemory: self.reduceMemory - ) + if prevPipeline != genConfig.pipelineHash() { + guard let size = genConfig.pipelineConfig.size else { + break + } + let width = Int(size.width) + let height = Int(size.height) + + genConfig.model.modifyEncoderMil(width: width, height: height) + genConfig.model.modifyDecoderMil(width: width, height: height) + + var reduceMemoryOrUpdateInputShape = self.reduceMemory + if genConfig.pipelineConfig.initImage != nil { + reduceMemoryOrUpdateInputShape = true + genConfig.model.modifyInputSize(width: width, height: height) + } + try await ImageGenerator.shared.loadPipeline( + model: genConfig.model, + controlNet: genConfig.controlNets, + computeUnit: genConfig.mlComputeUnit, + reduceMemory: reduceMemoryOrUpdateInputShape + ) + prevPipeline = genConfig.pipelineHash() + } try await ImageGenerator.shared.generate(genConfig) } catch ImageGenerator.GeneratorError.requestedModelNotFound { self.logger.error("Couldn't load \(genConfig.model.name) because it doesn't exist.") @@ -313,12 +398,6 @@ final class ImageController: ObservableObject { } catch ImageGenerator.GeneratorError.pipelineNotAvailable { self.logger.error("Pipeline is not available.") await ImageGenerator.shared.updateState(.ready("There was a problem loading pipeline.")) - } catch PipelineError.startingImageProvidedWithoutEncoder { - self.logger.error("The selected model does not support setting a starting image.") - await ImageGenerator.shared.updateState(.ready("The selected model does not support setting a starting image.")) - } catch Encoder.Error.sampleInputShapeNotCorrect { - self.logger.error("The starting image size doesn't match the size of the image that will be generated.") - await ImageGenerator.shared.updateState(.ready("The starting image size doesn't match the size of the image that will be generated.")) } catch { self.logger.error("There was a problem generating images: \(error)") await ImageGenerator.shared.updateState(.error("There was a problem generating images: \(error)")) diff --git a/Mochi Diffusion/Support/ImageGenerator.swift b/Mochi Diffusion/Support/ImageGenerator.swift index c5609a73..529e7c48 100644 --- a/Mochi Diffusion/Support/ImageGenerator.swift +++ b/Mochi Diffusion/Support/ImageGenerator.swift @@ -7,13 +7,13 @@ import Combine import CoreML +@preconcurrency import GuernikaKit import OSLog -@preconcurrency import StableDiffusion import UniformTypeIdentifiers struct GenerationConfig: Sendable, Identifiable { let id = UUID() - var pipelineConfig: StableDiffusionPipeline.Configuration + var pipelineConfig: SampleInput var isXL: Bool var autosaveImages: Bool var imageDir: String @@ -24,6 +24,16 @@ struct GenerationConfig: Sendable, Identifiable { var scheduler: Scheduler var upscaleGeneratedImages: Bool var controlNets: [String] + + func pipelineHash() -> Int { + var hasher = Hasher() + hasher.combine(model) + hasher.combine(controlNets) + hasher.combine(mlComputeUnit) + hasher.combine(pipelineConfig.size) + hasher.combine(pipelineConfig.initImage == nil) + return hasher.finalize() + } } @Observable public final class ImageGenerator { @@ -55,7 +65,7 @@ struct GenerationConfig: Sendable, Identifiable { private(set) var queueProgress = QueueProgress(index: 0, total: 0) - private var pipeline: (any StableDiffusionPipelineProtocol)? + private var pipeline: (any StableDiffusionPipeline)? private(set) var tokenizer: Tokenizer? @@ -65,8 +75,6 @@ struct GenerationConfig: Sendable, Identifiable { private var generationStartTime: DispatchTime? - private var currentPipelineHash: Int? - func loadImages(imageDir: String) async throws -> ([SDImage], URL) { var finalImageDirURL: URL let fm = FileManager.default @@ -122,9 +130,8 @@ struct GenerationConfig: Sendable, Identifiable { models = subDirs .sorted { $0.lastPathComponent.compare($1.lastPathComponent, options: [.caseInsensitive, .diacriticInsensitive]) == .orderedAscending } .compactMap { url in - let controlledUnetMetadataPath = url.appending(components: "ControlledUnet.mlmodelc", "metadata.json").path(percentEncoded: false) - let hasControlNet = fm.fileExists(atPath: controlledUnetMetadataPath) - + let unetMetadataPath = url.appending(components: "Unet.mlmodelc", "metadata.json").path(percentEncoded: false) + let hasControlNet = fm.fileExists(atPath: unetMetadataPath) if hasControlNet { let controlNetSymLinkPath = url.appending(component: "controlnet").path(percentEncoded: false) @@ -153,35 +160,24 @@ struct GenerationConfig: Sendable, Identifiable { throw GeneratorError.requestedModelNotFound } - var hasher = Hasher() - hasher.combine(model) - hasher.combine(controlNet) - hasher.combine(computeUnit) - hasher.combine(reduceMemory) - let hash = hasher.finalize() - guard hash != self.currentPipelineHash else { return } - await updateState(.loading) let config = MLModelConfiguration() config.computeUnits = computeUnit - if model.isXL { - self.pipeline = try StableDiffusionXLPipeline( - resourcesAt: model.url, - configuration: config, - reduceMemory: reduceMemory - ) - } else { - self.pipeline = try StableDiffusionPipeline( - resourcesAt: model.url, - controlNet: controlNet, - configuration: config, - disableSafety: true, - reduceMemory: reduceMemory - ) + let modelresource = try GuernikaKit.load(at: model.url) + + switch modelresource { + case is StableDiffusionXLPipeline: + self.pipeline = modelresource as? StableDiffusionXLPipeline + case is StableDiffusionXLRefinerPipeline: + self.pipeline = modelresource as? StableDiffusionXLRefinerPipeline + case is StableDiffusionPix2PixPipeline: + self.pipeline = modelresource as? StableDiffusionPix2PixPipeline + default: + self.pipeline = modelresource as? StableDiffusionMainPipeline } - self.currentPipelineHash = hash + self.pipeline?.reduceMemory = reduceMemory self.tokenizer = Tokenizer(modelDir: model.url) await updateState(.ready(nil)) } @@ -193,15 +189,10 @@ struct GenerationConfig: Sendable, Identifiable { } await updateState(.loading) generationStopped = false + var config = inputConfig config.pipelineConfig.seed = config.pipelineConfig.seed == 0 ? UInt32.random(in: 0 ..< UInt32.max) : config.pipelineConfig.seed - if config.isXL { - config.pipelineConfig.encoderScaleFactor = 0.13025 - config.pipelineConfig.decoderScaleFactor = 0.13025 - config.pipelineConfig.schedulerTimestepSpacing = .karras - } - var sdi = SDImage() sdi.prompt = config.pipelineConfig.prompt sdi.negativePrompt = config.pipelineConfig.negativePrompt @@ -214,7 +205,8 @@ struct GenerationConfig: Sendable, Identifiable { for index in 0 ..< config.numberOfImages { await updateQueueProgress(QueueProgress(index: index, total: inputConfig.numberOfImages)) generationStartTime = DispatchTime.now() - let images = try pipeline.generateImages(configuration: config.pipelineConfig) { [config] progress in + + let image = try pipeline.generateImages(input: config.pipelineConfig) { progress in Task { @MainActor in state = .running(progress) @@ -224,10 +216,11 @@ struct GenerationConfig: Sendable, Identifiable { } Task { - if config.pipelineConfig.useDenoisedIntermediates, let currentImage = progress.currentImages.last { - ImageStore.shared.setCurrentGenerating(image: currentImage) + let currentImage = progress.currentLatentSample + if await ImageController.shared.showHighqualityPreview { + ImageStore.shared.setCurrentGenerating(image: try pipeline.decodeToImage(currentImage)) } else { - ImageStore.shared.setCurrentGenerating(image: nil) + ImageStore.shared.setCurrentGenerating(image: pipeline.latentToImage(currentImage)) } } @@ -236,33 +229,32 @@ struct GenerationConfig: Sendable, Identifiable { if generationStopped { break } - for image in images { - guard let image = image else { continue } - if config.upscaleGeneratedImages { - guard let upscaledImg = await Upscaler.shared.upscale(cgImage: image) else { continue } - sdi.image = upscaledImg - sdi.aspectRatio = CGFloat(Double(upscaledImg.width) / Double(upscaledImg.height)) - sdi.upscaler = "RealESRGAN" - } else { - sdi.image = image - sdi.aspectRatio = CGFloat(Double(image.width) / Double(image.height)) - } - sdi.id = UUID() - sdi.seed = config.pipelineConfig.seed - sdi.generatedDate = Date.now - sdi.path = "" - - if config.autosaveImages && !config.imageDir.isEmpty { - var pathURL = URL(fileURLWithPath: config.imageDir, isDirectory: true) - let count = ImageStore.shared.images.endIndex + 1 - pathURL.append(path: sdi.filenameWithoutExtension(count: count)) - - let type = UTType.fromString(config.imageType) - guard let path = await sdi.save(pathURL, type: type) else { continue } - sdi.path = path.path(percentEncoded: false) - } - ImageStore.shared.add(sdi) + + guard let image = image else { continue } + if config.upscaleGeneratedImages { + guard let upscaledImg = await Upscaler.shared.upscale(cgImage: image) else { continue } + sdi.image = upscaledImg + sdi.aspectRatio = CGFloat(Double(upscaledImg.width) / Double(upscaledImg.height)) + sdi.upscaler = "RealESRGAN" + } else { + sdi.image = image + sdi.aspectRatio = CGFloat(Double(image.width) / Double(image.height)) + } + sdi.id = UUID() + sdi.seed = config.pipelineConfig.seed + sdi.generatedDate = Date.now + sdi.path = "" + + if config.autosaveImages && !config.imageDir.isEmpty { + var pathURL = URL(fileURLWithPath: config.imageDir, isDirectory: true) + let count = ImageStore.shared.images.endIndex + 1 + pathURL.append(path: sdi.filenameWithoutExtension(count: count)) + + let type = UTType.fromString(config.imageType) + guard let path = await sdi.save(pathURL, type: type) else { continue } + sdi.path = path.path(percentEncoded: false) } + ImageStore.shared.add(sdi) config.pipelineConfig.seed += 1 } await updateState(.ready(nil)) diff --git a/Mochi Diffusion/Support/Tokenizer.swift b/Mochi Diffusion/Support/Tokenizer.swift index cb3f2b8d..87bc33d4 100644 --- a/Mochi Diffusion/Support/Tokenizer.swift +++ b/Mochi Diffusion/Support/Tokenizer.swift @@ -6,23 +6,23 @@ // import Foundation -import StableDiffusion +import GuernikaKit final class Tokenizer { private let bpeTokenizer: BPETokenizer init?(modelDir: URL) { - let mergesURL = modelDir.appendingPathComponent("merges.txt", conformingTo: .url) - let vocabURL = modelDir.appendingPathComponent("vocab.json", conformingTo: .url) + let mergesURL = modelDir.appendingPathComponent("TextEncoder.mlmodelc/merges.txt", conformingTo: .url) + let vocabURL = modelDir.appendingPathComponent("TextEncoder.mlmodelc/vocab.json", conformingTo: .url) do { - try self.bpeTokenizer = BPETokenizer(mergesAt: mergesURL, vocabularyAt: vocabURL) + try self.bpeTokenizer = BPETokenizer(mergesUrl: mergesURL, vocabularyUrl: vocabURL, addedVocabUrl: nil) } catch { return nil } } func countTokens(_ inString: String) -> Int { - bpeTokenizer.tokenize(input: inString).0.count + bpeTokenizer.tokenize(inString).0.count } } diff --git a/Mochi Diffusion/Views/GalleryPreviewView.swift b/Mochi Diffusion/Views/GalleryPreviewView.swift index 66535fab..441c6ba9 100644 --- a/Mochi Diffusion/Views/GalleryPreviewView.swift +++ b/Mochi Diffusion/Views/GalleryPreviewView.swift @@ -18,17 +18,17 @@ struct GalleryPreviewView: View { .aspectRatio(contentMode: .fit) if case let .running(progress) = generator.state, let progress = progress, progress.stepCount > 0 { let step = progress.step + 1 - let stepValue = Double(step) / Double(progress.stepCount) + let stepValue = Double(step) / Double(progress.stepCount + 1) let progressLabel = String( - localized: "About \(formatTimeRemaining(generator.lastStepGenerationElapsedTime, stepsLeft: progress.stepCount - step))", + localized: "About \(formatTimeRemaining(generator.lastStepGenerationElapsedTime, stepsLeft: progress.stepCount - step + 1))", comment: "Text displaying the current time remaining" ) VStack(alignment: .leading) { HStack { Spacer() - Text(verbatim: "\(step)/\(progress.stepCount)") + Text(verbatim: "\(step)/\(progress.stepCount + 1)") .padding(6) .background(.ultraThinMaterial, in: RoundedRectangle(cornerRadius: 4)) } @@ -37,7 +37,6 @@ struct GalleryPreviewView: View { .padding(8) .background(.ultraThinMaterial, in: RoundedRectangle(cornerRadius: 8)) } - .aspectRatio(CGFloat(image.width / image.height), contentMode: .fit) .padding(8) } } diff --git a/Mochi Diffusion/Views/GalleryToolbarView.swift b/Mochi Diffusion/Views/GalleryToolbarView.swift index 08414cd3..fca67151 100644 --- a/Mochi Diffusion/Views/GalleryToolbarView.swift +++ b/Mochi Diffusion/Views/GalleryToolbarView.swift @@ -19,7 +19,7 @@ struct GalleryToolbarView: View { ZStack { if case let .running(progress) = generator.state, let progress = progress, progress.stepCount > 0 { let step = progress.step + 1 - let stepValue = Double(step) / Double(progress.stepCount) + let stepValue = Double(step) / Double(progress.stepCount + 1) Button { self.isStatusPopoverShown.toggle() diff --git a/Mochi Diffusion/Views/InspectorView.swift b/Mochi Diffusion/Views/InspectorView.swift index ff4205d3..715af10c 100644 --- a/Mochi Diffusion/Views/InspectorView.swift +++ b/Mochi Diffusion/Views/InspectorView.swift @@ -6,7 +6,6 @@ // import CoreML -import StableDiffusion import SwiftUI struct InfoGridRow: View { diff --git a/Mochi Diffusion/Views/JobQueueView.swift b/Mochi Diffusion/Views/JobQueueView.swift index b5449416..5470ce17 100644 --- a/Mochi Diffusion/Views/JobQueueView.swift +++ b/Mochi Diffusion/Views/JobQueueView.swift @@ -16,7 +16,7 @@ struct JobQueueView: View { private func updateProgressData() { if case let .running(progress) = generator.state, let progress = progress, progress.stepCount > 0 { - let step = progress.step + 1 + let step = progress.step let stepValue = Double(step) / Double(progress.stepCount) let progressLabel = String( @@ -143,19 +143,19 @@ private struct InfoPopoverView: View { controller.currentModel = config.model - if let startingImage = config.pipelineConfig.startingImage { + if let startingImage = config.pipelineConfig.initImage, let strength = config.pipelineConfig.strength { controller.startingImage = startingImage - controller.strength = Double(config.pipelineConfig.strength) + controller.strength = Double(strength) } else { await controller.unsetStartingImage() } - if let controlNetName = config.controlNets.first, let controlNetImage = config.pipelineConfig.controlNetInputs.first { - await controller.setControlNet(name: controlNetName) - await controller.setControlNet(image: controlNetImage) - } else { - await controller.unsetControlNet() - } +// if let controlNetName = config.controlNets.first, let controlNetImage = config.pipelineConfig.controlNetInputs.first { +// await controller.setControlNet(name: controlNetName) +// await controller.setControlNet(image: controlNetImage) +// } else { +// await controller.unsetControlNet() +// } } } @@ -220,31 +220,33 @@ private struct InfoPopoverView: View { text: MLComputeUnits.toString(config.mlComputeUnit), showCopyToPromptOption: false ) - if let startingImage = config.pipelineConfig.startingImage { + if let startingImage = config.pipelineConfig.initImage { InfoGridRow( type: LocalizedStringKey("Starting Image"), image: startingImage, showCopyToPromptOption: false ) - InfoGridRow( - type: LocalizedStringKey("Strength"), - text: config.pipelineConfig.strength.formatted(.number.precision(.fractionLength(2))), - showCopyToPromptOption: true, - callback: { controller.strength = Double(config.pipelineConfig.strength) } - ) - } - if let controlNetName = config.controlNets.first, let controlNetImage = config.pipelineConfig.controlNetInputs.first { - InfoGridRow( - type: LocalizedStringKey("ControlNet"), - text: controlNetName, - showCopyToPromptOption: false - ) - InfoGridRow( - type: LocalizedStringKey("ControlNet Image"), - image: controlNetImage, - showCopyToPromptOption: false - ) + if let strength = config.pipelineConfig.strength { + InfoGridRow( + type: LocalizedStringKey("Strength"), + text: strength.formatted(.number.precision(.fractionLength(2))), + showCopyToPromptOption: true, + callback: { controller.strength = Double(config.pipelineConfig.strength!) } + ) + } } +// if let controlNetName = config.controlNets.first, let controlNetImage = config.pipelineConfig.controlNetInputs.first { +// InfoGridRow( +// type: LocalizedStringKey("ControlNet"), +// text: controlNetName, +// showCopyToPromptOption: false +// ) +// InfoGridRow( +// type: LocalizedStringKey("ControlNet Image"), +// image: controlNetImage, +// showCopyToPromptOption: false +// ) +// } // swiftlint:enable trailing_closure } diff --git a/Mochi Diffusion/Views/SettingsView.swift b/Mochi Diffusion/Views/SettingsView.swift index b0fe28a3..b43b7251 100644 --- a/Mochi Diffusion/Views/SettingsView.swift +++ b/Mochi Diffusion/Views/SettingsView.swift @@ -6,7 +6,6 @@ // import CoreML -import StableDiffusion import SwiftUI import UniformTypeIdentifiers import UserNotifications @@ -210,18 +209,18 @@ struct SettingsView: View { GroupBox { VStack(alignment: .leading) { HStack { - Text("Show Image Preview") + Text("Show High Quality Image Preview") Spacer() - Toggle("", isOn: $controller.showGenerationPreview) + Toggle("", isOn: $controller.showHighqualityPreview) .labelsHidden() .toggleStyle(.switch) .controlSize(.small) } Text( - "Show the image as its being generated.", + "A high-quality preview, which will affect the speed.", comment: "Help text for Show Image Preview setting" ) .helpTextFormat() diff --git a/Mochi Diffusion/Views/SidebarControls/ControlNetView.swift b/Mochi Diffusion/Views/SidebarControls/ControlNetView.swift index cb7164b1..acc8c2bf 100644 --- a/Mochi Diffusion/Views/SidebarControls/ControlNetView.swift +++ b/Mochi Diffusion/Views/SidebarControls/ControlNetView.swift @@ -16,7 +16,7 @@ struct ControlNetView: View { .sidebarLabelFormat() HStack(alignment: .top) { - ImageWellView(image: controller.currentControlNets.first?.image, size: controller.currentModel?.inputSize) { image in + ImageWellView(image: controller.currentControlNets.first?.image, size: CGSize(width: controller.width, height: controller.height)) { image in if let image { await ImageController.shared.setControlNet(image: image) } else { diff --git a/Mochi Diffusion/Views/SidebarControls/SizeView.swift b/Mochi Diffusion/Views/SidebarControls/SizeView.swift index 73ec14ed..925bdf62 100644 --- a/Mochi Diffusion/Views/SidebarControls/SizeView.swift +++ b/Mochi Diffusion/Views/SidebarControls/SizeView.swift @@ -9,12 +9,16 @@ import SwiftUI struct SizeView: View { @EnvironmentObject private var controller: ImageController - private let imageSizes = [ - 256, 320, 384, 448, 512, 576, 640, 704, 768 + private let sdimageSizes = [ + 512, 576, 640, 768, 832, 896 + ] + private let sdxlimageSizes = [ + 512, 576, 640, 768, 832, 896, 1_024, 1_152, 1_216, 1_280, 1_344, 1_536 ] var body: some View { HStack { + let imageSizes: [Int] = ImageController.shared.currentModel?.isXL ?? false ? sdxlimageSizes : sdimageSizes VStack(alignment: .leading) { Text( "Width:", @@ -26,7 +30,21 @@ struct SizeView: View { } } .labelsHidden() + .disabled(!(controller.currentModel?.allowsVariableSize ?? false)) } + VStack(alignment: .leading) { + Spacer() + Button { + let w = controller.width, h = controller.height + controller.width = h + controller.height = w + } label: { + Image(systemName: "arrow.right.arrow.left.circle.fill") + } + .buttonBorderShape(.circle) + .disabled(!(controller.currentModel?.allowsVariableSize ?? false)) + } + VStack(alignment: .leading) { Text( "Height:", @@ -38,6 +56,7 @@ struct SizeView: View { } } .labelsHidden() + .disabled(!(controller.currentModel?.allowsVariableSize ?? false)) } } } diff --git a/Mochi Diffusion/Views/SidebarControls/StartingImageView.swift b/Mochi Diffusion/Views/SidebarControls/StartingImageView.swift index e4efa15a..7ef79db6 100644 --- a/Mochi Diffusion/Views/SidebarControls/StartingImageView.swift +++ b/Mochi Diffusion/Views/SidebarControls/StartingImageView.swift @@ -10,6 +10,7 @@ import SwiftUI struct StartingImageView: View { @EnvironmentObject private var controller: ImageController @State private var isInfoPopoverShown = false + @State private var isMaskPopoverShown = false var body: some View { Text( @@ -19,7 +20,8 @@ struct StartingImageView: View { .sidebarLabelFormat() HStack(alignment: .top) { - ImageWellView(image: controller.startingImage, size: controller.currentModel?.inputSize) { image in + ImageWellView(image: controller.startingImage, size: CGSize(width: controller.width, height: controller.height)) { image in + controller.maskImage = nil if let image { ImageController.shared.setStartingImage(image: image) } else { @@ -27,18 +29,41 @@ struct StartingImageView: View { } } .frame(width: 90, height: 90) - + .overlay( + Image(nsImage: controller.maskImage.map { NSImage(cgImage: $0, size: NSSize(width: $0.width, height: $0.height)) } ?? NSImage()) + .resizable() + .aspectRatio(contentMode: .fit) + ) + .popover(isPresented: self.$isMaskPopoverShown, arrowEdge: .top) { + let screenHeight = NSScreen.main?.frame.height ?? 0 + let screenWidth = NSScreen.main?.frame.width ?? 0 + let aspectRatio = CGSize(width: controller.width, height: controller.height).aspectRatio + if aspectRatio <= 1 { + MaskEditorView(startingImage: controller.startingImage?.scaledAndCroppedTo(size: CGSize(width: (screenHeight * aspectRatio * 0.6).rounded(), height: (screenHeight * 0.6).rounded())), maskImage: $controller.maskImage) + } else { + MaskEditorView(startingImage: controller.startingImage?.scaledAndCroppedTo(size: CGSize(width: (screenWidth * 0.5).rounded(), height: (screenWidth / aspectRatio * 0.5).rounded())), maskImage: $controller.maskImage) + } + } Spacer() VStack(alignment: .trailing) { HStack { Button { + controller.maskImage = nil Task { await ImageController.shared.selectStartingImage() } } label: { Image(systemName: "photo") } Button { + self.isMaskPopoverShown.toggle() + } label: { + Image(systemName: "paintbrush") + } + .disabled(controller.startingImage == nil) + + Button { + controller.maskImage = nil Task { await ImageController.shared.unsetStartingImage() } } label: { Image(systemName: "xmark") @@ -65,13 +90,13 @@ struct StartingImageView: View { .buttonStyle(PlainButtonStyle()) .popover(isPresented: self.$isInfoPopoverShown, arrowEdge: .top) { Text( - """ - Strength controls how closely the generated image resembles the starting image. - Use lower values to generate images that look similar to the starting image. - Use higher values to allow more creative freedom. + """ + Strength controls how closely the generated image resembles the starting image. + Use lower values to generate images that look similar to the starting image. + Use higher values to allow more creative freedom. - The size of the starting image must match the output image size of the current model. - """ + The size of the starting image must match the output image size of the current model. + """ ) .padding() } @@ -80,6 +105,107 @@ struct StartingImageView: View { } } +struct PathWrapper: Identifiable { + let id = UUID() + let path: Path +} + +struct MaskEditorView: View { + let startingImage: CGImage? + @Binding var maskImage: CGImage? + @State private var startPoint: CGPoint? + @State private var endPoint: CGPoint? + @State private var paths: [PathWrapper] = [] + @State private var radius: CGFloat = 80 + + var body: some View { + VStack(alignment: .trailing) { + ZStack { + if let startingImage { + Image(nsImage: NSImage(cgImage: startingImage, size: NSSize(width: startingImage.width, height: startingImage.height))) + .aspectRatio(contentMode: .fit) + if let maskImage { + Image(nsImage: NSImage(cgImage: maskImage, size: NSSize(width: startingImage.width, height: startingImage.height))) + .aspectRatio(contentMode: .fit) + } + } + } + .frame(width: CGFloat(startingImage?.width ?? 0), height: CGFloat(startingImage?.height ?? 0)) + + HStack { + Spacer() + Button { + paths.removeAll() + maskImage = nil + } label: { + Image(systemName: "arrow.clockwise.circle") + .font(.system(size: 30)) + } + .buttonBorderShape(.circle) + .padding(20) + + Slider(value: $radius, in: 30...150, step: 5) + Spacer() + } + Spacer() + } + + .gesture( + DragGesture(minimumDistance: 0) + .onChanged { value in + if startPoint == nil { + startPoint = value.location + } else { + endPoint = value.location + updateMaskImage() + } + } + .onEnded { _ in + startPoint = nil + endPoint = nil + } + ) + } + + func updateMaskImage() { + guard let endPoint = endPoint else { return } + + let center = CGPoint(x: (endPoint.x), y: (endPoint.y)) + let adjustedY = CGFloat(startingImage?.height ?? 0) - center.y + + let path = Path { path in + path.addEllipse(in: CGRect(x: center.x - radius, y: adjustedY - radius, width: radius * 2, height: radius * 2)) + } + paths.append(PathWrapper(path: path)) + drawPaths() + } + + func drawPaths() { + guard let startingImage = startingImage else { return } + let imageSize = CGSize(width: startingImage.width, height: startingImage.height) + let imageRect = CGRect(origin: .zero, size: imageSize) + + let width = Int(imageSize.width) + let height = Int(imageSize.height) + + guard let maskContext = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 0, space: CGColorSpaceCreateDeviceRGB(), bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue) else { return } + + maskContext.clear(imageRect) + + maskContext.setFillColor(CGColor.black) + maskContext.setBlendMode(.normal) + + for pathWrapper in paths { + maskContext.addPath(pathWrapper.path.cgPath) + maskContext.fillPath() + } + + if let maskImage = maskContext.makeImage() { + self.maskImage = maskImage + } + } +} + #Preview { StartingImageView() .environmentObject(ImageController.shared) diff --git a/Mochi Diffusion/Views/SidebarView.swift b/Mochi Diffusion/Views/SidebarView.swift index e0fefa62..3bcf02f2 100644 --- a/Mochi Diffusion/Views/SidebarView.swift +++ b/Mochi Diffusion/Views/SidebarView.swift @@ -23,6 +23,10 @@ struct SidebarView: View { ModelView() Spacer().frame(height: 6) } + Group { + SizeView() + Spacer().frame(height: 6) + } Group { NumberOfImagesView() Spacer().frame(height: 6)