# 使用条件对抗网络的图像到图像转换—Pix2Pix

 `GPU` `进阶` `计算机视觉` `全流程`



## Pix2Pix概述

Pix2Pix是基于条件生成对抗网络（cGAN, Condition Generative Adversarial Networks ）实现的一种深度学习图像转换模型，该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作，其包括两个模型：**发生器**和**鉴别器**。cGAN的生成器与传统GAN的生成器在原理上有一些区别，cGAN的生成器是将输入图片作为指导信息，由输入图像不断尝试生成用于迷惑鉴别器的“假”图像，由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射，而传统GAN的生成器是基于一个给定的随机噪声生成图像，输出图像通过其他约束条件控制生成，这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中鉴别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与鉴别器的不断博弈过程中，模型会达到一个平衡点，生成器输出的图像与真实训练数据使得鉴别器刚好具有50%的概率判断正确。

在教程开始前，首先定义一些在整个过程中需要用到的符号：

- $x$：代表观测图像的数据。
- $z$：代表随机噪声的数据。
- $y=G(x,z)$：生成器网络，给出由观测图像$x$与随机噪声$z$生成的“假”图片，其中$x$来自于训练数据而非生成器。
- $D(x,G(x,z))$：鉴别器网络，给出图像判定为真实图像的概率，其中$x$来自于训练数据，$G(x,z)$来自于生成器。

cGAN的目标可以表示为：

$$L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))]$$

该公式是cGAN的损失函数，`D`想要尽最大努力去正确分类真实图像与“假”图像，也就是使参数$log D(x,y)$最大化；而`G`则尽最大努力用生成的“假”图像$y$欺骗`D`，避免被识破，也就是使参数$log(1−D(G(x,z)))$最小化。cGAN的目标可简化为：

$$arg\min_{G}\max_{D}L_{cGAN}(G,D)$$

![1.png](./images/1.png)

为了对比cGAN和GAN的不同，我们将GAN的目标也进行了说明：

$$L_{GAN}(G,D)=E_{y}[log(D(y))]+E_{(x,z)}[log(1-D(x,z))]$$


从公式可以看出，GAN直接由随机噪声$z$生成“假”图像，不借助观测图像$x$的任何信息。过去的经验告诉我们，GAN与传统损失混合使用是有好处的，鉴别器的任务不变，依旧是区分真实图像与“假”图像，但是生成器的任务不仅要欺骗鉴别器，还要在传统损失的基础上接近训练数据。假设cGAN与L1正则化混合使用，那么有:

$$L_{L1}(G)=E_{(x,y,z)}[||y-G(x,z)||_{1}]$$

进而得到最终目标：

$$arg\min_{G}\max_{D}L_{cGAN}(G,D)+\lambda L_{L1}(G)$$

图像转换问题本质上其实就是像素到像素的映射问题，pix2pix使用完全一样的网络结构和目标函数，仅更换不同的训练数据集就能分别实现以上的任务。本任务将借助MindSpore框架来实现pix2pix的应用。

## 准备环节

### 配置环境文件

本教程我们在GPU环境下，使用图模式运行实验。

```context.set_context(mode=context.GRAPH_MODE, device_target="GPU")```

### 准备数据

在本教程中，我们将使用[指定数据集](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/)，该数据集为六种数据集，分别为外墙（facades）、市景（cityscapes）、地图（maps）、昼夜图（night2day）、鞋子图片及对应线条图（edges2shoes）、手包图片及对应线条图（edges2handbags）。其中facades有606张图片，cityscapes有3475张图片，maps有2194张图片，night2day有20120张图片，edges2shoes有50025张图片，edges2handbags有138767张图片。

每类数据集均保存于./data/datasets文件夹下，如./data/datasets/maps等

## 数据处理

首先为执行过程定义一些配置参数：

