In [1]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)

In [2]:
import torch

# Instantiate the class `ResNetGenerator` with a set of pretrained parameters
netG = ResNetGenerator()
model_path = 'horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

# Put the model in eval mode
netG.eval()

from PIL import Image
from torchvision import transforms

# Load an image and preprocess it
img = Image.open('horse.jpg') # img.size: (1500, 1220)
preprocess = transforms.Compose([transforms.Resize(256),
                                transforms.ToTensor()])
img_t = preprocess(img) # img_t.size(): torch.Size([3, 256, 314])
batch_t = torch.unsqueeze(img_t, 0) # batch_t.size(): torch.Size([1, 3, 256, 314])

# Send the preprocessed image to Generator
batch_out = netG(batch_t) # batch_out.size(): torch.Size([1, 3, 256, 316])

# Convert back to an image
out_t = (batch_out.data.squeeze() + 1.0) / 2.0 # batch_out.data.squeeze().size(): torch.Size([3, 256, 316])
out_img = transforms.ToPILImage()(out_t)
out_img.save('zebra.jpg')

In [3]:
for key in model_data.keys():
    print(key)

model.1.weight
model.1.bias
model.4.weight
model.4.bias
model.7.weight
model.7.bias
model.10.conv_block.1.weight
model.10.conv_block.1.bias
model.10.conv_block.5.weight
model.10.conv_block.5.bias
model.11.conv_block.1.weight
model.11.conv_block.1.bias
model.11.conv_block.5.weight
model.11.conv_block.5.bias
model.12.conv_block.1.weight
model.12.conv_block.1.bias
model.12.conv_block.5.weight
model.12.conv_block.5.bias
model.13.conv_block.1.weight
model.13.conv_block.1.bias
model.13.conv_block.5.weight
model.13.conv_block.5.bias
model.14.conv_block.1.weight
model.14.conv_block.1.bias
model.14.conv_block.5.weight
model.14.conv_block.5.bias
model.15.conv_block.1.weight
model.15.conv_block.1.bias
model.15.conv_block.5.weight
model.15.conv_block.5.bias
model.16.conv_block.1.weight
model.16.conv_block.1.bias
model.16.conv_block.5.weight
model.16.conv_block.5.bias
model.17.conv_block.1.weight
model.17.conv_block.1.bias
model.17.conv_block.5.weight
model.17.conv_block.5.bias
model.18.conv_block.

In [4]:
import netron
import torch.onnx

netG = ResNetGenerator()
model_path = 'horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)
netG.eval()

x = torch.randn(1, 3, 256, 314)
modelFile = "netG.pth"
torch.onnx.export(netG, x, modelFile)
netron.start(modelFile)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(


Serving 'netG.pth' at http://localhost:8080


  _C._jit_pass_onnx_graph_shape_type_inference(


('localhost', 8080)