In [466]:
import torch

x = torch.randn(1, 3, 20, 20)

In [467]:
import torch
import torch.nn as nn

from models.denoisingModule import EncodingBlock, EncodingBlockEnd, DecodingBlock, DecodingBlockEnd
from models.textureReconstructionModule import ConvDown, ConvUp


def make_model(args, parent=False):
    return E_DUN(args)


class E_DUN(nn.Module):
    def __init__(self, args):
        super(E_DUN, self).__init__()

        self.channel0 = args.n_colors  # channel的数量
        self.up_factor = args.scale[0]  # 放大倍数
        self.patch_size = args.patch_size
        self.batch_size = int(args.batch_size / args.n_GPUs)

        self.Encoding_block1 = EncodingBlock(64)
        self.Encoding_block2 = EncodingBlock(64)
        self.Encoding_block3 = EncodingBlock(64)
        self.Encoding_block4 = EncodingBlock(64)

        self.Encoding_block_end = EncodingBlockEnd(64)

        self.Decoding_block1 = DecodingBlock(256)
        self.Decoding_block2 = DecodingBlock(256)
        self.Decoding_block3 = DecodingBlock(256)
        self.Decoding_block4 = DecodingBlock(256)

        self.feature_decoding_end = DecodingBlockEnd(256)

        self.act = nn.ReLU()

        self.construction = nn.Conv2d(64, 3, 3, padding=1)

        G0 = 64
        kSize = 3
        T = 4
        self.Fe_e = nn.ModuleList(
            [nn.Sequential(
                *[
                    nn.Conv2d(3, G0, kSize, padding=(kSize - 1) // 2, stride=1),
                    nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1)
                ]
            ) for _ in range(T)]
        )

        self.RNNF = nn.ModuleList(
            [nn.Sequential(
                *[
                    nn.Conv2d((i + 2) * G0, G0, 1, padding=0, stride=1),
                    nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1),
                    self.act,
                    nn.Conv2d(64, 3, 3, padding=1)
                ]
            ) for i in range(T)]
        )

        self.Fe_f = nn.ModuleList(
            [nn.Sequential(
                *[
                    nn.Conv2d((2 * i + 3) * G0, G0, 1, padding=0, stride=1)
                ]
            ) for i in range(T - 1)]
        )

        # 纹理重构模块
        self.eta = nn.ParameterList([nn.Parameter(torch.tensor(0.5)) for _ in range(T)])
        self.delta = nn.ParameterList([nn.Parameter(torch.tensor(0.1)) for _ in range(T)])
        self.conv_up = ConvUp(3, self.up_factor)
        self.conv_down = ConvDown(3, self.up_factor)

        # candy算子不需要迭代内部系数
        self.candy = CandyNet(3)
        for para in self.candy.parameters():
            para.requires_grad = False

    def forward(self, y):  # [batch_size ,3 ,7 ,270 ,480] ;

        fea_list = []
        V_list = []
        outs = []
        x_texture = []
        x_texture.append(torch.nn.functional.interpolate(
            y, scale_factor=self.up_factor, mode='bilinear', align_corners=False))
        x_edge = self.candy(x_texture[0])
        x = (x_edge + x_texture[0])  # 这里可以增加一些倍数，直接相加可能会存在问题

        for i in range(len(self.Fe_e)):
            # --------------------denoising module------------------------
            fea = self.Fe_e[i](x_texture[i])
            fea_list.append(fea)
            if i != 0:
                fea = self.Fe_f[i - 1](torch.cat(fea_list, 1))
            encode0, down0 = self.Encoding_block1(fea)
            encode1, down1 = self.Encoding_block2(down0)
            encode2, down2 = self.Encoding_block3(down1)
            encode3, down3 = self.Encoding_block4(down2)

            media_end = self.Encoding_block_end(down3)

            decode3 = self.Decoding_block1(media_end, encode3)
            decode2 = self.Decoding_block2(decode3, encode2)
            decode1 = self.Decoding_block3(decode2, encode1)
            decode0 = self.feature_decoding_end(decode1, encode0)

            fea_list.append(decode0)
            V_list.append(decode0)
            if i == 0:
                decode0 = self.construction(self.act(decode0))
            else:
                decode0 = self.RNNF[i - 1](torch.cat(V_list, 1))
            v = x_texture[i] + decode0

            # # --------------------texture module--------------------------
            x_texture.append(x_texture[i] - self.delta[i] * (
                    self.conv_up(self.conv_down(x) - y) + self.eta[i] * (x - v)))

            # # -----------------------edge module--------------------------
            x_edge = (self.candy(x))  # 这里对代码进行了置换
            x = x_edge + x_texture[i + 1]  # 这里可以增加一些倍数，直接相加可能会存在问题
            #
            outs.append(x)

        return x