In [1]:
def parse_args():  # some parameters
    parser = argparse.ArgumentParser(description='config')
    parser.add_argument('--train_data_dir', default='../data/maps/train/', type=str)
    parser.add_argument('--device_target', default='GPU', choices=['GPU', 'Ascend'], type=str)
    parser.add_argument('--train_fakeimg_dir', default='results/fake_img/', type=str)
    parser.add_argument('--loss_show_dir', default='results/loss_show', type=str)
    parser.add_argument('--ckpt_dir', default='results/ckpt', type=str)
    parser.add_argument('--epoch_num', default=200, type=int)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--beta1', default=0.5, type=float)
    parser.add_argument('--beta2', default=0.999, type=float)
    return parser.parse_args()

一些参数的说明  
train_data_dir：数据训练集文件路径；  
device_target：使用设备；  
train_fakeimg_dir：训练中存储的假图像文件路径；  
loss_show_dir：训练中存储损失图像文件路径；  
ckpt_dir：训练中存储checkpoint文件路径；  
epoch_num：训练迭代次数,需根据不同数据集具体设置；  
batch_size：训练中使用的批量大小,需根据不同数据集具体设置；  
beta1：Adam 优化器 beta1；  
beta2：Adam 优化器 beta2；  

论文中对训练轮次数`epoch_num`和输入图像批次`batch_size`根据其数据集不同建议设定不同。参考：  

epoch_num:  
facades：200  
cityscapes：200  
maps：200  
night2day：17  
edges2shoes：15  
edges2handbags：15  

batch_size：  
facades：1  
cityscapes：1  
maps：1  
night2day：4  
edges2shoes：4  
edges2handbags：4  

- 定义`Pix2PixDataset`和`create_train_dataset`函数对训练数据进行处理和增强操作。

In [None]:
import os
import numpy as np
from PIL import Image

import mindspore
from mindspore import dataset as ds
import mindspore.dataset.vision.c_transforms as transforms

from src.config.pix2pix_config import pix2pix_config as config


def get_params():
    """
    Get parameters from images.

    Return:
        x,y. get image size information.
    """

    new_h = new_w = config.load_size  # config.load_size

    x = np.random.randint(0, np.maximum(0, new_w - config.train_pic_size))
    y = np.random.randint(0, np.maximum(0, new_h - config.train_pic_size))

    return x, y

def crop(img, pos, size=config.train_pic_size):
    """
    Crop the images.

    Args:
        img (list): image.
        pos (int): crop position.
        size (int): train image size.

    Return:
        img. output img.
    """

    ow = oh = config.load_size
    x1, y1 = pos
    tw = th = size
    if ow > tw or oh > th:
        img = img.crop((x1, y1, x1 + tw, y1 + th))
        return img
    return img

def sync_random_horizontal_flip(input_images, target_images):
    """
    Randomly flip the input images and the target images.

    Args:
        input_images (list): input original image.
        target_images (list): output image after random horizontal flip.

   Return:
        out_input: random horizontal flip image.
        out_target: random horizontal flip image.
    """

    seed = np.random.randint(0, 2000000000)
    mindspore.set_seed(seed)
    op = transforms.RandomHorizontalFlip(prob=0.5)
    out_input = op(input_images)
    mindspore.set_seed(seed)
    op = transforms.RandomHorizontalFlip(prob=0.5)
    out_target = op(target_images)
    return out_input, out_target


class Pix2PixDataset():
    """
    Define train process_datasets.

    Args:
        root_dir(str): train dataset path.

    Outputs:
        a_crop. crop image a.
        b_crop. crop image b.
    """

    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        print(self.list_files)
        self.list_files.sort(key=lambda x: int(x[:-4]))
        print(self.list_files)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        ab = Image.open(img_path).convert('RGB')
        w, h = ab.size
        w2 = int(w / 2)

        a = ab.crop((w2, 0, w, h))
        b = ab.crop((0, 0, w2, h))

        a = a.resize((config.load_size, config.load_size))
        b = b.resize((config.load_size, config.load_size))

        transform_params = get_params()
        a_crop = crop(a, transform_params, size=config.train_pic_size)
        b_crop = crop(b, transform_params, size=config.train_pic_size)

        return a_crop, b_crop


