In [26]:
import torch
import torch.nn as nn
import functools
import torch.nn.functional as F
from collections import OrderedDict
# from .src.model.loss import VGG19


class VGG19(torch.nn.Module):
    def __init__(self, stage):
        super(VGG19, self).__init__()
        vgg19 = models.vgg19(pretrained=True)
        features = vgg19.features
        self.stage = stage

        self.relu1_1 = torch.nn.Sequential()
        self.relu1_2 = torch.nn.Sequential()

        self.relu2_1 = torch.nn.Sequential()
        self.relu2_2 = torch.nn.Sequential()

        self.relu3_1 = torch.nn.Sequential()
        self.relu3_2 = torch.nn.Sequential()
        self.relu3_3 = torch.nn.Sequential()
        self.relu3_4 = torch.nn.Sequential()

        self.relu4_1 = torch.nn.Sequential()
        self.relu4_2 = torch.nn.Sequential()
        self.relu4_3 = torch.nn.Sequential()
        self.relu4_4 = torch.nn.Sequential()

        self.relu5_1 = torch.nn.Sequential()
        self.relu5_2 = torch.nn.Sequential()
        self.relu5_3 = torch.nn.Sequential()
        self.relu5_4 = torch.nn.Sequential()
        
        if self.stage == 1:
            classifier = vgg19.classifier
            self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            self.class_other = torch.nn.Sequential()
            self.linear_3 = torch.nn.Sequential()

        for x in range(2):
            self.relu1_1.add_module(str(x), features[x])

        for x in range(2, 4):
            self.relu1_2.add_module(str(x), features[x])

        for x in range(4, 7):
            self.relu2_1.add_module(str(x), features[x])

        for x in range(7, 9):
            self.relu2_2.add_module(str(x), features[x])

        for x in range(9, 12):
            self.relu3_1.add_module(str(x), features[x])

        for x in range(12, 14):
            self.relu3_2.add_module(str(x), features[x])

        for x in range(14, 16):
            self.relu3_2.add_module(str(x), features[x])

        for x in range(16, 18):
            self.relu3_4.add_module(str(x), features[x])

        for x in range(18, 21):
            self.relu4_1.add_module(str(x), features[x])

        for x in range(21, 23):
            self.relu4_2.add_module(str(x), features[x])

        for x in range(23, 25):
            self.relu4_3.add_module(str(x), features[x])

        for x in range(25, 27):
            self.relu4_4.add_module(str(x), features[x])

        for x in range(27, 30):
            self.relu5_1.add_module(str(x), features[x])

        for x in range(30, 32):
            self.relu5_2.add_module(str(x), features[x])

        for x in range(32, 34):
            self.relu5_3.add_module(str(x), features[x])

        for x in range(34, 36):
            self.relu5_4.add_module(str(x), features[x])
        
        if self.stage == 1:
            for x in range(0, 6):
                self.class_other.add_module(str(x), classifier[x])
            for x in range(6, 6):
                self.linear_3.add_module(str(x), classifier[x])
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        relu1_1 = self.relu1_1(x)
        relu1_2 = self.relu1_2(relu1_1)

        relu2_1 = self.relu2_1(relu1_2)
        relu2_2 = self.relu2_2(relu2_1)

        relu3_1 = self.relu3_1(relu2_2)
        relu3_2 = self.relu3_2(relu3_1)
        relu3_3 = self.relu3_3(relu3_2)
        relu3_4 = self.relu3_4(relu3_3)

        relu4_1 = self.relu4_1(relu3_4)
        relu4_2 = self.relu4_2(relu4_1)
        relu4_3 = self.relu4_3(relu4_2)
        relu4_4 = self.relu4_4(relu4_3)

        relu5_1 = self.relu5_1(relu4_4)
        relu5_2 = self.relu5_2(relu5_1)
        relu5_3 = self.relu5_3(relu5_2)
        relu5_4 = self.relu5_4(relu5_3)
        
        linear_3 = None
        if self.stage == 1:
            maxpool = self.maxpool(relu5_4)
            x = maxpool.view(maxpool.shape[0],-1)
            x = self.class_other(x)
            linear_3 = self.linear_3(x)
        
        out = {
            'relu1_1': relu1_1,
            'relu1_2': relu1_2,

            'relu2_1': relu2_1,
            'relu2_2': relu2_2,

            'relu3_1': relu3_1,
            'relu3_2': relu3_2,
            'relu3_3': relu3_3,
            'relu3_4': relu3_4,

            'relu4_1': relu4_1,
            'relu4_2': relu4_2,
            'relu4_3': relu4_3,
            'relu4_4': relu4_4,

            'relu5_1': relu5_1,
            'relu5_2': relu5_2,
            'relu5_3': relu5_3,
            'relu5_4': relu5_4,

            'linear_3': linear_3
        }
        return out



class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def init_weights(self, init_type='xavier', gain=0.02):
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)

                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)
        
class ResnetBlock(nn.Module):
    def __init__(self, dim, dilation=1, use_spectral_norm=False, use_dropout=False):
        super(ResnetBlock, self).__init__()
        conv_block = [
            nn.ReflectionPad2d(dilation),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
         ]
        
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]
            
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)

        return out

def spectral_norm(module, mode=True):
    if mode:
        return nn.utils.spectral_norm(module)

    return module
    
class ImagineNet(BaseNetwork):
    def __init__(self, residual_blocks=8, init_weights=True, in_channels=3, out_channels=3, expand=True):
        super(ImagineNet, self).__init__()
        
        # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=1, padding=0),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True),
        )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock(256, 2)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)
        
        self.outer = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
        )
        self.tanh = nn.Tanh()
            
        if init_weights:
            self.init_weights()

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.outer(x)
        x = (self.tanh(x) + 1) / 2
        return x


class Dis_Imagine(BaseNetwork):
    def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
        super(Dis_Imagine, self).__init__()
        self.use_sigmoid = use_sigmoid
        
        self.conv1 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv2 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv3 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv4 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv5 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
        )

        if init_weights:
            self.init_weights()
     
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        outputs = conv5
        if self.use_sigmoid:
            outputs = torch.sigmoid(conv5)

        return outputs, [conv1, conv2, conv3, conv4, conv5]
class SpiralNet(BaseNetwork):
    def __init__(self, out_size, device, ks=3, each=16):
        super(SpiralNet, self).__init__()

        self.ks, self.out_size, self.device, self.each = ks, out_size, device, each

        self.G = Gen(inc=3, outc=3, mid=64)

        self.init_weights()
        self.strid = self.each  # must divide by 4

        self.wmax, self.hmax = self.out_size[0], self.out_size[1]

    def cat_sf(self, x, sfdata, p, direct):
        h, w = x.shape[2], x.shape[3]

        if direct == 0:
            crop_sf = sfdata[:, :, p[1]:p[1] + h, p[0]:p[0] + w]
            slice_in = torch.cat((crop_sf, x), dim=3)
            return crop_sf
        elif direct == 1:
            crop_sf = sfdata[:, :, p[1]:p[1] + h, p[0]:p[0] + w]
            slice_in = torch.cat((crop_sf, x), dim=2)
            return crop_sf
        elif direct == 2:
            crop_sf = sfdata[:, :, p[1] - h:p[1], p[0] - w:p[0]]
            slice_in = torch.cat((x, crop_sf), dim=3)
            return crop_sf
        elif direct == 3:
            crop_sf = sfdata[:, :, p[1] - h:p[1], p[0] - w:p[0]]
            slice_in = torch.cat((x, crop_sf), dim=2)
            return crop_sf

    def sliceOperator_1(self, x, direct):
        # direct => l/u/r/d: 0,1,2,3
        h, w = x.shape[2], x.shape[3]
        if direct == 0:
            return x[:, :, :, :self.each]
        if direct == 1:
            return x[:, :, :self.each, :]
        if direct == 2:
            return x[:, :, :, -self.each:]
        if direct == 3:
            return x[:, :, -self.each:, :]

    def extrapolateOperator(self, x, direct, conv_data, loc):
        cat_edge = self.each
        if direct == 0:
            if loc < self.each:
                cat_edge = loc - 0
            x = torch.cat((conv_data[:, :, :, :cat_edge], x), dim=3)
        if direct == 1:
            if loc < self.each:
                cat_edge = loc - 0
            x = torch.cat((conv_data[:, :, :cat_edge, :], x), dim=2)
        if direct == 2:
            if loc > (self.wmax - self.each):
                cat_edge = self.wmax - loc
            x = torch.cat((x, conv_data[:, :, :, -cat_edge:]), dim=3)
        if direct == 3:
            if loc > (self.hmax - self.each):
                cat_edge = self.hmax - loc
            x = torch.cat((x, conv_data[:, :, -cat_edge:, :]), dim=2)

        return x

    def sliceOperator_from(self, x, direct, gt, coord):
        h, w = x.shape[2], x.shape[3]
        x1, y1, x2, y2 = coord[0], coord[1], coord[2], coord[3]
        if direct == 0:
            return gt[:, :, y1:y1 + h, x1:x1 + self.each]
        if direct == 1:
            return gt[:, :, y1:y1 + self.each, x1:x1 + w]
        if direct == 2:
            return gt[:, :, y2 - h:y2, x2 - self.each:x2]
        if direct == 3:
            return gt[:, :, y2 - self.each:y2, x1:x1 + w]

    def sliceGenerator(self, stg, coord, inits, subimage, fm, each, direct=0, post_train=False):
        iner = subimage
        if direct == 0:
            if coord[0] > 0:
                loc_x1 = coord[0]
                coord[0] -= each
                coord[0] = coord[0] if coord[0] > 0 else 0
                dir = [1, 0, 0, 0]
                stg = stg + dir
                c_left = self.sliceOperator_1(iner, direct)
                if not post_train:
                    styslice = self.sliceOperator_from(c_left, direct, gt, coord)
                else:
                    styslice = self.sliceOperator_1(c_left, direct)
                slice_in = self.cat_sf(c_left, fm, (coord[0], coord[1]), direct)
                l = self.G(slice_in, styslice, inits, stg)
                iner = self.extrapolateOperator(iner, direct, l, loc_x1)
                return iner, coord[0], coord[1]

            else:
                return iner, coord[0], coord[1]

        elif direct == 1:
            if coord[1] > 0:
                loc_y1 = coord[1]
                coord[1] -= each
                coord[1] = coord[1] if coord[1] > 0 else 0
                dir = [0, 1, 0, 0]
                stg = stg + dir
                c_up = self.sliceOperator_1(iner, direct)
                if not post_train:
                    styslice = self.sliceOperator_from(c_up, direct, gt, coord)
                else:
                    styslice = self.sliceOperator_1(c_up, direct)
                slice_in = self.cat_sf(c_up, fm, (coord[0], coord[1]), direct)
                u = self.G(slice_in, styslice, inits, stg)
                iner = self.extrapolateOperator(iner, direct, u, loc_y1)
                return iner, coord[0], coord[1]
            else:
                return iner, coord[0], coord[1]

        elif direct == 2:
            if coord[2] < self.wmax:
                loc_x2 = coord[2]
                coord[2] += each
                coord[2] = coord[2] if coord[2] < self.wmax else self.wmax
                dir = [0, 0, 1, 0]
                stg = stg + dir
                c_right = self.sliceOperator_1(iner, direct)
                if not post_train:
                    styslice = self.sliceOperator_from(c_right, direct, gt, coord)
                else:
                    styslice = self.sliceOperator_1(c_right, direct)
                slice_in = self.cat_sf(c_right, fm, (coord[2], coord[3]), direct)
                r = self.G(slice_in, styslice, inits, stg)
                iner = self.extrapolateOperator(iner, direct, r, loc_x2)
                return iner, coord[2], coord[3]
            else:
                return iner, coord[2], coord[3]

        elif direct == 3:
            if coord[3] < self.hmax:
                loc_y2 = coord[3]
                coord[3] += each
                coord[3] = coord[3] if coord[3] < self.hmax else self.hmax
                dir = [0, 0, 0, 1]
                stg = stg + dir
                c_down = self.sliceOperator_1(iner, direct)
                if not post_train:
                    styslice = self.sliceOperator_from(c_down, direct, gt, coord)
                else:
                    styslice = self.sliceOperator_1(c_down, direct)
                slice_in = self.cat_sf(c_down, fm, (coord[2], coord[3]), direct)
                d = self.G(slice_in, styslice, inits, stg)
                iner = self.extrapolateOperator(iner, direct, d, loc_y2)
                return iner, coord[2], coord[3]
            else:
                return iner, coord[2], coord[3]

    def forward(self, x, gt, fm, position, stage, post_train=False):
        x1, y1, x2, y2 = position[0][0].item(), position[0][1].item(), position[1][0].item(), position[1][1].item()
        inits = x
        coord_all = []
        post_train = post_train

        for st in range(int(stage)):
            gen = x
            stg = [0, 0, 0, 0]
            gen, x1, y1 = self.sliceGenerator(stg, [x1, y1, x2, y2], inits, gen, fm, self.each, direct=0,
                                        post_train=post_train)
            coord_all.append([x1, y1, x2, y2])
            gen, x1, y1 = self.sliceGenerator(stg, [x1, y1, x2, y2], inits, gen, fm, self.each, direct=1,
                                        post_train=post_train)
            coord_all.append([x1, y1, x2, y2])
            gen, x2, y2 = self.sliceGenerator(stg, [x1, y1, x2, y2], inits, gen, fm, self.each, direct=2,
                                        post_train=post_train)
            coord_all.append([x1, y1, x2, y2])
            gen, x2, y2 = self.sliceGenerator(stg, [x1, y1, x2, y2], inits, gen, fm, self.each, direct=3,
                                        post_train=post_train)
            coord_all.append([x1, y1, x2, y2])

            x = gen

        return x, coord_all

In [28]:
# ImagineNet()
Dis_Imagine(6)

Dis_Imagine(
  (conv1): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv5): Sequential(
    (0): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  )
)

In [33]:
from torchsummary import summary
net = ImagineNet()
# net = Dis_Imagine(6)
summary(net, (3, 66, 66))
# summary(net, (6, 66, 66))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1            [-1, 3, 72, 72]               0
            Conv2d-2           [-1, 64, 66, 66]           9,472
    InstanceNorm2d-3           [-1, 64, 66, 66]               0
              ReLU-4           [-1, 64, 66, 66]               0
            Conv2d-5          [-1, 128, 33, 33]          73,856
    InstanceNorm2d-6          [-1, 128, 33, 33]               0
              ReLU-7          [-1, 128, 33, 33]               0
            Conv2d-8          [-1, 256, 17, 17]         295,168
    InstanceNorm2d-9          [-1, 256, 17, 17]               0
             ReLU-10          [-1, 256, 17, 17]               0
  ReflectionPad2d-11          [-1, 256, 21, 21]               0
           Conv2d-12          [-1, 256, 17, 17]         590,080
   InstanceNorm2d-13          [-1, 256, 17, 17]               0
             ReLU-14          [-1, 256,