In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1 or  classname.find('InstanceNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def get_norm_layer(norm_type):
    if norm_type == 'batch':
        norm_layer = nn.BatchNorm2d
    elif norm_type == 'instance':
        norm_layer = nn.InstanceNorm2d
    else:
        print('normalization layer [%s] is not found' % norm)
    return norm_layer

# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
        super(UnetGenerator, self).__init__()
        self.gpu_ids = gpu_ids

        # currently support only input_nc == output_nc
        assert(input_nc == output_nc)

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        # embedded = self.embedding(input).view(1, 1, -1)
        # embedded.size
        if  self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost

        downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc, affine=True)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc, affine=True)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([self.model(x), x], 1)


In [132]:
# Defines the Unet+MobileNet generator.

class UnetMobileNetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
        super(UnetMobileNetGenerator, self).__init__()
        self.gpu_ids = gpu_ids

        # currently support only input_nc == output_nc
        assert(input_nc == output_nc)

        # construct unet structure
        unet_block = UnetMobileNetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetMobileNetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetMobileNetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
        unet_block = UnetMobileNetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
        unet_block = UnetMobileNetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
        unet_block = UnetMobileNetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        # embedded = self.embedding(input).view(1, 1, -1)
        # embedded.size
        if  self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

# Defines the submodule with skip connection.
class UnetMobileNetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetMobileNetSkipConnectionBlock, self).__init__()
        self.outermost = outermost

        #downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
        #                     stride=2, padding=1)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc, affine=True)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc, affine=True)

        if outermost: # Equal to call conv_bn
            downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=True)
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv, downnorm] # different from original setup
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost: # Equal to call conv_dw without submodule            
            #downconv_d = nn.Conv2d(outer_nc, outer_nc, kernel_size=4,
            #                 stride=2, padding=1)
            downconv_d = nn.Conv2d(outer_nc, outer_nc, kernel_size=4,
                             stride=2, padding=1, groups=outer_nc, bias=True)
            downconv_s = nn.Conv2d(outer_nc, inner_nc, kernel_size=1,
                             stride=1, padding=0, bias=True)
            
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            upconv_d = nn.ConvTranspose2d(inner_nc, inner_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, groups=inner_nc)
            upconv_s = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=1, stride=1,
                                        padding=0)
            upnorm_d = norm_layer(inner_nc, affine=True)
            
            #down = [downrelu, downconv_d, downnorm, downrelu, downconv_s, downnorm]
            down = [downrelu, downconv_d, downnorm, downrelu, downconv_s, downnorm]
            up = [uprelu, upconv_d, upnorm_d, uprelu, upconv_s, upnorm]
            #up = [uprelu, upconv, upnorm]
            model = down + up
        else: # Equal to call conv_dw            
            downconv_d = nn.Conv2d(outer_nc, outer_nc, kernel_size=4,
                             stride=2, padding=1, groups=outer_nc, bias=True)
            downconv_s = nn.Conv2d(outer_nc, inner_nc, kernel_size=1,
                             stride=1, padding=0, bias=True)
            downnorm_d = norm_layer(outer_nc, affine=True)
            
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            
            
            upconv_d = nn.ConvTranspose2d(inner_nc * 2, inner_nc * 2,
                                        kernel_size=4, stride=2,
                                        padding=1, groups=inner_nc * 2)
            upconv_s = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=1, stride=1,
                                        padding=0)
            upnorm_d = norm_layer(inner_nc * 2, affine=True)
            
            down = [downrelu, downconv_d, downnorm_d, downrelu, downconv_s, downnorm]
            up = [uprelu, upconv_d, upnorm_d, uprelu, upconv_s, upnorm]
            #up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def conv_bn(inp, oup, stride):
        return nn.Sequential(
            nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(inplace=True)
        )

    def conv_dw(inp, oup, stride):
        return nn.Sequential(
            nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
            nn.BatchNorm2d(inp),
            nn.ReLU(inplace=True),
    
            nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([self.model(x), x], 1)


In [21]:
# Defines the Unet+MobileNet V2 generator.

class UnetMobileNet2Generator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
        super(UnetMobileNet2Generator, self).__init__()
        self.gpu_ids = gpu_ids

        # currently support only input_nc == output_nc
        assert(input_nc == output_nc)

        # construct unet structure
        unet_block = UnetMobileNet2SkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
        
        #for i in range(num_downs - 5):
        #    unet_block = UnetMobileNet2SkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        #unet_block = UnetMobileNet2SkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
        #unet_block = UnetMobileNet2SkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
        #unet_block = UnetMobileNet2SkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
        #unet_block = UnetMobileNet2SkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        # embedded = self.embedding(input).view(1, 1, -1)
        # embedded.size
        if  self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

# Defines the submodule with skip connection.
class UnetMobileNet2SkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc,
                 submodule=None, outermost=False, innermost=False, expand_ratio=6, 
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetMobileNet2SkipConnectionBlock, self).__init__()
        self.outermost = outermost

        downrelu = nn.ReLU6(inplace=True)
        downnorm = norm_layer(inner_nc, affine=True)
        uprelu = nn.ReLU6(inplace=True)
        upnorm = norm_layer(outer_nc, affine=True)

        if outermost: # Equal to call conv_bn
            downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=True)
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv, downnorm] # different from original setup
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] # + up
        elif innermost: # Equal to call InvertedResidual without submodule            
            
            # pw
            downconv_p = nn.Conv2d(outer_nc, outer_nc*expand_ratio, kernel_size=1,
                                  stride=1, padding=0, bias=False)
            downnorm_p = nn.BatchNorm2d(outer_nc*expand_ratio)
            # dw
            downconv_d = nn.Conv2d(outer_nc*expand_ratio, outer_nc*expand_ratio, kernel_size=4,
                                  stride=2, padding=1, groups=outer_nc*expand_ratio, bias=False)
            downnorm_d = nn.BatchNorm2d(outer_nc*expand_ratio)
            # pw-linear
            downconv_l = nn.Conv2d(outer_nc*expand_ratio, inner_nc, kernel_size=1,
                                  stride=1, padding=0, bias=False)
            downnorm_l = nn.BatchNorm2d(inner_nc)
            
            # pw
            upconv_p = nn.ConvTranspose2d(inner_nc, inner_nc*expand_ratio,
                                        kernel_size=1, stride=1,
                                        padding=0, bias=False)
            upnorm_p = nn.BatchNorm2d(inner_nc*expand_ratio)
            # dw
            upconv_d = nn.ConvTranspose2d(inner_nc*expand_ratio, inner_nc*expand_ratio,
                                         kernel_size=4, stride=2,
                                         padding=1, groups=inner_nc*expand_ratio, bias=False)
            upnorm_d = nn.BatchNorm2d(inner_nc*expand_ratio)
            # pw-linear
            upconv_l = nn.ConvTranspose2d(inner_nc*expand_ratio, outer_nc,
                                         kernel_size=1, stride=1,
                                         padding=0, bias=False)
            upnorm_l = nn.BatchNorm2d(outer_nc) # Normalization can be switched to the norm_layer
            
            down = [downrelu, downconv_p, downnorm_p, downrelu, downconv_d, downnorm_d,
                    downrelu, downconv_l, downnorm_l]
            up = [uprelu, upconv_p, upnorm_p, uprelu, upconv_d, upnorm_d, 
                  uprelu, upconv_l, upnorm_l]
            model = down # + up
        else: # Equal to call InvertedResidual
            
            # pw
            downconv_p = nn.Conv2d(outer_nc, outer_nc*expand_ratio, kernel_size=1,
                             stride=1, padding=0, bias=False)
            downnorm_p = nn.BatchNorm2d(outer_nc*expand_ratio)
            # dw
            downconv_d = nn.Conv2d(outer_nc*expand_ratio, outer_nc*expand_ratio, kernel_size=4,
                                   stride=2, padding=1, groups=outer_nc*expand_ratio, bias=False)
            downnorm_d = nn.BatchNorm2d(outer_nc*expand_ratio)
            # pw-linear
            downconv_l = nn.Conv2d(outer_nc*expand_ratio, inner_nc, kernel_size=1,
                             stride=1, padding=0, bias=False)
            downnorm_l = nn.BatchNorm2d(inner_nc)
            
            # pw
            upconv_p = nn.ConvTranspose2d(inner_nc*2, inner_nc*2*expand_ratio, kernel_size=1,
                                         stride=1, padding=0, bias=False)
            upnorm_p = nn.BatchNorm2d(inner_nc*2*expand_ratio)
            # dw
            upconv_d = nn.ConvTranspose2d(inner_nc*2*expand_ratio, inner_nc*2*expand_ratio, kernel_size=4,
                                          stride=2, padding=1, groups=inner_nc*2*expand_ratio, bias=False)
            upnorm_d = nn.BatchNorm2d(inner_nc*2*expand_ratio)
            # pw-linear
            upconv_l = nn.ConvTranspose2d(inner_nc*2*expand_ratio, outer_nc, kernel_size=1,
                                         stride=1, padding=0, bias=False)
            upnorm_l = nn.BatchNorm2d(outer_nc)
            
            down = [downrelu, downconv_p, downnorm_p, downrelu, downconv_d, downnorm_d, 
                    downrelu, downconv_l, downnorm_l]
            up = [uprelu, upconv_p, upnorm_p, uprelu, upconv_d, upnorm_d,
                  uprelu, upconv_l, upnorm_l]

            if use_dropout:
                model = down + [submodule] # + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] # + up

        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([self.model(x), x], 1)


