In [1]:
import torch
import torch.nn as nn

In [2]:

class ResBlock(nn.Module):
    def __init__(self, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
        super(ResBlock, self).__init__()

        m = []
        for i in range(2):
            m.append(nn.Conv2d(n_feat, n_feat, kernel_size, padding=(kernel_size // 2), bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feat))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

In [3]:
import torch
import torch.nn as nn


class ResBlock(nn.Module):
    def __init__(self, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
        super(ResBlock, self).__init__()

        m = []
        for i in range(2):
            m.append(nn.Conv2d(n_feat, n_feat, kernel_size, padding=(kernel_size // 2), bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feat))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):  # x(n_feat) -> res(n_feat)
        res = self.body(x).mul(self.res_scale)
        res += x
        return res


class EncodingBlock(nn.Module):
    def __init__(self, ch_in):
        super(EncodingBlock, self).__init__()

        body = [
            nn.Conv2d(ch_in, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            nn.Conv2d(64, 128, kernel_size=3, padding=3 // 2)
        ]
        self.body = nn.Sequential(*body)
        self.down = nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=3 // 2)
        self.act = nn.ReLU()

    def forward(self, input):  # input -> f_e(128),down(64)
        f_e = self.body(input)
        down = self.act(self.down(f_e))
        return f_e, down


class EncodingBlockEnd(nn.Module):
    def __init__(self, ch_in):
        super(EncodingBlockEnd, self).__init__()

        head = [
            nn.Conv2d(in_channels=ch_in, out_channels=64, kernel_size=3, padding=3 // 2),
            nn.ReLU()
        ]
        body = [
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),

            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),

            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),

            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),

            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),

            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
        ]
        tail = [
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=3 // 2)
        ]
        self.head = nn.Sequential(*head)
        self.body = nn.Sequential(*body)
        self.tail = nn.Sequential(*tail)

    def forward(self, input):  # input -> f_e(128)
        out = self.head(input)
        f_e = self.body(out) + out
        f_e = self.tail(f_e)
        return f_e


class DecodingBlock(nn.Module):
    def __init__(self, ch_in):
        super(DecodingBlock, self).__init__()

        body = [
            nn.Conv2d(in_channels=ch_in, out_channels=64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, padding=1 // 2)
        ]

        self.up = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.act = nn.ReLU()
        self.body = nn.Sequential(*body)

    def forward(self, input, map):  # input(128),map(128) -> out(256)
        # 保证逆向卷积出来的shape和map一致
        up = self.up(input, output_size=[input.shape[0], input.shape[1], map.shape[2], map.shape[3]])
        up = self.act(up)
        out = torch.cat((up, map), 1)  # 在channel 纬度上
        out = self.body(out)
        return out


class DecodingBlockEnd(nn.Module):
    def __init__(self, ch_in):
        super(DecodingBlockEnd, self).__init__()

        body = [
            nn.Conv2d(ch_in, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            ResBlock(n_feat=64, kernel_size=3),
            ResBlock(n_feat=64, kernel_size=3),
        ]

        self.up = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.act = nn.ReLU()
        self.body = nn.Sequential(*body)

    def forward(self, input, map):  # input(128),map(128) -> out(64)
        # 保证逆向卷积出来的shape和map一致
        up = self.up(input, output_size=[input.shape[0], input.shape[1], map.shape[2], map.shape[3]])
        out = self.act(up)
        out = torch.cat((out, map), 1)  # 在channel 纬度上
        out = self.body(out)
        return out


In [4]:
Encoding_block1 = EncodingBlock(64)

In [6]:
Encoding_block1.eval()

EncodingBlock(
  (body): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): ResBlock(
      (body): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (3): ResBlock(
      (body): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (down): Conv2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (act): ReLU()
)

In [5]:
x = torch.randn(1,64,128,128)

In [6]:
import numpy as np
import scipy
from scipy import signal as signal
import torch
import torch.nn as nn


class CandyNet(nn.Module):
    def __init__(self, threshold=10.0, use_cuda=False):
        super(CandyNet, self).__init__()

        self.threshold = threshold
        self.use_cuda = use_cuda

        filter_size = 5
        generated_filters = signal.gaussian(filter_size, std=1.0).reshape([1, filter_size])

        self.gaussian_filter_horizontal = nn.Conv2d(1, 1, kernel_size=(1, filter_size), padding=(0, filter_size // 2))
        self.gaussian_filter_horizontal.weight.data.copy_(torch.from_numpy(generated_filters))
        self.gaussian_filter_horizontal.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        self.gaussian_filter_vertical = nn.Conv2d(1, 1, kernel_size=(filter_size, 1), padding=(filter_size // 2, 0))
        self.gaussian_filter_vertical.weight.data.copy_(torch.from_numpy(generated_filters.T))
        self.gaussian_filter_vertical.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        sobel_filter = np.array([[1, 0, -1],
                                 [2, 0, -2],
                                 [1, 0, -1]])

        self.sobel_filter_horizontal = nn.Conv2d(1, 1, kernel_size=sobel_filter.shape,
                                                 padding=sobel_filter.shape[0] // 2)
        self.sobel_filter_horizontal.weight.data.copy_(torch.from_numpy(sobel_filter))
        self.sobel_filter_horizontal.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        self.sobel_filter_vertical = nn.Conv2d(1, 1, kernel_size=sobel_filter.shape, padding=sobel_filter.shape[0] // 2)
        self.sobel_filter_vertical.weight.data.copy_(torch.from_numpy(sobel_filter.T))
        self.sobel_filter_vertical.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        # filters were flipped manually
        filter_0 = np.array([[0, 0, 0],
                             [0, 1, -1],
                             [0, 0, 0]])

        filter_45 = np.array([[0, 0, 0],
                              [0, 1, 0],
                              [0, 0, -1]])

        filter_90 = np.array([[0, 0, 0],
                              [0, 1, 0],
                              [0, -1, 0]])

        filter_135 = np.array([[0, 0, 0],
                               [0, 1, 0],
                               [-1, 0, 0]])

        filter_180 = np.array([[0, 0, 0],
                               [-1, 1, 0],
                               [0, 0, 0]])

        filter_225 = np.array([[-1, 0, 0],
                               [0, 1, 0],
                               [0, 0, 0]])

        filter_270 = np.array([[0, -1, 0],
                               [0, 1, 0],
                               [0, 0, 0]])

        filter_315 = np.array([[0, 0, -1],
                               [0, 1, 0],
                               [0, 0, 0]])

        all_filters = np.stack(
            [filter_0, filter_45, filter_90, filter_135, filter_180, filter_225, filter_270, filter_315])

        self.directional_filter = nn.Conv2d(1, 8, kernel_size=filter_0.shape, padding=filter_0.shape[-1] // 2)
        self.directional_filter.weight.data.copy_(torch.from_numpy(all_filters[:, None, ...]))
        self.directional_filter.bias.data.copy_(torch.from_numpy(np.zeros(shape=(all_filters.shape[0],))))

    def forward(self, img):  # (batch,channel,height, width)

        batch = img.shape[0]
        img_r = img[:, 0:1]  # batch,1,height, width
        img_g = img[:, 1:2]  # batch,1,height, width
        img_b = img[:, 2:3]  # batch,1,height, width

        blur_horizontal = self.gaussian_filter_horizontal(img_r)  # batch,1,height,width
        blurred_img_r = self.gaussian_filter_vertical(blur_horizontal)  # batch,1,height,width
        blur_horizontal = self.gaussian_filter_horizontal(img_g)  # batch,1,height,width
        blurred_img_g = self.gaussian_filter_vertical(blur_horizontal)  # batch,1,height,width
        blur_horizontal = self.gaussian_filter_horizontal(img_b)  # batch,1,height,width
        blurred_img_b = self.gaussian_filter_vertical(blur_horizontal)  # batch,1,height,width

        blurred_img = torch.stack([blurred_img_r, blurred_img_g, blurred_img_b], dim=1)  # batch,1,height,width
        blurred_img = torch.stack([torch.squeeze(blurred_img)])  # batch,1,height,width

        grad_x_r = self.sobel_filter_horizontal(blurred_img_r)  # batch,1,height,width
        grad_y_r = self.sobel_filter_vertical(blurred_img_r)  # batch,1,height,width
        grad_x_g = self.sobel_filter_horizontal(blurred_img_g)  # batch,1,height,width
        grad_y_g = self.sobel_filter_vertical(blurred_img_g)  # batch,1,height,width
        grad_x_b = self.sobel_filter_horizontal(blurred_img_b)  # batch,1,height,width
        grad_y_b = self.sobel_filter_vertical(blurred_img_b)  # batch,1,height,width

        # COMPUTE THICK EDGES
        grad_mag = torch.sqrt(grad_x_r ** 2 + grad_y_r ** 2)  # batch,1,height,width
        grad_mag += torch.sqrt(grad_x_g ** 2 + grad_y_g ** 2)  # batch,1,height,width
        grad_mag += torch.sqrt(grad_x_b ** 2 + grad_y_b ** 2)  # batch,1,height,width
        grad_orientation = (  # batch,1,height,width
                torch.atan2(grad_y_r + grad_y_g + grad_y_b, grad_x_r + grad_x_g + grad_x_b) * (180.0 / 3.14159))
        grad_orientation += 180.0  # batch,1,height,width
        grad_orientation = torch.round(grad_orientation / 45.0) * 45.0  # batch,1,height,width

        # THIN EDGES (NON-MAX SUPPRESSION)

        all_filtered = self.directional_filter(grad_mag)  # batch,8,height,width
        inidices_positive = (grad_orientation / 45) % 8  # batch,1,height,width
        inidices_negative = ((grad_orientation / 45) + 4) % 8  # batch,1,height,width

        height = inidices_positive.size()[2]
        width = inidices_positive.size()[3]
        pixel_count = height * width

        pixel_range = torch.FloatTensor([range(pixel_count)])  # batch,pixel_range
        if self.use_cuda:
            pixel_range = torch.cuda.FloatTensor([range(pixel_count)])

        indices = (  # batch,pixel_range
                inidices_positive.view(
                    inidices_positive.shape[0],
                    pixel_count).data * pixel_count + pixel_range)

        channel_select_filtered_positive = torch.ones(batch, 1, height, width)  # batch, 1, height, width
        for i in range(batch):
            channel_select_filtered_positive_temp = all_filtered[i].view(-1)[indices[i].long()].view(1, height, width)
            channel_select_filtered_positive[i] = channel_select_filtered_positive_temp

        indices = (  # batch,pixel_range
                inidices_negative.view(
                    inidices_negative.shape[0],
                    pixel_count).data * pixel_count + pixel_range)

        channel_select_filtered_negative = torch.ones(batch, 1, height, width)  # batch, 1, height, width
        for i in range(batch):
            channel_select_filtered_negative_temp = all_filtered[i].view(-1)[indices[i].long()].view(1, height, width)
            channel_select_filtered_negative[i] = channel_select_filtered_negative_temp

        channel_select_filtered = torch.stack(  # batch, 2, height, width
            [channel_select_filtered_positive, channel_select_filtered_negative], dim=1)

        is_max = channel_select_filtered.min(dim=1)[0] > 0.0

        thin_edges = grad_mag.clone()
        thin_edges[is_max == 0] = 0.0

        # THRESHOLD

        thresholded = thin_edges.clone()
        thresholded[thin_edges < self.threshold] = 0.0

        early_threshold = grad_mag.clone()
        early_threshold[grad_mag < self.threshold] = 0.0

        assert grad_mag.size() == grad_orientation.size() == thin_edges.size() == thresholded.size() == early_threshold.size()

        return thresholded


if __name__ == '__main__':
    CandyNet()


In [8]:
with torch.no_grad():
    torch.onnx.export(
        Encoding_block1,
        x,
        'test.onnx',
        #opset_version=11,
        input_names=['input'],
        output_names=['output']
    )

In [9]:
ls

draw.py    out.log      test.onnx        Untitled2.ipynb
nohup.out  python.log3  Untitled1.ipynb  Untitled.ipynb


In [10]:
import onnx

ModuleNotFoundError: No module named 'onnx'

In [7]:
import torch.nn as nn


class ConvUp(nn.Module):

    def __init__(self, ch_in, up_factor):

        super(ConvUp, self).__init__()

        body = [
            nn.Conv2d(ch_in, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
        ]

        if up_factor == 2:
            modules_tail = [
                nn.ConvTranspose2d(64, 64, kernel_size=3, stride=up_factor, padding=1, output_padding=1),
                nn.Conv2d(64, ch_in, 3, padding=3 // 2, bias=True)
            ]
        elif up_factor == 3:
            modules_tail = [
                nn.ConvTranspose2d(64, 64, kernel_size=3, stride=up_factor, padding=0, output_padding=0),
                nn.Conv2d(64, ch_in, 3, padding=3 // 2, bias=True)
            ]

        elif up_factor == 4:
            modules_tail = [
                nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.Conv2d(64, ch_in, 3, padding=3 // 2, bias=True)
            ]

        self.body = nn.Sequential(*body)
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, input):

        output = self.body(input)
        output = self.tail(output)
        return output


class ConvDown(nn.Module):

    def __init__(self, ch_in, up_factor):

        super(ConvDown, self).__init__()

        body = [
            nn.Conv2d(ch_in, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=3 // 2),
            nn.ReLU(),
        ]

        if up_factor == 4:
            modules_tail = [
                nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=2),
                nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=2),
                nn.Conv2d(64, ch_in, kernel_size=3, padding=3 // 2, bias=True)
            ]
        elif up_factor == 3:
            modules_tail = [
                nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=up_factor),
                nn.Conv2d(64, ch_in, kernel_size=3, padding=3 // 2, bias=True)
            ]
        elif up_factor == 2:
            modules_tail = [
                nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=up_factor),
                nn.Conv2d(64, ch_in, kernel_size=3, padding=3 // 2, bias=True)
            ]

        self.body = nn.Sequential(*body)
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, input):

        out = self.body(input)
        out = self.tail(out)
        return out


In [8]:
class E_DUN(nn.Module):
    def __init__(self, args):
        super(E_DUN, self).__init__()

        self.channel0 = args.n_colors
        self.up_factor = args.scale[0]
        self.patch_size = args.patch_size
        self.batch_size = int(args.batch_size / args.n_GPUs)

        self.Encoding_block1 = EncodingBlock(64)
        self.Encoding_block2 = EncodingBlock(64)
        self.Encoding_block3 = EncodingBlock(64)
        self.Encoding_block4 = EncodingBlock(64)

        self.Encoding_block_end = EncodingBlockEnd(64)

        self.Decoding_block1 = DecodingBlock(256)
        self.Decoding_block2 = DecodingBlock(256)
        self.Decoding_block3 = DecodingBlock(256)
        self.Decoding_block4 = DecodingBlock(256)

        self.feature_decoding_end = DecodingBlockEnd(256)

        self.act = nn.ReLU()

        self.construction = nn.Conv2d(64, 3, 3, padding=1)

        G0 = 64
        kSize = 3
        T = 4
        self.Fe_e = nn.ModuleList(
            [nn.Sequential(*[
                nn.Conv2d(3, G0, kSize, padding=(kSize - 1) // 2, stride=1),
                nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1)
            ]) for _ in range(T)])

        self.RNNF = nn.ModuleList(
            [nn.Sequential(*[
                nn.Conv2d((i + 2) * G0, G0, 1, padding=0, stride=1),
                nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1),
                self.act,
                nn.Conv2d(64, 3, 3, padding=1)
            ]) for i in range(T)])

        self.Fe_f = nn.ModuleList(
            [nn.Sequential(*[
                nn.Conv2d((2 * i + 3) * G0, G0, 1, padding=0, stride=1)
            ]) for i in range(T - 1)])

        self.eta = nn.ParameterList([nn.Parameter(torch.tensor(0.5)) for _ in range(T)])
        self.delta = nn.ParameterList([nn.Parameter(torch.tensor(0.1)) for _ in range(T)])

        self.conv_up = ConvUp(3, self.up_factor)
        self.conv_down = ConvDown(3, self.up_factor)

        self.candy = CandyNet(3).eval()  # candy算子不需要迭代内部系数

    def forward(self, y):  # [batch_size ,3 ,7 ,270 ,480] ;
        fea_list = []
        V_list = []
        outs = []

        x_texture = torch.nn.functional.interpolate(
            y, scale_factor=self.up_factor, mode='bilinear', align_corners=False)
        x_edge = self.candy(x_texture)
        x = x_edge + x_texture  # 这里可以增加一些倍数，直接相加可能会存在问题

        for i in range(len(self.Fe_e)):
            # --------------------denoising module------------------------
            fea = self.Fe_e[i](x_texture)
            fea_list.append(fea)
            if i != 0:
                fea = self.Fe_f[i - 1](torch.cat(fea_list, 1))
            encode0, down0 = self.Encoding_block1(fea)
            encode1, down1 = self.Encoding_block2(down0)
            encode2, down2 = self.Encoding_block3(down1)
            encode3, down3 = self.Encoding_block4(down2)

            media_end = self.Encoding_block_end(down3)

            decode3 = self.Decoding_block1(media_end, encode3)
            decode2 = self.Decoding_block2(decode3, encode2)
            decode1 = self.Decoding_block3(decode2, encode1)
            decode0 = self.feature_decoding_end(decode1, encode0)

            fea_list.append(decode0)
            V_list.append(decode0)
            if i == 0:
                decode0 = self.construction(self.act(decode0))
            else:
                decode0 = self.RNNF[i - 1](torch.cat(V_list, 1))
            v = x_texture + decode0

            # --------------------texture module--------------------------
            x_texture = x_texture - self.delta[i] * (
                    self.conv_up(self.conv_down(x) - y) + self.eta[i] * (x - v))

            # -----------------------edge module--------------------------
            x_edge = self.candy(x)
            x = x_edge + x_texture  # 这里可以增加一些倍数，直接相加可能会存在问题

            outs.append(x)

        return x


In [11]:
class test:
    def __init__(self):
        self.n_colors = 3
        self.scale = [2]
        self.patch_size = 2
        self.batch_size = 1
        self.n_GPUs = 1

In [12]:
args = test()

In [13]:
model = E_DUN(args)
model.eval()

E_DUN(
  (Encoding_block1): EncodingBlock(
    (body): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (3): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): Conv2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (act): ReLU()
  )
  (Encoding_block2): EncodingBlock(
    (body): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [None]:
x = torch.randn(1,3,256,256)
with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        'test.onnx',
        #opset_version=11,
        input_names=['input'],
        output_names=['output']
    )



In [14]:
candy = CandyNet(3)

In [15]:
candy

CandyNet(
  (gaussian_filter_horizontal): Conv2d(1, 1, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2))
  (gaussian_filter_vertical): Conv2d(1, 1, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
  (sobel_filter_horizontal): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (sobel_filter_vertical): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (directional_filter): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [16]:
list(candy.parameters())

[Parameter containing:
 tensor([[[[0.1353, 0.6065, 1.0000, 0.6065, 0.1353]]]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True),
 Parameter containing:
 tensor([[[[0.1353],
           [0.6065],
           [1.0000],
           [0.6065],
           [0.1353]]]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True),
 Parameter containing:
 tensor([[[[ 1.,  0., -1.],
           [ 2.,  0., -2.],
           [ 1.,  0., -1.]]]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True),
 Parameter containing:
 tensor([[[[ 1.,  2.,  1.],
           [ 0.,  0.,  0.],
           [-1., -2., -1.]]]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True),
 Parameter containing:
 tensor([[[[ 0.,  0.,  0.],
           [ 0.,  1., -1.],
           [ 0.,  0.,  0.]]],
 
 
         [[[ 0.,  0.,  0.],
           [ 0.,  1.,  0.],
           [ 0.,  0., -1.]]],
 
 
         [[[ 0.,  0.,  0.],
           [ 0.,  1.,  

In [8]:
for para in candy.parameters():
    para.requires_grad = False

In [10]:
list(candy.parameters())

[Parameter containing:
 tensor([[[[0.1353, 0.6065, 1.0000, 0.6065, 0.1353]]]]),
 Parameter containing:
 tensor([0.]),
 Parameter containing:
 tensor([[[[0.1353],
           [0.6065],
           [1.0000],
           [0.6065],
           [0.1353]]]]),
 Parameter containing:
 tensor([0.]),
 Parameter containing:
 tensor([[[[ 1.,  0., -1.],
           [ 2.,  0., -2.],
           [ 1.,  0., -1.]]]]),
 Parameter containing:
 tensor([0.]),
 Parameter containing:
 tensor([[[[ 1.,  2.,  1.],
           [ 0.,  0.,  0.],
           [-1., -2., -1.]]]]),
 Parameter containing:
 tensor([0.]),
 Parameter containing:
 tensor([[[[ 0.,  0.,  0.],
           [ 0.,  1., -1.],
           [ 0.,  0.,  0.]]],
 
 
         [[[ 0.,  0.,  0.],
           [ 0.,  1.,  0.],
           [ 0.,  0., -1.]]],
 
 
         [[[ 0.,  0.,  0.],
           [ 0.,  1.,  0.],
           [ 0., -1.,  0.]]],
 
 
         [[[ 0.,  0.,  0.],
           [ 0.,  1.,  0.],
           [-1.,  0.,  0.]]],
 
 
         [[[ 0.,  0.,  0.],
   

In [17]:
candy.eval()

CandyNet(
  (gaussian_filter_horizontal): Conv2d(1, 1, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2))
  (gaussian_filter_vertical): Conv2d(1, 1, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
  (sobel_filter_horizontal): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (sobel_filter_vertical): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (directional_filter): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [20]:
x = torch.randn(1,3,256,256)
with torch.no_grad():
    torch.onnx.export(
        candy,
        x,
        'candy.onnx',
        #opset_version=11,
        input_names=['input'],
        #output_names=['output']
    )

  .format(op_name, op_name))
  .format(op_name, op_name))
  .format(op_name, op_name))


RuntimeError: ONNX export failed: Couldn't export operator aten::atan2

Defined at:
<ipython-input-8-6c7c0dfadad2>(108): forward
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/nn/modules/module.py(477): _slow_forward
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/nn/modules/module.py(487): __call__
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/jit/__init__.py(252): forward
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/nn/modules/module.py(489): __call__
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/jit/__init__.py(197): get_trace_graph
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/onnx/utils.py(192): _trace_and_get_graph_from_model
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/onnx/utils.py(224): _model_to_graph
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/onnx/utils.py(281): _export
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/onnx/utils.py(104): export
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/torch/onnx/__init__.py(27): export
<ipython-input-20-ef10e6136362>(8): <module>
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3343): run_code
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3263): run_ast_nodes
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3072): run_cell_async
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2895): _run_cell
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2867): run_cell
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel/zmqshell.py(536): run_cell
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel/ipkernel.py(306): do_execute
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(162): _fake_ctx_run
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(234): wrapper
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel/kernelbase.py(545): execute_request
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(162): _fake_ctx_run
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(234): wrapper
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel/kernelbase.py(268): dispatch_shell
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(162): _fake_ctx_run
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(234): wrapper
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel/kernelbase.py(365): process_one
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(775): run
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(162): _fake_ctx_run
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/gen.py(814): inner
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/ioloop.py(741): _run_callback
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/ioloop.py(688): <lambda>
/root/anaconda3/envs/aaa/lib/python3.6/asyncio/events.py(145): _run
/root/anaconda3/envs/aaa/lib/python3.6/asyncio/base_events.py(1462): _run_once
/root/anaconda3/envs/aaa/lib/python3.6/asyncio/base_events.py(442): run_forever
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/tornado/platform/asyncio.py(199): start
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel/kernelapp.py(612): start
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/traitlets/config/application.py(664): launch_instance
/root/anaconda3/envs/aaa/lib/python3.6/site-packages/ipykernel_launcher.py(16): <module>
/root/anaconda3/envs/aaa/lib/python3.6/runpy.py(85): _run_code
/root/anaconda3/envs/aaa/lib/python3.6/runpy.py(193): _run_module_as_main


Graph we tried to export:
graph(%input : Float(1, 3, 256, 256)
      %1 : Float(1, 1, 1, 5)
      %2 : Float(1)
      %3 : Float(1, 1, 5, 1)
      %4 : Float(1)
      %5 : Float(1, 1, 3, 3)
      %6 : Float(1)
      %7 : Float(1, 1, 3, 3)
      %8 : Float(1)
      %9 : Float(8, 1, 3, 3)
      %10 : Float(8)) {
  %11 : Long() = onnx::Constant[value={0}](), scope: CandyNet
  %12 : Tensor = onnx::Shape(%input), scope: CandyNet
  %13 : Long() = onnx::Gather[axis=0](%12, %11), scope: CandyNet
  %14 : Float(1, 3, 256, 256) = onnx::Slice[axes=[0], ends=[9223372036854775807], starts=[0]](%input), scope: CandyNet
  %15 : Float(1!, 1, 256, 256) = onnx::Slice[axes=[1], ends=[1], starts=[0]](%14), scope: CandyNet
  %16 : Float(1, 3, 256, 256) = onnx::Slice[axes=[0], ends=[9223372036854775807], starts=[0]](%input), scope: CandyNet
  %17 : Float(1!, 1, 256, 256) = onnx::Slice[axes=[1], ends=[2], starts=[1]](%16), scope: CandyNet
  %18 : Float(1, 3, 256, 256) = onnx::Slice[axes=[0], ends=[9223372036854775807], starts=[0]](%input), scope: CandyNet
  %19 : Float(1!, 1, 256, 256) = onnx::Slice[axes=[1], ends=[3], starts=[2]](%18), scope: CandyNet
  %20 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 5], pads=[0, 2, 0, 2], strides=[1, 1]](%15, %1, %2), scope: CandyNet/Conv2d[gaussian_filter_horizontal]
  %21 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 1], pads=[2, 0, 2, 0], strides=[1, 1]](%20, %3, %4), scope: CandyNet/Conv2d[gaussian_filter_vertical]
  %22 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 5], pads=[0, 2, 0, 2], strides=[1, 1]](%17, %1, %2), scope: CandyNet/Conv2d[gaussian_filter_horizontal]
  %23 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 1], pads=[2, 0, 2, 0], strides=[1, 1]](%22, %3, %4), scope: CandyNet/Conv2d[gaussian_filter_vertical]
  %24 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 5], pads=[0, 2, 0, 2], strides=[1, 1]](%19, %1, %2), scope: CandyNet/Conv2d[gaussian_filter_horizontal]
  %25 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 1], pads=[2, 0, 2, 0], strides=[1, 1]](%24, %3, %4), scope: CandyNet/Conv2d[gaussian_filter_vertical]
  %26 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%21, %5, %6), scope: CandyNet/Conv2d[sobel_filter_horizontal]
  %27 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%21, %7, %8), scope: CandyNet/Conv2d[sobel_filter_vertical]
  %28 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%23, %5, %6), scope: CandyNet/Conv2d[sobel_filter_horizontal]
  %29 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%23, %7, %8), scope: CandyNet/Conv2d[sobel_filter_vertical]
  %30 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%25, %5, %6), scope: CandyNet/Conv2d[sobel_filter_horizontal]
  %31 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%25, %7, %8), scope: CandyNet/Conv2d[sobel_filter_vertical]
  %32 : Tensor = onnx::Constant[value={2}](), scope: CandyNet
  %33 : Float(1, 1, 256, 256) = onnx::Pow(%26, %32), scope: CandyNet
  %34 : Tensor = onnx::Constant[value={2}](), scope: CandyNet
  %35 : Float(1, 1, 256, 256) = onnx::Pow(%27, %34), scope: CandyNet
  %36 : Float(1, 1, 256, 256) = onnx::Add(%33, %35), scope: CandyNet
  %37 : Float(1, 1, 256, 256) = onnx::Sqrt(%36), scope: CandyNet
  %38 : Tensor = onnx::Constant[value={2}](), scope: CandyNet
  %39 : Float(1, 1, 256, 256) = onnx::Pow(%28, %38), scope: CandyNet
  %40 : Tensor = onnx::Constant[value={2}](), scope: CandyNet
  %41 : Float(1, 1, 256, 256) = onnx::Pow(%29, %40), scope: CandyNet
  %42 : Float(1, 1, 256, 256) = onnx::Add(%39, %41), scope: CandyNet
  %43 : Float(1, 1, 256, 256) = onnx::Sqrt(%42), scope: CandyNet
  %44 : Float(1, 1, 256, 256) = onnx::Add(%37, %43), scope: CandyNet
  %45 : Tensor = onnx::Constant[value={2}](), scope: CandyNet
  %46 : Float(1, 1, 256, 256) = onnx::Pow(%30, %45), scope: CandyNet
  %47 : Tensor = onnx::Constant[value={2}](), scope: CandyNet
  %48 : Float(1, 1, 256, 256) = onnx::Pow(%31, %47), scope: CandyNet
  %49 : Float(1, 1, 256, 256) = onnx::Add(%46, %48), scope: CandyNet
  %50 : Float(1, 1, 256, 256) = onnx::Sqrt(%49), scope: CandyNet
  %51 : Float(1, 1, 256, 256) = onnx::Add(%44, %50), scope: CandyNet
  %52 : Float(1, 1, 256, 256) = onnx::Add(%27, %29), scope: CandyNet
  %53 : Float(1, 1, 256, 256) = onnx::Add(%52, %31), scope: CandyNet
  %54 : Float(1, 1, 256, 256) = onnx::Add(%26, %28), scope: CandyNet
  %55 : Float(1, 1, 256, 256) = onnx::Add(%54, %30), scope: CandyNet
  %56 : Float(1, 1, 256, 256) = aten::atan2(%53, %55), scope: CandyNet
  %57 : Tensor = onnx::Constant[value={57.2958}]()
  %58 : Tensor = onnx::Mul(%56, %57)
  %59 : Tensor = onnx::Constant[value={180}]()
  %60 : Tensor = onnx::Add(%58, %59)
  %61 : Tensor = onnx::Constant[value={45}]()
  %62 : Tensor = onnx::Div(%60, %61)
  %63 : Float(1, 1, 256, 256) = aten::round(%62), scope: CandyNet
  %64 : Tensor = onnx::Constant[value={45}]()
  %65 : Tensor = onnx::Mul(%63, %64)
  %66 : Tensor = onnx::Constant[value={45}]()
  %67 : Tensor = onnx::Div(%65, %66)
  %68 : Long() = onnx::Constant[value={8}](), scope: CandyNet
  %69 : Float(1, 1, 256, 256) = aten::remainder(%67, %68), scope: CandyNet
  %70 : Long() = onnx::Constant[value={2}](), scope: CandyNet
  %71 : Tensor = onnx::Shape(%69), scope: CandyNet
  %72 : Long() = onnx::Gather[axis=0](%71, %70), scope: CandyNet
  %73 : Long() = onnx::Constant[value={3}](), scope: CandyNet
  %74 : Tensor = onnx::Shape(%69), scope: CandyNet
  %75 : Long() = onnx::Gather[axis=0](%74, %73), scope: CandyNet
  %76 : Long() = onnx::Constant[value={1}](), scope: CandyNet
  %77 : Tensor = onnx::Unsqueeze[axes=[0]](%13)
  %78 : Tensor = onnx::Unsqueeze[axes=[0]](%76)
  %79 : Tensor = onnx::Unsqueeze[axes=[0]](%72)
  %80 : Tensor = onnx::Unsqueeze[axes=[0]](%75)
  %81 : Tensor = onnx::Concat[axis=0](%77, %78, %79, %80)
  %82 : Float(1, 1, 256, 256) = onnx::ConstantFill[dtype=1, input_as_shape=1, value=1](%81), scope: CandyNet
  %83 : Long() = onnx::Constant[value={1}](), scope: CandyNet
  %84 : Tensor = onnx::Unsqueeze[axes=[0]](%13)
  %85 : Tensor = onnx::Unsqueeze[axes=[0]](%83)
  %86 : Tensor = onnx::Unsqueeze[axes=[0]](%72)
  %87 : Tensor = onnx::Unsqueeze[axes=[0]](%75)
  %88 : Tensor = onnx::Concat[axis=0](%84, %85, %86, %87)
  %89 : Float(1, 1, 256, 256) = onnx::ConstantFill[dtype=1, input_as_shape=1, value=1](%88), scope: CandyNet
  %90 : Tensor = onnx::Unsqueeze[axes=[1]](%82), scope: CandyNet
  %91 : Tensor = onnx::Unsqueeze[axes=[1]](%89), scope: CandyNet
  %92 : Float(1, 2, 1, 256, 256) = onnx::Concat[axis=1](%90, %91), scope: CandyNet
  %93 : Float(1, 1, 256, 256), %94 : Long(1, 1, 256, 256) = onnx::ATen[dim=1, keepdim=0, operator="min"](%92), scope: CandyNet
  %95 : Tensor = onnx::Constant[value={0}](), scope: CandyNet
  %96 : Byte(1, 1, 256, 256) = onnx::Greater(%93, %95), scope: CandyNet
  %97 : Long() = onnx::Constant[value={0}](), scope: CandyNet
  %98 : Byte(1, 1, 256, 256) = onnx::Equal(%96, %97), scope: CandyNet
  %99 : Float() = onnx::Constant[value={0}]()
  %100 : Byte(1, 1, 256, 256) = onnx::Cast[to=2](%98), scope: CandyNet
  %101 : Long() = onnx::Constant[value={0}](), scope: CandyNet
  %102 : Float(1, 1, 256, 256) = onnx::ATen[operator="index_put"](%51, %100, %99, %101), scope: CandyNet
  %103 : Tensor = onnx::Constant[value={3}](), scope: CandyNet
  %104 : Byte(1, 1, 256, 256) = onnx::Less(%102, %103), scope: CandyNet
  %105 : Float() = onnx::Constant[value={0}]()
  %106 : Byte(1, 1, 256, 256) = onnx::Cast[to=2](%104), scope: CandyNet
  %107 : Long() = onnx::Constant[value={0}](), scope: CandyNet
  %108 : Float(1, 1, 256, 256) = onnx::ATen[operator="index_put"](%102, %106, %105, %107), scope: CandyNet
  return (%108);
}


In [19]:
candy(x)

tensor([[[[ 0.0000, 25.9050,  0.0000,  ...,  0.0000,  0.0000, 17.5363],
          [ 0.0000,  0.0000, 31.3237,  ..., 27.4972,  0.0000, 19.5156],
          [30.4530, 40.9091,  0.0000,  ...,  0.0000, 17.6542,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 17.8266],
          [24.5876, 30.4273, 34.9075,  ...,  0.0000, 20.9378, 15.7724],
          [23.1771,  0.0000,  0.0000,  ...,  0.0000, 14.5815,  0.0000]]]],
       grad_fn=<IndexPutBackward>)

In [22]:
model = ConvDown(3,2)

In [23]:
x.shape

torch.Size([1, 3, 256, 256])

In [25]:
model.eval()

ConvDown(
  (body): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
  )
  (tail): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [27]:
with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        'ConvDown.onnx',
        #opset_version=11,
        input_names=['input'],
        output_names=['output']
    )

In [9]:
model = ConvUp(3,2)

In [10]:
x = torch.randn(1,3,128,128)

In [11]:
with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        'ConvUp.onnx',
        #opset_version=11,
        input_names=['input'],
        output_names=['output']
    )