Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD LatentDiffusionModel #830

Closed
kyakuno opened this issue Jun 1, 2022 · 10 comments
Closed

ADD LatentDiffusionModel #830

kyakuno opened this issue Jun 1, 2022 · 10 comments
Assignees

Comments

@kyakuno
Copy link
Collaborator

kyakuno commented Jun 1, 2022

https://github.com/CompVis/latent-diffusion

@kyakuno
Copy link
Collaborator Author

kyakuno commented Jun 3, 2022

MIT

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

エクスポートのための修正
○ ldm/modules/diffusionmodules/model.py

class AttnBlock(nn.Module):
    ...
    def forward(self, x):
        ...
        w_ = w_ * (int(c)**(-0.5))

class AttnBlock(nn.Module):
    ...
    def forward(self, x):
        ...
        w_ = w_ * (c**(-0.5))

○ ldm/modules/diffusionmodules/util.py

def checkpoint(func, inputs, params, flag):
    ...
    if flag:
        ...

def checkpoint(func, inputs, params, flag):
    ...
    flag = False
    if flag:
        ...

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-txt2img: transformer_emb.onnx, transformer_attn.onnx
    onnxのサイズが大きいとファイルが分割されるため、モデルを分割してエクスポートする

○ ldm/modules/encoders/modules.py

class BERTEmbedder(AbstractEncoder):
    ...
    def forward(self, text):
        if self.use_tknz_fn:
            ...
        else:
            ...
        z = self.transformer(tokens, return_embeddings=True)

class BERTEmbedder(AbstractEncoder):
    ...
    def forward(self, text):
        if self.use_tknz_fn:
            ...
        else:
            ...
        print("------>")
        import functools
        from torch.autograd import Variable
        self.transformer.forward = functools.partial(self.transformer.forward, return_embeddings=True)
        self.transformer.cpu()
        x = Variable(tokens.cpu())
        torch.onnx.export(
            self.transformer, x, 'transformer_emb.onnx',
            input_names=["x"],
            output_names=["out"],
            dynamic_axes={'x' : {0 : 'n'}, 'out' : {0 : 'n'}},
            verbose=False, opset_version=12
        )
        print("<------")
class BERTEmbedder(AbstractEncoder):
    ...
    def forward(self, text):
        if self.use_tknz_fn:
            ...
        else:
            ...
        x = self.transformer(tokens, return_embeddings=True)
        
        print("------>")
        from torch.autograd import Variable
        self.transformer.forward = self.transformer.forward2
        self.transformer.cpu()
        x = Variable(x.cpu())
        torch.onnx.export(
            self.transformer, x, 'transformer_attn.onnx',
            input_names=["x"],
            output_names=["out"],
            dynamic_axes={'x' : {0 : 'n'}, 'out' : {0 : 'n'}},
            verbose=False, opset_version=12
        )
        print("<------")

○ ldm/modules/x_transformer.py

class TransformerWrapper(nn.Module):
    def forward(
      ...
    ):
        ...
        x = self.project_emb(x)
        ...
        if num_mem > 0:
            ...

class TransformerWrapper(nn.Module):
    def forward(
      ...
    ):
        ...
        x = self.project_emb(x)
        
        return x

        if num_mem > 0:
            ...

    def forward2(self, x, mask=None, mems=None, **kwargs):
        num_mem = 0
        x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
        x = self.norm(x)

        mem, x = x[:, :num_mem], x[:, num_mem:]
        
        return_embeddings = True
        out = self.to_logits(x) if not return_embeddings else x
        
        return out

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-txt2img: diffusion_emb.onnx, diffusion_mid.onnx, diffusion_out.onnx
    onnxのサイズが大きいとファイルが分割されるため、モデルを分割してエクスポートする