def create_train_dataset(dataset, batch_size):
    """
    Create train process_datasets.

    Args:
        dataset (Class): image processed dataset.
        batch_size (int): train dataset size.

    Return:
        train dataset parameter.
    """

    mean = [0.5 * 255] * 3
    std = [0.5 * 255] * 3

    trans = [
        transforms.Normalize(mean=mean, std=std),
        transforms.HWC2CHW()
    ]

    train_ds = ds.GeneratorDataset(dataset, column_names=["input_images", "target_images"], shuffle=False)

    train_ds = train_ds.map(operations=[sync_random_horizontal_flip], input_columns=["input_images", "target_images"])

    train_ds = train_ds.map(operations=trans, input_columns=["input_images"])
    train_ds = train_ds.map(operations=trans, input_columns=["target_images"])

    train_ds = train_ds.batch(batch_size=batch_size, drop_remainder=True)

    return train_ds

- 调用`Pix2PixDataset`和`create_train_dataset`读取测试集

In [None]:
dataset = Pix2PixDataset(root_dir='../data/maps/test/')  #选择有图片的文件夹测试
ds = create_train_dataset(dataset, batch_size=config.batch_size)
print("ds:", ds.get_dataset_size())
print("ds:", ds.get_col_names())
print("ds.shape:", ds.output_shapes())

In [None]:
ds: 54      #测试文件夹里共有54张图片
ds: ['input_images', 'target_images']      #names
ds.shape: [[1, 3, 256, 256], [1, 3, 256, 256]]   #output_shapes

## 创建网络

当处理完数据后，就可以来进行网络的搭建了。网络搭建将逐一详细讨论生成器、判别器和损失函数。生成器G用到的是U-net结构，输入的轮廓图$x$编码再解码成真是图片，判别器D用到的是作者自己提出来的条件判别器PatchGAN，判别器D的作用是在轮廓图 $x$的条件下，对于生成的图片$G(x)$判断为假，对于真实判断为真。

### 生成器G结构

U-Net是德国Freiburg大学模式识别和图像处理组提出的一种全卷积结构。它分为两个部分，其中左侧是由卷积和降采样操作组成的压缩路径，右侧是由卷积和上采样组成的扩张路径，扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。网络模型整体是一个U形的结构，因此被叫做U-Net。和常见的先降采样到低维度，再升采样到原始分辨率的编解码结构的网络相比，U-Net的区别是加入skip-connection，对应的feature maps和decode之后的同样大小的feature maps按通道拼一起，用来保留不同分辨率下像素级的细节信息。
![2.png](./images/2.png)


#### 定义Unet Skip Connection Block

In [None]:
import mindspore.nn as nn
import mindspore.ops as ops


class UNetSkipConnectionBlock(nn.Cell):
    """
    Unet submodule with skip connection.

    Args:
        outer_nc (int): The number of filters in the outer conv layer.
        inner_nc (int): The number of filters in the inner conv layer.
        in_planes (int): The number of channels in input images/features.
        dropout (bool): Use dropout or not. Default: False.
        submodule (Cell): Previously defined submodules.
        outermost (bool): If this module is the outermost module.
        innermost (bool): If this module is the innermost module.
        alpha (float): LeakyRelu slope. Default: 0.2.
        norm_mode (str): Specifies norm method. The optional values are "batch", "instance".

    Outputs:
        Tensor, output tensor of Unet submodule.
    """

    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        downnorm = nn.BatchNorm2d(inner_nc)
        upnorm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            downnorm = nn.BatchNorm2d(inner_nc, affine=False)
            upnorm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        downconv = nn.Conv2d(in_planes, inner_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        downrelu = nn.LeakyReLU(alpha)
        uprelu = nn.ReLU()

        if outermost:
            upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, pad_mode='pad')
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.Conv2dTranspose(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(0.5))

        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost
        self.concat = ops.Concat(axis=1)

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = self.concat((out, x))
        return out

