In [None]:
import tqdm

def nop(it, *a, **k):
    if 'prefix' in k:
        print(k['prefix'])
    return it
def noprange(*a, **k):
    if 'prefix' in k:
        print(k['prefix'])
    return range(*a)

tqdm.tqdm = nop
tqdm.trange = noprange

import os

import torch
from diffusers import StableDiffusionPipeline, KarrasVeScheduler
# from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
from PIL import Image

kve = KarrasVeScheduler(
    sigma_max=14.6146,
    # sigma_min=0.0936,
    sigma_min=0.0292,
    s_churn=0.
)
# lms = LMSDiscreteScheduler(
#   beta_start=0.00085,
#   beta_end=0.012,
#   beta_schedule="scaled_linear"
# )

pipe = StableDiffusionPipeline.from_pretrained("/Users/birch/git/stable-diffusion-v1-4", safety_checker=None)# torch_type=torch.float16, revision="fp16")
# pipe = pipe.to("mps")

In [None]:
import coremltools as ct
from pathlib import Path
import torch as th
import diffusers
from coremltools.models import MLModel

class Undictifier(th.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, *args, **kwargs): 
        return self.m(*args, **kwargs)["sample"]

class CLIPUndictifier(th.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, *args, **kwargs): 
        return self.m(*args, **kwargs)[0]

def convert_text_encoder(text_encoder, outname):    
    import transformers
    from transformers.models.clip.modeling_clip import CLIPTextTransformer
    CLIPTextTransformer.attention_mask = CLIPTextTransformer._build_causal_attention_mask(None, 1, 77, th.float)
    def _fake_build_causal_mask(self, *args, **kwargs):
        return self.attention_mask
    CLIPTextTransformer._build_causal_attention_mask = _fake_build_causal_mask
    f_trace = th.jit.trace(CLIPUndictifier(text_encoder), (th.zeros(1, 77, dtype=th.long)), strict=False, check_trace=False)

    f_coreml = ct.convert(f_trace, 
               inputs=[ct.TensorType(shape=(1, 77))],
               convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True)
    f_coreml.save(outname)

def convert_decoder(decoder, outname):    
    f_trace = th.jit.trace(decoder, (th.zeros(1, 4, 64, 64)), strict=False, check_trace=False)

    f_coreml = ct.convert(f_trace, 
               inputs=[ct.TensorType(shape=(1, 4, 64, 64))],
               convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True)
    f_coreml.save(outname)

def convert_post_quant_conv(layer, outname):
    f_trace = th.jit.trace(layer, (th.zeros(1, 4, 64, 64)), strict=False, check_trace=False)

    f_coreml = ct.convert(f_trace, 
            inputs=[ct.TensorType(shape=(1, 4, 64, 64))],
            convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True)
    f_coreml.save(outname)

def convert_unet(f, out_name):
    from coremltools.converters.mil import Builder as mb
    from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op, _TORCH_OPS_REGISTRY
    import coremltools.converters.mil.frontend.torch.ops as cml_ops
    # def unsliced_attention(self, query, key, value, _sequence_length, _dim):
    #     attn = (torch.einsum("b i d, b j d -> b i j", query, key) * self.scale).softmax(dim=-1)
    #     attn = torch.einsum("b i j, b j d -> b i d", attn, value)
    #     return self.reshape_batch_dim_to_heads(attn)
    # diffusers.models.attention.CrossAttention._attention = unsliced_attention
    orig_einsum = th.einsum
    def fake_einsum(a, b, c):
        if a == 'b i d, b j d -> b i j': return th.bmm(b, c.permute(0, 2, 1))
        if a == 'b i j, b j d -> b i d': return th.bmm(b, c)
        raise ValueError(f"unsupported einsum {a} on {b.shape} {c.shape}")
    th.einsum = fake_einsum
    if "broadcast_to" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["broadcast_to"]
    @register_torch_op
    def broadcast_to(context, node): return cml_ops.expand(context, node)
    if "gelu" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["gelu"]
    @register_torch_op
    def gelu(context, node): context.add(mb.gelu(x=context[node.inputs[0]], name=node.name))
    
    print("tracing")
    f_trace = th.jit.trace(Undictifier(f), (th.zeros(2, 4, 64, 64), th.zeros(1), th.zeros(2, 77, 768)), strict=False, check_trace=False)
    print("converting")
    f_coreml_fp16 = ct.convert(f_trace, 
               inputs=[ct.TensorType(shape=(2, 4, 64, 64)), ct.TensorType(shape=(1,)), ct.TensorType(shape=(2, 77, 768))],
               convert_to="mlprogram",  compute_precision=ct.precision.FLOAT16, skip_model_load=True)
    f_coreml_fp16.save(f"{out_name}")
    th.einsum = orig_einsum
    