In [468]:
class args:
    def __init__(self):
        self.n_colors = 3
        self.scale = [2]
        self.patch_size = 1
        self.batch_size = 1
        self.n_GPUs = 1

In [469]:
test = args()
test.n_colors

3

In [470]:
model = E_DUN(test)
model

E_DUN(
  (Encoding_block1): EncodingBlock(
    (body): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (3): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): Conv2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (act): ReLU()
  )
  (Encoding_block2): EncodingBlock(
    (body): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), paddin

In [471]:
y = model(x)

In [472]:
y.shape

torch.Size([1, 3, 40, 40])

In [473]:
import torch.nn as nn

torch.autograd.set_detect_anomaly(True)
loss_function = nn.L1Loss()

In [474]:
z = torch.rand(1, 3, 40, 40)
loss = loss_function(y, z)

In [475]:

loss.backward()

In [None]:
loss


In [None]:
loss.backward()

In [453]:
import numpy as np
import torch
import torch.nn as nn
from scipy import signal as signal


class CandyNet(nn.Module):
    def __init__(self, threshold=10.0, use_cuda=False):
        super(CandyNet, self).__init__()

        self.threshold = threshold
        self.use_cuda = use_cuda

        filter_size = 5
        generated_filters = signal.gaussian(filter_size, std=1.0).reshape([1, filter_size])

        self.gaussian_filter_horizontal = nn.Conv2d(1, 1, kernel_size=(1, filter_size), padding=(0, filter_size // 2))
        self.gaussian_filter_horizontal.weight.data.copy_(torch.from_numpy(generated_filters))
        self.gaussian_filter_horizontal.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        self.gaussian_filter_vertical = nn.Conv2d(1, 1, kernel_size=(filter_size, 1), padding=(filter_size // 2, 0))
        self.gaussian_filter_vertical.weight.data.copy_(torch.from_numpy(generated_filters.T))
        self.gaussian_filter_vertical.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        sobel_filter = np.array([[1, 0, -1],
                                 [2, 0, -2],
                                 [1, 0, -1]])

        self.sobel_filter_horizontal = nn.Conv2d(1, 1, kernel_size=sobel_filter.shape,
                                                 padding=sobel_filter.shape[0] // 2)
        self.sobel_filter_horizontal.weight.data.copy_(torch.from_numpy(sobel_filter))
        self.sobel_filter_horizontal.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        self.sobel_filter_vertical = nn.Conv2d(1, 1, kernel_size=sobel_filter.shape, padding=sobel_filter.shape[0] // 2)
        self.sobel_filter_vertical.weight.data.copy_(torch.from_numpy(sobel_filter.T))
        self.sobel_filter_vertical.bias.data.copy_(torch.from_numpy(np.array([0.0])))

        # filters were flipped manually
        filter_0 = np.array([[0, 0, 0],
                             [0, 1, -1],
                             [0, 0, 0]])

        filter_45 = np.array([[0, 0, 0],
                              [0, 1, 0],
                              [0, 0, -1]])

        filter_90 = np.array([[0, 0, 0],
                              [0, 1, 0],
                              [0, -1, 0]])

        filter_135 = np.array([[0, 0, 0],
                               [0, 1, 0],
                               [-1, 0, 0]])

        filter_180 = np.array([[0, 0, 0],
                               [-1, 1, 0],
                               [0, 0, 0]])

        filter_225 = np.array([[-1, 0, 0],
                               [0, 1, 0],
                               [0, 0, 0]])

        filter_270 = np.array([[0, -1, 0],
                               [0, 1, 0],
                               [0, 0, 0]])

        filter_315 = np.array([[0, 0, -1],
                               [0, 1, 0],
                               [0, 0, 0]])

        all_filters = np.stack(
            [filter_0, filter_45, filter_90, filter_135, filter_180, filter_225, filter_270, filter_315])

        self.directional_filter = nn.Conv2d(1, 8, kernel_size=filter_0.shape, padding=filter_0.shape[-1] // 2)
        self.directional_filter.weight.data.copy_(torch.from_numpy(all_filters[:, None, ...]))
        self.directional_filter.bias.data.copy_(torch.from_numpy(np.zeros(shape=(all_filters.shape[0],))))

    def forward(self, img):  # (batch,channel,height, width)

        batch = img.shape[0]
        img_r = img[:, 0:1]  # batch,1,height, width
        img_g = img[:, 1:2]  # batch,1,height, width
        img_b = img[:, 2:3]  # batch,1,height, width

        blur_horizontal_r = self.gaussian_filter_horizontal(img_r)  # batch,1,height,width
        blurred_img_r = self.gaussian_filter_vertical(blur_horizontal_r)  # batch,1,height,width
        blur_horizontal_g = self.gaussian_filter_horizontal(img_g)  # batch,1,height,width
        blurred_img_g = self.gaussian_filter_vertical(blur_horizontal_g)  # batch,1,height,width
        blur_horizontal_b = self.gaussian_filter_horizontal(img_b)  # batch,1,height,width
        blurred_img_b = self.gaussian_filter_vertical(blur_horizontal_b)  # batch,1,height,width

        blurred_img_ = torch.stack([blurred_img_r, blurred_img_g, blurred_img_b], dim=1)  # batch,1,height,width
        blurred_img = torch.stack([torch.squeeze(blurred_img_)])  # batch,1,height,width

        grad_x_r = self.sobel_filter_horizontal(blurred_img_r)  # batch,1,height,width
        grad_y_r = self.sobel_filter_vertical(blurred_img_r)  # batch,1,height,width
        grad_x_g = self.sobel_filter_horizontal(blurred_img_g)  # batch,1,height,width
        grad_y_g = self.sobel_filter_vertical(blurred_img_g)  # batch,1,height,width
        grad_x_b = self.sobel_filter_horizontal(blurred_img_b)  # batch,1,height,width
        grad_y_b = self.sobel_filter_vertical(blurred_img_b)  # batch,1,height,width

        # COMPUTE THICK EDGES
        grad_mag_1 = torch.sqrt(grad_x_r ** 2 + grad_y_r ** 2)  # batch,1,height,width
        grad_mag_2 = grad_mag_1 + torch.sqrt(grad_x_g ** 2 + grad_y_g ** 2)  # batch,1,height,width
        grad_mag = grad_mag_2 + torch.sqrt(grad_x_b ** 2 + grad_y_b ** 2)  # batch,1,height,width
        grad_orientation_1 = (  # batch,1,height,width
                torch.atan2(grad_y_r + grad_y_g + grad_y_b, grad_x_r + grad_x_g + grad_x_b) * (180.0 / 3.14159))
        grad_orientation_2 = grad_orientation_1 + 180.0  # batch,1,height,width
        grad_orientation = torch.round(grad_orientation_2 / 45.0) * 45.0  # batch,1,height,width

        # THIN EDGES (NON-MAX SUPPRESSION)

        all_filtered = self.directional_filter(grad_mag)  # batch,8,height,width
        inidices_positive = (grad_orientation / 45) % 8  # batch,1,height,width
        inidices_negative = ((grad_orientation / 45) + 4) % 8  # batch,1,height,width

        height = inidices_positive.size()[2]
        width = inidices_positive.size()[3]
        pixel_count = height * width

        pixel_range = torch.FloatTensor([range(pixel_count)])  # batch,pixel_range
        if self.use_cuda:
            pixel_range = torch.cuda.FloatTensor([range(pixel_count)])

        indices = (  # batch,pixel_range
                inidices_positive.view(
                    inidices_positive.shape[0],
                    pixel_count).data * pixel_count + pixel_range)

        channel_select_filtered_positive = torch.ones(batch, 1, height, width)  # batch, 1, height, width
        for i in range(batch):
            channel_select_filtered_positive_temp = all_filtered[i].view(-1)[indices[i].long()].view(1, height, width)
            channel_select_filtered_positive[i] = channel_select_filtered_positive_temp

        indices = (  # batch,pixel_range
                inidices_negative.view(
                    inidices_negative.shape[0],
                    pixel_count).data * pixel_count + pixel_range)

        channel_select_filtered_negative = torch.ones(batch, 1, height, width)  # batch, 1, height, width
        for i in range(batch):
            channel_select_filtered_negative_temp = all_filtered[i].view(-1)[indices[i].long()].view(1, height, width)
            channel_select_filtered_negative[i] = channel_select_filtered_negative_temp

        channel_select_filtered = torch.stack(  # batch, 2, height, width
            [channel_select_filtered_positive, channel_select_filtered_negative], dim=1)

        is_max = channel_select_filtered.min(dim=1)[0] > 0.0

        thin_edges = grad_mag.clone()
        thin_edges[is_max == 0] = 0.0

        # THRESHOLD

        thresholded = thin_edges.clone()
        thresholded[thin_edges < self.threshold] = 0.0

        early_threshold = grad_mag.clone()
        early_threshold[grad_mag < self.threshold] = 0.0

        # assert grad_mag.size() == grad_orientation.size() == thin_edges.size() == thresholded.size() == early_threshold.size()

        return thresholded


if __name__ == '__main__':
    CandyNet()


In [454]:
candy = CandyNet(3)
y = candy(x)
loss = loss_function(y, x)

  return F.l1_loss(input, target, reduction=self.reduction)


In [455]:
loss.backward()