In [1]:
%run Net_block.ipynb

class Net(object):

    @staticmethod
    def parse_arch(arch):
        net_depth  = len(arch) - 1
        layer_list = [None] * net_depth
        param_list = [None] * net_depth
        
        input_dims = arch[0][1]
        for i in range(1, net_depth + 1):
            layer_list[i-1] = arch[i][0]
            param_list[i-1] = arch[i][1]
        return input_dims, layer_list, param_list
    
    @staticmethod
    def parse_layers(arch):
        input_dims, layer_list, param_list = Net.parse_arch(arch)
        net_depth = len(layer_list)
        layers = []
        for i in range(net_depth):
            layers.append(Block_mapping.module_mapping[layer_list[i]](param_list[i]))
        return input_dims, layers
    
    @staticmethod
    def init_weights(net, init_type='normal', init_gain=0.02):
        """Initialize network weights.
        Parameters:
            net (network)   -- network to be initialized
            init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
            init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
        """
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
                nn.init.normal_(m.weight.data, 1.0, init_gain)
                nn.init.constant_(m.bias.data, 0.0)
                
        print('initialize network with %s' % init_type)
        net.apply(init_func)  # apply the initialization function <init_func>
    
class Network_template(nn.Module):
    
    def __init__(self, ngpu, arch):
        super(Network_template, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(*nn.ModuleList(arch))

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        #return torch.squeeze(output)
        return output
    
class StandardLoss(nn.Module):
    
    def __init__(self, mode, reduction):
        super(StandardLoss, self).__init__()
        self.loss_mode = mode
        if mode == 'xentropy':
            self.criterion = nn.CrossEntropyLoss(reduction=reduction)
        else:
            raise NotImplementedError('loss mode %s not implemented' % mode)
            
    def xentropy_loss(self, params):
        # params: [0]net, [1]Dtrain, [2]labels
        assert(len(params) == 3)
        predictions = params[0](params[1])
        loss = self.criterion(predictions, params[2])
        return loss
    
    def __call__(self, params):
        if self.loss_mode == 'xentropy':
            loss = self.xentropy_loss(params)
        return loss
    
class GANLoss(nn.Module):
    
    def __init__(self, gan_mode, reduction):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(1.0))
        self.register_buffer('fake_label', torch.tensor(0.0))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.criterion = nn.MSELoss(reduction=reduction)
        elif gan_mode == 'vanilla' or gan_mode == "vanilla_topo":
            self.criterion = nn.BCEWithLogitsLoss(reduction=reduction)
        elif gan_mode == "wgangp" or gan_mode == "wgan":
            pass
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)
            
    def get_target_tensor(self, prediction, is_real):
        if is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)
    
    def calc_gradient_penalty(self, netD, device, Dreal, Dfake, constant=1.0, lambda_gp=10.0):
        batch_size = Dreal.shape[0]
        alpha = torch.rand(batch_size, 1, 1, 1, device=device)
        alpha = alpha.expand_as(Dreal)
        interpolated = alpha * Dreal + (1 - alpha) * Dfake
        interpolated.requires_grad_(True)
        out_interp = netD(interpolated)
        gradients = torch.autograd.grad(outputs=out_interp, inputs=interpolated,
                    grad_outputs=torch.ones(out_interp.size()).to(device),
                    create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients  = gradients.view(batch_size, -1)
        grad_norm  = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)
        gp_penalty = ((grad_norm - constant)**2).mean()
        return gp_penalty * lambda_gp
    
    def vanilla_topo_loss(self, params):
        '''
        params: Dfake_device, Dfix_device
        both Dfake_device and Dfix_device should be output from tanh() layer,
        which have values between -1.0 and 1.0.
        '''
        image_shape = list(params[0].shape)
        image_shape = image_shape[-2:]
        flat_fake = params[0].view(-1, np.prod(image_shape))
        flat_fix  = params[1].view(-1, np.prod(image_shape))
        topo_err  = self.criterion(flat_fake, flat_fix)
        return topo_err
    
    def wgan_gp_loss(self, params):
        if params[0] == "G":
            # if G params[1]: netD, params[2]: Dfake
            assert(len(params) == 3)
            out_fake = params[1](params[2])
            err_fake = -out_fake.mean()
            return err_fake
        elif params[0] == "D":
            # if D params[1]: netD, params[2]: device, params[3]: Dreal, params[4]: Dfake
            assert(len(params) >= 5)
            out_real = params[1](params[3])
            out_fake = params[1](params[4].detach())
            gp_penalty = self.calc_gradient_penalty(params[1], params[2], params[3], params[4].detach())
            return out_fake.mean() - out_real.mean() + gp_penalty
        else:
            raise NotImplementedError('Unrecognized network' % params[0])
            
    def wgan_loss(self, params):
        if params[0] == "G":
            # if G params[1]: netD, params[2]: Dfake
            assert(len(params) == 3)
            out_fake = params[1](params[2])
            err_fake = -out_fake.mean()
            return err_fake
        elif params[0] == "D":
            # if D params[1]: netD, params[2]: device, params[3]: Dreal, params[4]: Dfake
            assert(len(params) >= 5)
            out_real = params[1](params[3])
            out_fake = params[1](params[4].detach())
            return out_fake.mean() - out_real.mean()
        else:
            raise NotImplementedError('Unrecognized network' % params[0])
            
    def gan_loss(self, params):
        if params[0] == "G":
            # if G params[1]: netD, params[2]: Dfake
            assert(len(params) == 3)
            out_fake      = params[1](params[2])
            target_tensor = self.get_target_tensor(out_fake, True)
            err_fake      = self.criterion(out_fake, target_tensor)
            return err_fake
        elif params[0] == "D":
            # if D params[1]: netD, params[2]: device, params[3]: Dreal, params[4]: Dfake
            assert(len(params) >= 5)
            out_real      = params[1](params[3])
            out_fake      = params[1](params[4].detach())
            target_tensor = self.get_target_tensor(out_real, True)
            err_real      = self.criterion(out_real, target_tensor)
            target_tensor = self.get_target_tensor(out_fake, False)
            err_fake      = self.criterion(out_fake, target_tensor)
            return err_real + err_fake
        else:
            raise NotImplementedError('Unrecognized network' % params[0])
        
    def __call__(self, params):
        if self.gan_mode in ['lsgan', 'vanilla']:
            loss = self.gan_loss(params)
        elif self.gan_mode == 'vanilla_topo':
            loss = self.vanilla_topo_loss(params)
        elif self.gan_mode == 'wgangp':
            loss = self.wgan_gp_loss(params)
        elif self.gan_mode == "wgan":
            loss = self.wgan_loss(params)
        return loss