In [1]:
import yaml
import torch
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ
from PIL import Image
import PIL
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import cv2
import onnx

def load_config(config_path, display=False):
  config = OmegaConf.load(config_path)
  if display:
    print(yaml.dump(OmegaConf.to_container(config)))
  return config

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
  if is_gumbel:
    model = GumbelVQ(**config.model.params)
  else:
    model = VQModel(**config.model.params)
  if ckpt_path is not None:
    sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
    missing, unexpected = model.load_state_dict(sd, strict=False)
  return model

In [2]:
config1024 = load_config("configs/model.yaml", display=False)
#model1024 = load_vqgan(config1024, ckpt_path="ckpts/last.ckpt").to('cuda')
#model1024 = load_vqgan(config1024, ckpt_path="ckpts/last_full_train.ckpt").to('cuda')
model1024 = load_vqgan(config1024, ckpt_path="configs/last.ckpt")

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.




loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips\vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


In [3]:
torch_input = torch.randn(1, 3, 160, 256)
onnx_program = torch.onnx.dynamo_export(model1024, torch_input)



In [4]:
onnx_program.save("vqgan.onnx")

In [6]:
torch.onnx.export(model1024, torch_input, "rtt.onnx")

  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [25]:
p = model1024.encoder

In [26]:
torch.onnx.export(p, torch_input, "rencode.onnx")

  w_ = w_ * (int(c)**(-0.5))
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [9]:
enc = onnx.load('rencode.onnx')

In [10]:
onnx.checker.check_model(enc)

In [22]:
import torch.nn as nn
class EncoderManual(nn.Module):
    def __init__(self, model1024):
        super(EncoderManual, self).__init__()
        # Assuming encoder layers are defined in a sequential manner in original_model.encoder
        self.encoder = model1024.encoder
        self.quant_conv = model1024.quant_conv
        self.quantize = model1024.quantize

    def forward(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        return [quant, emb_loss, info]


In [23]:
a=EncoderManual(model1024)

In [24]:
torch.onnx.export(a, torch_input, 'quant3.onnx', input_names=['input'])

  w_ = w_ * (int(c)**(-0.5))
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