class UNetWrapper:
    def __init__(self, f, out_name="unet.mlpackage"):
        self.in_channels = f.in_channels
        if not Path(out_name).exists():
            print("generating coreml model"); convert_unet(f, out_name); print("saved")
        # not only does ANE take forever to load because it recompiles each time - it then doesn't work!
        # and NSLocalizedDescription = "Error computing NN outputs."; is not helpful... GPU it is
        print("loading saved coreml model"); f_coreml_fp16 = MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
        self.f = f_coreml_fp16

    def __call__(self, sample, timestep, encoder_hidden_states):
        from diffusers.models.unet_2d_condition import UNet2DConditionOutput
        args = {"sample": sample.numpy(), "timestep": th.tensor([timestep], dtype=th.int32).numpy(), "input_35": encoder_hidden_states.numpy()}
        for v in self.f.predict(args).values():
            return UNet2DConditionOutput(sample=th.tensor(v, dtype=th.float32))

# class TextEncoderWrapper:
#     def __init__(self, f, out_name="text_encoder.mlpackage"):
#         if not Path(out_name).exists():
#             print("generating coreml model"); convert_text_encoder(f, out_name); print("saved")
#         print("loading saved coreml model"); self.f = MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
    
#     def __call__(self, input):
#         args = args = {"input_ids_1": input.float().numpy()}
#         for v in self.f.predict(args).values():
#             return (th.tensor(v, dtype=th.float32),)

class DecoderWrapper:
    def __init__(self, f, out_name="vae_decoder.mlpackage"):
        if not Path(out_name).exists():
            print("generating coreml model"); convert_decoder(f, out_name); print("saved")
        print("loading saved coreml model"); f_coreml_fp16 = MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
        self.f = f_coreml_fp16
    
    def __call__(self, input):
        args = args = {"z": input.numpy()}
        for v in self.f.predict(args).values():
            return th.tensor(v, dtype=th.float32)

class PostQuantConvWrapper:
    def __init__(self, f, out_name="post_quant_conv.mlpackage"):
        if not Path(out_name).exists():
            print("generating coreml model"); convert_post_quant_conv(f, out_name); print("saved")
        print("loading saved coreml model"); f_coreml_fp16 = MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
        self.f = f_coreml_fp16
    
    def __call__(self, input):
        args = {"input": input.numpy()}
        for v in self.f.predict(args).values():
            return th.tensor(v, dtype=th.float32)

class VAEWrapper:
    def __init__(self, decoder, post_quant_conv):
        self.decoder = decoder
        self.post_quant_conv = post_quant_conv

    def decode(self, input):
        from diffusers.models.vae import DecoderOutput
        quant = self.post_quant_conv(input)
        dec = self.decoder(quant)

        return DecoderOutput(sample=dec)

# pipe.text_encoder = TextEncoderWrapper(pipe.text_encoder)
pipe.unet = UNetWrapper(pipe.unet)
pipe.vae = VAEWrapper(
            DecoderWrapper(pipe.vae.decoder), 
            PostQuantConvWrapper(pipe.vae.post_quant_conv)) 

In [None]:
prompt = "masterpiece character portrait of a blonde girl, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave"
generator = torch.Generator(device="cpu").manual_seed(68673924)
image: Image.Image = pipe(
	prompt,
	# guidance_scale=1.,
	generator=generator,  
  # scheduler=lms,
  scheduler=kve,
  # num_inference_steps=30
  num_inference_steps=15
).images[0]

sample_path="../outputs/diffusers"
base_count = len(os.listdir(sample_path))
image.save(os.path.join(sample_path, f"{base_count:05}.png"))