○ ldm/models/diffusion/ddpm.py

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            print("------>")
            from torch.autograd import Variable
            xx = (Variable(x), Variable(t), Variable(cc))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_emb.onnx',
                input_names=["x", "timesteps", "context"],
                output_names=["h", "emb", "h0", "h1", "h2", "h3", "h4", "h5", "h6", "h7", "h8", "h9", "h10", "h11"],
                dynamic_axes={'x' : {0 : 'n', 2:'h',3:'w'}, 'timesteps' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h' : {0 : 'n', 2:'h1',3:'w1'}, 'emb' : {0 : 'n'}, 'h0' : {0 : 'n', 2:'h1',3:'w1'}, 'h1' : {0 : 'n', 2:'h1',3:'w1'}, 'h2' : {0 : 'n', 2:'h1',3:'w1'}, 'h3' : {0 : 'n', 2:'h2',3:'w2'}, 'h4' : {0 : 'n', 2:'h2',3:'w2'}, 'h5' : {0 : 'n', 2:'h2',3:'w2'}, 'h6' : {0 : 'n', 2:'h3',3:'w3'}, 'h7' : {0 : 'n', 2:'h3',3:'w3'}, 'h8' : {0 : 'n', 2:'h3',3:'w3'}, 'h9' : {0 : 'n', 2:'h4',3:'w4'}, 'h10' : {0 : 'n', 2:'h4',3:'w4'}, 'h11' : {0 : 'n', 2:'h4',3:'w4'}},
                verbose=False, opset_version=12
            )
            print("<------")
class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
            h = out[0]
            emb = out[1]
            hs = out[2:]

            print("------>")
            from torch.autograd import Variable
            self.diffusion_model.forward = self.diffusion_model.forward2
            xx = (
                Variable(h), Variable(emb), Variable(cc), 
                Variable(hs[6]), Variable(hs[7]), Variable(hs[8]), Variable(hs[9]), 
                Variable(hs[10]), Variable(hs[11]))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_mid.onnx',
                input_names=[
                    "h", "emb", "context", "h6", "h7", "h8", "h9", "h10", "h11"],
                output_names=["out"],
                dynamic_axes={'h' : {0 : 'n', 2:'h4',3:'w4'}, 'emb' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h6' : {0 : 'n', 2:'h3',3:'w3'}, 'h7' : {0 : 'n', 2:'h3',3:'w3'}, 'h8' : {0 : 'n', 2:'h3',3:'w3'}, 'h9' : {0 : 'n', 2:'h4',3:'w4'}, 'h10' : {0 : 'n', 2:'h4',3:'w4'}, 'h11' : {0 : 'n', 2:'h4',3:'w4'}, 'out' : {0 : 'n', 2:'h2',3:'w2'}},
                verbose=False, opset_version=12
            )
            print("<------")
class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
            h = out[0]
            emb = out[1]
            hs = out[2:]
            h = self.diffusion_model.forward2(
                h, emb, cc, 
                hs[6], hs[7], hs[8], hs[9], hs[10], hs[11])

            print("------>")
            from torch.autograd import Variable
            self.diffusion_model.forward = self.diffusion_model.forward3
            xx = (
                Variable(h), Variable(emb), Variable(cc), 
                Variable(hs[0]), Variable(hs[1]), Variable(hs[2]), Variable(hs[3]), 
                Variable(hs[4]), Variable(hs[5]))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_out.onnx',
                input_names=[
                    "h", "emb", "context", "h0", "h1", "h2", "h3", "h4", "h5"],
                output_names=["out"],
                dynamic_axes={'h' : {0 : 'n', 2:'h2',3:'w2'}, 'emb' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h0' : {0 : 'n', 2:'h1',3:'w1'}, 'h1' : {0 : 'n', 2:'h1',3:'w1'}, 'h2' : {0 : 'n', 2:'h1',3:'w1'}, 'h3' : {0 : 'n', 2:'h2',3:'w2'}, 'h4' : {0 : 'n', 2:'h2',3:'w2'}, 'h5' : {0 : 'n', 2:'h2',3:'w2'}, 'out' : {0 : 'n', 2:'h',3:'w'}},
                verbose=False, opset_version=12
            )
            print("<------")

○ ldm/modules/diffusionmodules/openaimodel.py

class UNetModel(nn.Module):
    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        ...
        h = self.middle_block(h, emb, context)

class UNetModel(nn.Module):
    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        ...
        h = self.middle_block(h, emb, context)
        return h, emb, hs[0], hs[1], hs[2], hs[3], hs[4], hs[5], hs[6], hs[7], hs[8], hs[9], hs[10], hs[11]

    def forward2(self, h, emb, context, h6, h7, h8, h9, h10, h11):
        ...
        hs = [h6, h7, h8, h9, h10, h11]
        for i, module in enumerate(self.output_blocks[:6]):
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        return h

    def forward3(self, h, emb, context, h0, h1, h2, h3, h4, h5):
        hs = [h0, h1, h2, h3, h4, h5]
        for i, module in enumerate(self.output_blocks[6:]):
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)

        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-txt2img: autoencoder.onnx