#### 基于Unet的生成器

In [7]:
import mindspore.nn as nn

from src.models.unet_block import UNetSkipConnectionBlock


class UNetGenerator(nn.Cell):
    """
    Unet based generator.

    Args:
        in_planes (int): the number of channels in input images.
        out_planes (int): the number of channels in output images.
        ngf (int): the number of filters in the last conv layer.Default: 64.
        n_layers (int): the number of downsamplings in UNet.Default: 8.
        norm_mode (str): Specifies norm method.
        dropout (bool): Use dropout or not. Default: False.

    Outputs:
        Tensor, output tensor.
    """

    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()

        # construct unet structure
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)

        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

原始cGAN的输入是条件x和噪声z两种信息，这里的生成器只使用了条件信息，因此不能生成多样性的结果。因此pix2pix在训练和测试时都使用了dropout，这样可以生成多样性的结果。

### 判别器

判别器使用的PatchGAN结构，可看做卷积。生成的矩阵中的每个点代表原图的一小块区域（patch）。通过矩阵中的各个值来判断原图中对应每个Patch的真假。

In [8]:
import mindspore.nn as nn
from mindspore.ops import Concat

from src.config.pix2pix_config import pix2pix_config as config


class ConvNormRelu(nn.Cell):
    """
    Convolution fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.

    Args:
        in_planes (int): Input channel.
        out_planes (int): Output channel.
        kernel_size (int): Input kernel size. Default: 4.
        stride (int): Stride size for the first convolutional layer. Default: 2.
        alpha (float): Slope of LackyReLU. Default: 0.2.
        norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
        pad_mode (str): Specifies padding mode. The optional values are CONSTANT, REFLECT, SYMMETRIC. Default: CONSTANT.
        use_relu (bool): Use relu or not. Default: True.
        padding (int): Pad size, if it is None, it will calculate by kernel_size. Default: None.

    Outputs:
        Tensor, output tensor of module layer.
    """

    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':    # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if config.pad_mode == 'REFLECT':
            pad_mode = "REFLECT"
        elif config.pad_mode == "SYMMETRIC":
            pad_mode = "SYMMETRIC"
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class Discriminator(nn.Cell):
    """
    Discriminator of Model.

    Args:
        in_planes (int): Input channel. Default: 3.
        ndf (int): the number of filters in the last conv layer. Default: 64.
        n_layers (int): The number of ConvNormRelu blocks. Default: 3.
        alpha (float): LeakyRelu slope. Default: 0.2.
        norm_mode (str): Specifies norm method. The optional values are "batch", "instance". Default: "batch".

    Outputs:
        Tensor, output tensor of discriminator of model.
    """

    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(config.alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))

        self.features = nn.SequentialCell(layers)
        self.concat = Concat(axis=1)

    def construct(self, x, y):
        x_y = self.concat((x, y))
        output = self.features(x_y)
        return output


### pix2pix的生成器和判别器初始化

In [None]:
import mindspore.nn as nn
from mindspore.common import initializer as init

from src.config.pix2pix_config import pix2pix_config as config
from src.models.generator import UNetGenerator
from src.models.discriminator import Discriminator


def get_generator():
    """
    Return a generator by args.

    Returns:
        net_generator. initialization generator network.
    """

    net_generator = UNetGenerator(in_planes=config.g_in_planes, out_planes=config.g_out_planes,
                                  ngf=config.g_ngf, n_layers=config.g_layers)
    for _, cell in net_generator.cells_and_names():
        if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
            if config.init_type == 'normal':
                cell.weight.set_data(init.initializer(init.Normal(config.init_gain), cell.weight.shape))
            elif config.init_type == 'xavier':
                cell.weight.set_data(init.initializer(init.XavierUniform(config.init_gain), cell.weight.shape))
            elif config.init_type == 'constant':
                cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % config.init_type)
        elif isinstance(cell, nn.BatchNorm2d):
            cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
            cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
    return net_generator

def get_discriminator():
    """
    Return a discriminator by args.

     Returns:
        net_discriminator. initialization discriminator network.
    """

    net_discriminator = Discriminator(in_planes=config.d_in_planes, ndf=config.d_ndf,
                                      alpha=config.alpha, n_layers=config.d_layers)
    for _, cell in net_discriminator.cells_and_names():
        if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
            if config.init_type == 'normal':
                cell.weight.set_data(init.initializer(init.Normal(config.init_gain), cell.weight.shape))
            elif config.init_type == 'xavier':
                cell.weight.set_data(init.initializer(init.XavierUniform(config.init_gain), cell.weight.shape))
            elif config.init_type == 'constant':
                cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % config.init_type)
        elif isinstance(cell, nn.BatchNorm2d):
            cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
            cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
    return net_discriminator


class Pix2Pix(nn.Cell):
    """
    pix2pix model network.

    Args:
        discriminator (Cell): a generator.
        generator (Cell): a discriminator.

    Inputs:
        -**reala** - generate real image information.

    Outputs:
        fakeb, a fake image information.
    """
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.netd = discriminator
        self.netg = generator

    def construct(self, reala):
        fakeb = self.netg(reala)
        return fakeb

实例化pix2pix生成器和判别器，并打印网络结构。

In [None]:
net_generator = get_generator()
net_discriminator = get_discriminator()
pix2pix = Pix2Pix(generator=net_generator, discriminator=net_discriminator)

#遍历每个参数，并打印网络各层名字和属性
for m in pix2pix.parameters_and_names():
    print(m)

('netd.features.0.weight', Parameter (name=netd.features.0.weight, shape=(64, 6, 4, 4), dtype=Float32, requires_grad=True))
('netd.features.2.features.0.weight', Parameter (name=netd.features.2.features.0.weight, shape=(128, 64, 4, 4), dtype=Float32, requires_grad=True))
('netd.features.2.features.1.moving_mean', Parameter (name=netd.features.2.features.1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False))
('netd.features.2.features.1.moving_variance', Parameter (name=netd.features.2.features.1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
('netd.features.2.features.1.gamma', Parameter (name=netd.features.2.features.1.gamma, shape=(128,), dtype=Float32, requires_grad=True))
('netd.features.2.features.1.beta', Parameter (name=netd.features.2.features.1.beta, shape=(128,), dtype=Float32, requires_grad=True))
('netd.features.3.features.0.weight', Parameter (name=netd.features.3.features.0.weight, shape=(256, 128, 4, 4), dtype=Float32, requires_grad=True))
('netd.features.3.features.1.moving_mean', Parameter (name=netd.features.3.features.1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False))
('netd.features.3.features.1.moving_variance', Parameter (name=netd.features.3.features.1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
('netd.features.3.features.1.gamma', Parameter (name=netd.features.3.features.1.gamma, shape=(256,), dtype=Float32, requires_grad=True))
('netd.features.3.features.1.beta', Parameter (name=netd.features.3.features.1.beta, shape=(256,), dtype=Float32, requires_grad=True))  
(结构太长，只展示部分)

### 连接网络和损失函数

MindSpore将损失函数、优化器等操作都封装到了Cell中，因为GAN结构上的特殊性，其损失是判别器和生成器的多输出形式，这就导致它和一般的分类网络不同。所以我们需要自定义`WithLossCell`类，将网络和Loss连接起来。

In [12]:
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops import functional as opsf
import mindspore.ops.operations as opsp
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, _get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.nn.loss.loss import LossBase

from src.config.pix2pix_config import pix2pix_config as config


class SigmoidCrossEntropyWithLogits(LossBase):
    """
    Defining Sigmoid Cross Entropy Loss as Loss Function.

    Inputs:
        -**data** (Tensor) - Tensor of image data.
        -**label** (Tensor) - Tensor of image label.

    Outputs:
        SigmoidCrossEntropy loss function.
    """

    def __init__(self):
        super(SigmoidCrossEntropyWithLogits, self).__init__()
        self.cross_entropy = opsp.SigmoidCrossEntropyWithLogits()

    def construct(self, data, label):
        x = self.cross_entropy(data, label)
        return self.get_loss(x)

class LossD(LossBase):
    """
    Define discriminator loss

    args:
        reduction (str): Return loss of the samples. Default: "mean".

    Inputs:
        -**pred1** (Tensor) - predict image1.
        -**pred0** (Tensor) - predict image0.

    Outputs:
        discriminator loss.
    """

    def __init__(self, reduction="mean"):    # Return the averaging loss of the samples
        super(LossD, self).__init__(reduction)
        self.sig = SigmoidCrossEntropyWithLogits()
        self.ones = ops.OnesLike()
        self.zeros = ops.ZerosLike()
        self.lambda_dis = config.lambda_dis

    def construct(self, pred1, pred0):
        loss = self.sig(pred1, self.ones(pred1)) + self.sig(pred0, self.zeros(pred0))
        dis_loss = loss * self.lambda_dis
        return dis_loss


class WithLossCellD(nn.Cell):
    """
    Define WithLossCellD to connect the network and Loss.

    Args:
        backbone (Cell): backbone of loss network.
        loss_fn (Cell): init loss function.

    Inputs:
        -**reala** (Tensor) - real label a.
        -**realb** (Tensor) - real label b.

    Outputs:
        connected loss function.
    """

    def __init__(self, backbone, loss_fn):
        super(WithLossCellD, self).__init__(auto_prefix=True)
        self.net_discriminator = backbone.net_discriminator
        self.net_generator = backbone.net_generator
        self._loss_fn = loss_fn

    def construct(self, reala, realb):
        fakeb = self.net_generator(reala)
        pred1 = self.net_discriminator(reala, realb)
        pred0 = self.net_discriminator(reala, fakeb)
        return self._loss_fn(pred1, pred0)


class LossG(LossBase):
    """
    Define generator loss.

    Inputs:
        -**fakeb** (Tensor) - generate fake image.
        -**realb** (Tensor) - real image.
        -**pred0** (Tensor) - predict image.

    Outputs:
        generator loss.
    """

    def __init__(self, reduction="mean"):   #reduction="mean": Return the averaging loss of the samples
        super(LossG, self).__init__(reduction)
        self.sig = SigmoidCrossEntropyWithLogits()
        self.l1_loss = nn.L1Loss()
        self.ones = ops.OnesLike()
        self.lambda_gan = config.lambda_gan
        self.lambda_l1 = config.lambda_l1

    def construct(self, fakeb, realb, pred0):
        loss_1 = self.sig(pred0, self.ones(pred0))
        loss_2 = self.l1_loss(fakeb, realb)
        loss = loss_1 * self.lambda_gan + loss_2 * self.lambda_l1
        return loss


class WithLossCellG(nn.Cell):
    """
    Define WithLossCellG to connect the network and Loss.

    Args:
        backbone (Cell): backbone of loss network.
        loss_fn (Cell): init loss function.

    Inputs:
        -**reala** (Tensor) - real label a.
        -**realb** (Tensor) - real label b.

    Outputs:
        connected loss function.
    """

    def __init__(self, backbone, loss_fn):
        super(WithLossCellG, self).__init__(auto_prefix=True)
        self.net_discriminator = backbone.net_discriminator
        self.net_generator = backbone.net_generator
        self._loss_fn = loss_fn

    def construct(self, reala, realb):
        fakeb = self.net_generator(reala)
        pred0 = self.net_discriminator(reala, fakeb)
        return self._loss_fn(fakeb, realb, pred0)


class TrainOneStepCell(nn.Cell):
    """
    Define TrainOneStepCell to encapsulate the training of the discriminator and generator together.

    Args:
        loss_netd (Cell): loss network of discriminator.
        loss_netg (Cell): loss network of generator.
        optimizerd (Union[Cell]): optimizer that updates discriminator network parameters.
        optimizerg (Union[Cell]): optimizer that updates generator network parameters.
        sens (numbers.Number): Input to backpropagation, scaling factor. Default: 1.
        auto_prefix (bool): whether auto prefix. Default: True.

    Inputs:
        -**reala** (Tensor) - real label a.
        -**realb** (Tensor) - real label b.

    Outputs:
        d_res, train generator out output.
        g_res, train discriminator output.
    """

    def __init__(self, loss_netd, loss_netg, optimizerd, optimizerg, sens=1, auto_prefix=True):
        super(TrainOneStepCell, self).__init__(auto_prefix=auto_prefix)
        self.loss_net_d = loss_netd
        self.loss_net_d.set_grad()
        self.loss_net_d.add_flags(defer_inline=True)

        self.loss_net_g = loss_netg
        self.loss_net_g.set_grad()
        self.loss_net_g.add_flags(defer_inline=True)

        self.weights_g = optimizerg.parameters
        self.optimizerg = optimizerg
        self.weights_d = optimizerd.parameters
        self.optimizerd = optimizerd

        self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
        self.sens = sens

        # Parallel processing
        self.reducer_flag = False
        self.grad_reducer_g = opsf.identity
        self.grad_reducer_d = opsf.identity
        self.parallel_mode = _get_parallel_mode()
        if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
            self.reducer_flag = True
        if self.reducer_flag:
            mean = _get_gradients_mean()
            degree = _get_device_num()
            self.grad_reducer_g = DistributedGradReducer(self.weights_g, mean, degree)
            self.grad_reducer_d = DistributedGradReducer(self.weights_d, mean, degree)

    def set_sens(self, value):
        self.sens = value

    def construct(self, reala, realb):
        """Define TrainOneStepCell."""
        d_loss = self.loss_net_d(reala, realb)
        g_loss = self.loss_net_g(reala, realb)
        d_sens = ops.Fill()(ops.DType()(d_loss), ops.Shape()(d_loss), self.sens)
        d_grads = self.grad(self.loss_net_d, self.weights_d)(reala, realb, d_sens)
        d_res = ops.depend(d_loss, self.optimizerd(d_grads))
        g_sens = ops.Fill()(ops.DType()(g_loss), ops.Shape()(g_loss), self.sens)
        g_grads = self.grad(self.loss_net_g, self.weights_g)(reala, realb, g_sens)
        g_res = ops.depend(g_loss, self.optimizerg(g_grads))
        return d_res, g_res

### 训练

训练分为两个主要部分：训练判别器和训练生成器。

- 训练判别器

   训练判别器的目的是最大程度地提高判别图像真伪的概率。希望通过提高其随机梯度来更新判别器，所以我们要最大化$log D(x) + log(1 - D(G(z))$的值。

- 训练生成器

   希望通过最小化$log(1 - D(G(z)))$来训练生成器，以产生更好的虚假图像。

   在这两个部分中，分别获取训练过程中的损失，并在每个周期结束时进行统计。

下面进行训练：

In [None]:
# 导入各模块
import datetime

import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor

from models.loss import WithLossCellD, LossD, WithLossCellG, LossG, TrainOneStepCell
from models.pix2pix import Pix2Pix, get_generator, get_discriminator
from process_datasets.dataset import Pix2PixDataset, create_train_dataset
from utils.tools import get_lr
from utils.device_adapter import get_device_num
from src.config.pix2pix_config import pix2pix_config as arg

In [None]:
# 配置信息
device_num = get_device_num()
context.set_context(mode=context.GRAPH_MODE)

# 预处理数据以进行训练
dataset = Pix2PixDataset(root_dir=arg.train_data_dir)
ds = create_train_dataset(dataset, batch_size=arg.batch_size)
steps_per_epoch = ds.get_dataset_size()

In [None]:
# Network
net_generator = get_generator()
net_discriminator = get_discriminator()
pix2pix = Pix2Pix(generator=net_generator, discriminator=net_discriminator)

# loss
d_loss_fn = LossD()
g_loss_fn = LossG()
d_loss_net = WithLossCellD(backbone=pix2pix, loss_fn=d_loss_fn)
g_loss_net = WithLossCellG(backbone=pix2pix, loss_fn=g_loss_fn)

# optimizer
d_opt = nn.Adam(pix2pix.netd.trainable_params(), learning_rate=get_lr(),
                beta1=arg.beta1, beta2=arg.beta2, loss_scale=1)
g_opt = nn.Adam(pix2pix.netg.trainable_params(), learning_rate=get_lr(),
                beta1=arg.beta1, beta2=arg.beta2, loss_scale=1)

# train net
train_net = TrainOneStepCell(loss_netd=d_loss_net, loss_netg=g_loss_net, optimizerd=d_opt, optimizerg=g_opt, sens=1)
train_net.set_train()

In [None]:
# Training loop
g_losses = []
d_losses = []
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=arg.epoch_num)
for epoch in range(arg.epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_net(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 100 == 0:
            print("================start===================")
            print("Date time: ", start_time)
            if arg.run_distribute:
                print("Device ID :", str(rank))
            print("ms per step :", delta/1000)
            print("epoch: ", epoch + 1, "/", arg.epoch_num)
            print("step: ", i, "/", steps_per_epoch)
            print("Dloss: ", dis_loss)
            print("Gloss: ", gen_loss)
            print("=================end====================")

        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())

- 循环训练网络，每次迭代，就收集生成器和判别器的损失，以便于后面绘制训练过程中损失函数的图像。

#### 训练输出

================start===================  
Date time:  2022-07-06 23:29:31.486978  
ms per step : 573.098  
epoch:  1 / 15  
step:  0 / 34641  
Dloss:  0.8268157  
Gloss:  96.03406  
=================end====================  
================start===================  
Date time:  2022-07-06 23:29:42.813233  
ms per step : 48.086  
epoch:  1 / 15  
step:  100 / 34641  
Dloss:  0.08200404  
Gloss:  14.9787245  
=================end====================  
================start===================  
Date time:  2022-07-06 23:29:47.699990  
ms per step : 49.513  
epoch:  1 / 15  
step:  200 / 34641  
Dloss:  1.7021327  
Gloss:  11.134652  
=================end====================  
···  
================start===================  
Date time:  2022-07-07 06:29:31.024008  
ms per step : 48.005  
epoch:  15 / 15  
step:  34400 / 34641  
Dloss:  6.0905662e-05  
Gloss:  13.91369  
=================end====================  
================start===================  
Date time:  2022-07-07 06:29:35.837478  
ms per step : 48.055  
epoch:  15 / 15  
step:  34500 / 34641  
Dloss:  6.202688e-05  
Gloss:  17.475426  
=================end====================  
================start===================  
Date time:  2022-07-07 06:29:40.601357  
ms per step : 47.031  
epoch:  15 / 15  
step:  34600 / 34641  
Dloss:  1.3251489e-06  
Gloss:  17.247965  
=================end====================

## 结果

- 采集本次训练过程中损失值如下：

![3.png](./images/3.png)


- 使用各个数据集训练好的模型分别推理，得到Pix2Pix的图片输出如下：

![4.png](./images/4.png)