In [5]:
# model_G = define_G(3, 3, 64, which_model_netG='unet_256')
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]):
    netG = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert(torch.cuda.is_available())

    if which_model_netG == 'resnet_9blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
    elif which_model_netG == 'resnet_6blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_128':
        netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_256':
        netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_mobilenet_256':
        netG = UnetMobileNetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_mobilenet2_256':
        netG = UnetMobileNet2Generator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
    else:
        print('Generator model name [%s] is not recognized' % which_model_netG)
    if len(gpu_ids) > 0:
        netG.cuda(device_id=gpu_ids[0])
    netG.apply(weights_init)
    return netG


def define_D(input_nc, ndf, which_model_netD,
             n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]):
    netD = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert(torch.cuda.is_available())
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    else:
        print('Discriminator model name [%s] is not recognized' %
              which_model_netD)
    if use_gpu:
        netD.cuda(device_id=gpu_ids[0])
    netD.apply(weights_init)
    return netD



In [6]:
def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)
    

In [7]:
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

In [24]:
model_G = define_G(3, 3, 64, which_model_netG='unet_256')

In [133]:
mobile_model_G = define_G(3, 3, 64, which_model_netG='unet_mobilenet_256')

In [45]:
print_network(model_G)

UnetGenerator(
  (model): UnetSkipConnectionBlock(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): UnetSkipConnectionBlock(
        (model): Sequential(
          (0): LeakyReLU(0.2, inplace)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
          (3): UnetSkipConnectionBlock(
            (model): Sequential(
              (0): LeakyReLU(0.2, inplace)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
              (3): UnetSkipConnectionBlock(
                (model): Sequential(
                  (0): LeakyReLU(0.2, inplace)
                  (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
                  (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
                  

In [134]:
print_network(mobile_model_G)

UnetMobileNetGenerator(
  (model): UnetMobileNetSkipConnectionBlock(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (2): UnetMobileNetSkipConnectionBlock(
        (model): Sequential(
          (0): LeakyReLU(0.2, inplace)
          (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), groups=64)
          (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
          (3): LeakyReLU(0.2, inplace)
          (4): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
          (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
          (6): UnetMobileNetSkipConnectionBlock(
            (model): Sequential(
              (0): LeakyReLU(0.2, inplace)
              (1): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), groups=128)
              (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
              (3):

In [22]:
mobile2_model_G = define_G(3, 3, 64, which_model_netG='unet_mobilenet2_256')

In [23]:
print_network(mobile2_model_G)

UnetMobileNet2Generator(
  (model): UnetMobileNet2SkipConnectionBlock(
    (model): Sequential(
      (0): ReLU6(inplace)
      (1): Conv2d(512, 3072, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True)
      (3): ReLU6(inplace)
      (4): Conv2d(3072, 3072, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), groups=3072, bias=False)
      (5): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True)
      (6): ReLU6(inplace)
      (7): Conv2d(3072, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (8): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    )
  )
)
Total number of parameters: 3208192