○ ldm/models/diffusion/ddpm.py

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if hasattr(self, "split_input_params"):
            ...
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                ...
            else:
                return self.first_stage_model.decode(z)

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if hasattr(self, "split_input_params"):
            ...
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                ...
            else:
                print("------>")
                self.first_stage_model.forward = self.first_stage_model.decode
                from torch.autograd import Variable
                x = Variable(z)
                torch.onnx.export(
                    self.first_stage_model, x, 'autoencoder.onnx',
                    input_names=["input"],
                    output_names=["output"],
                    dynamic_axes={'input' : {0 : 'n', 2:'h',3:'w'}, 'output' : {0 : 'n', 2:'ho',3:'wo'}},
                    verbose=False, opset_version=11
                )
                print("<------")

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-inpainting: cond_stage_model.onnx

○ ldm/models/diffusion/ddpm.py

    with torch.no_grad():
        with model.ema_scope():
            for image, mask in tqdm(zip(images, masks)):
                ...
                c = model.cond_stage_model.encode(batch["masked_image"])

    with torch.no_grad():
        with model.ema_scope():
            for image, mask in tqdm(zip(images, masks)):
                ...
                print("------>")
                model.cond_stage_model.forward = model.cond_stage_model.encode
                from torch.autograd import Variable
                x = Variable(batch["masked_image"])
                torch.onnx.export(
                    model.cond_stage_model, x, 'cond_stage_model.onnx',
                    input_names=["masked_image"],
                    output_names=["out"],
                    dynamic_axes={'masked_image' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'oh', 3 : 'ow'}},
                    verbose=False, opset_version=11
                )
                print("<------")

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-inpainting: autoencoder.onnx

○ ldm/models/diffusion/ddpm.py

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            ..
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            ...
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                print("------>")
                from torch.autograd import Variable
                x = Variable(z)
                self.first_stage_model.forward = self.first_stage_model.decode
                torch.onnx.export(
                    self.first_stage_model, x, 'autoencoder.onnx',
                    input_names=["z"],
                    output_names=["dec"],
                    dynamic_axes={'z' : {2 : 'h', 3 : 'w'}, 'dec' : {2 : 'oh', 3 : 'ow'}},
                    verbose=False, opset_version=12
                )
                print("<------")

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-inpainting: diffusion_model.onnx

○ ldm/models/diffusion/ddpm.py

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            print("------>")
            from torch.autograd import Variable
            xx = (Variable(xc), Variable(t))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_model.onnx',
                input_names=["xc", "t"],
                output_names=["out"],
                dynamic_axes={'xc' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'h', 3 : 'w'}},
                verbose=False, opset_version=12
            )
            print("<------")

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-superresolution: first_stage_decode.onnx

○ ldm/models/diffusion/ddpm.py

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            if self.split_input_params["patch_distributed_vq"]:
                ...
                if isinstance(self.first_stage_model, VQModelInterface):
                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
                                                                 force_not_quantize=predict_cids or force_not_quantize)
                                   for i in range(z.shape[-1])]

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            if self.split_input_params["patch_distributed_vq"]:
                ...
                if isinstance(self.first_stage_model, VQModelInterface):
                    print("------>")
                    from torch.autograd import Variable
                    x = Variable(z[:, :, :, :, 0])
                    self.first_stage_model.forward = self.first_stage_model.decode
                    torch.onnx.export(
                        self.first_stage_model, x, 'first_stage_decode.onnx',
                        input_names=["x"],
                        output_names=["out"],
                        dynamic_axes={'x' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'oh', 3 : 'ow'}},
                        verbose=False, opset_version=12
                    )
                    print("<------")

@ooe1123
Copy link
Contributor

ooe1123 commented Jun 26, 2022

  • latent-diffusion-superresolution: diffusion_model.onnx

○ ldm/modules/diffusionmodules/openaimodel.py

class QKVAttentionLegacy(nn.Module):
    ...
    def forward(self, qkv):
        ...
        scale = 1 / math.sqrt(math.sqrt(ch))

class QKVAttentionLegacy(nn.Module):
    ...
    def forward(self, qkv):
        ...
        scale = 1 / ((ch**0.5)**0.5)

○ ldm/models/diffusion/ddpm.py

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            print("------>")
            from torch.autograd import Variable
            xx = (Variable(xc), Variable(t))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_model.onnx',
                input_names=["xc", "t"],
                output_names=["out"],
                dynamic_axes={'xc' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'h', 3 : 'w'}},
                verbose=False, opset_version=12
            )
            print("<------")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants