In [87]:
import torch
import timeit
import os
import time
import copy
import numpy as np

In [107]:
# extending Conv2D and Deconv2D layers for equalized learning rate logic
class _equalized_conv2d(torch.nn.Module):
    """ conv2d with the concept of equalized learning rate
        Args:
            :param c_in: input channels
            :param c_out:  output channels
            :param k_size: kernel size (h, w) should be a tuple or a single integer
            :param stride: stride for conv
            :param pad: padding
            :param bias: whether to use bias or not
    """

    def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
        from torch.nn.modules.utils import _pair
        from numpy import sqrt, prod

        super(_equalized_conv2d, self).__init__()

        # define the weight and bias if to be used
        self.weight = torch.nn.Parameter(torch.nn.init.normal_(torch.empty(c_out, c_in, *_pair(k_size))))

        self.use_bias = bias
        self.stride = stride
        self.pad = pad

        if self.use_bias:
            self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0))

        # prod 计算所有元素的积，可以指定维度
        # scale为什么这么算？
        fan_in = prod(_pair(k_size)) * c_in
        self.scale = sqrt(2) / sqrt(fan_in) 

    def forward(self, x):
        from torch.nn.functional import conv2d

        return conv2d(input=x,
                      weight=self.weight * self.scale,
                      bias=self.bias if self.use_bias else None,
                      stride=self.stride,
                      padding=self.pad)

    # 输出的是各层权重的形状
    def extra_repr(self):
        return ",".join(map(str, self.weight.shape))

In [108]:
class _equalized_deconv2d(torch.nn.Module):
    """ Transpose convolution using the equalized learning rate
        Args:
            :param c_in: input channels
            :param c_out: output channels
            :param k_size: kernel size
            :param stride: stride for convolution transpose
            :param pad: padding
            :param bias: whether to use bias or not
    """
    def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
        """ constructor for the class """
        from torch.nn.modules.utils import _pair
        from numpy import sqrt

        super(_equalized_deconv2d, self).__init__()

        # define the weight and bias if to be used
        self.weight = torch.nn.Parameter(torch.nn.init.normal_(torch.empty(c_in, c_out, *_pair(k_size))))

        self.use_bias = bias
        self.stride = stride
        self.pad = pad

        if self.use_bias:
            self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0))

        fan_in = c_in  # value of fan_in for deconv
        self.scale = sqrt(2) / sqrt(fan_in)

    def forward(self, x):
        from torch.nn.functional import conv_transpose2d

        return conv_transpose2d(input=x,
                                weight=self.weight * self.scale,
                                bias=self.bias if self.use_bias else None,
                                stride=self.stride,
                                padding=self.pad)

    def extra_repr(self):
        return ",".join(map(str, self.weight.shape))


In [90]:
#----------------------------------------------------------------
# Pixelwise feature vector normalization.
# reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120
#----------------------------------------------------------------
class PixelwiseNorm(torch.nn.Module):
    def __init__(self):
        super(PixelwiseNorm, self).__init__()
    
    def forward(self, x, alpha=1e-8):
        # [N1HW], 对三个通道求均值，三个通道压缩成了单个通道
        y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt()
        # 带有广播机制，应该对图片的每个像素都除了一个均值
        y = x / y
        return y

In [91]:
# ==========================================================
# Layers required for Building The generator and
# discriminator
# ==========================================================
class GenInitialBlock(torch.nn.Module):
    """ Module implementing the initial block of the input """
    def __init__(self, in_channels, use_eql):
        """
        constructor for the inner class
        :param in_channels: number of input channels to the block
        :param use_eql: whether to use equalized learning rate
        """
        from torch.nn import LeakyReLU

        super(GenInitialBlock, self).__init__()

        if use_eql:
            self.conv_1 = _equalized_deconv2d(in_channels, in_channels, (4, 4), bias=True)
            self.conv_2 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1, bias=True)

        else:
            from torch.nn import Conv2d, ConvTranspose2d
            self.conv_1 = ConvTranspose2d(in_channels, in_channels, (4, 4), bias=True)
            self.conv_2 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True)

        self.pixNorm = PixelwiseNorm()

        self.lrelu = LeakyReLU(0.2)


    def forward(self, x):
        # 在前面加两维度
        y = torch.unsqueeze(torch.unsqueeze(x, -1), -1)

        # 先升采样再降采样？为什么这样做？
        y = self.lrelu(self.conv_1(y))
        y = self.lrelu(self.conv_2(y))

        y = self.pixNorm(y)

        return y

class GenGeneralConvBlock(torch.nn.Module):
    """ Module implementing a general convolutional block """
    
    def __init__(self, in_channels, out_channels, use_eql):
        from torch.nn import LeakyReLU
        from torch.nn.functional import interpolate

        super(GenGeneralConvBlock, self).__init__()

        self.upsample = lambda x: interpolate(x, scale_factor=2)

        if use_eql:
            self.conv_1 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1, bias=True)
            self.conv_2 = _equalized_conv2d(out_channels, out_channels, (3, 3), pad=1, bias=True)
        
        else:
            from torch.nn import Conv2d
            self.conv_1 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True)
            self.conv_2 = Conv2d(out_channels, out_channels, (3, 3), padding=1, bias=True)

        self.pixNorm = PixelwiseNorm()

        self.lrelu = LeakyReLU(0.2)

    def forward(self, x):
        y = self.upsample(x)
        y = self.pixNorm(self.lrelu(self.conv_1(y)))
        y = self.pixNorm(self.lrelu(self.conv_2(y)))

        return y



In [113]:
class Generator(torch.nn.Module):
    def __init__(self, depth=7, latent_size=512, use_eql=True):
        """
        constructor for the Generator class
        :param depth: required depth of the Network
        :param latent_size: size of the latent manifold
        :param use_eql: whether to use equalized learning rate
        """
        from torch.nn import ModuleList
        from torch.nn.functional import interpolate

        super(Generator, self).__init__()

        # 这个trick值得学习，如果是2的n次幂，那么它和自身减一按位与必是0
        assert latent_size != 0 and ((latent_size & (latent_size -1))==0), "latent size not a power of 2"

        if depth >= 4:
            assert latent_size >= np.power(2, depth -4), "latent size will diminish to zero"

        self.use_eql = use_eql
        self.depth = depth
        self.latent_size = latent_size

        self.initial_block = GenInitialBlock(self.latent_size, use_eql=self.use_eql)

        self.layers = ModuleList([])

        if self.use_eql:
            self.toRGB = lambda in_channels: _equalized_conv2d(in_channels, 3, (1, 1), bias=True)

        else:
            from torch.nn import Conv2d
            self.toRGB = lambda in_channels: Conv2d(in_channels, 3, (1, 1), bias=True)

        self.rgb_converters = ModuleList([self.toRGB(self.latent_size)])

        # create the remaining layers
        for i in range(self.depth - 1):
            if i <= 2:
                layer = GenGeneralConvBlock(self.latent_size, self.latent_size, use_eql=self.use_eql)
                rgb = self.toRGB(self.latent_size)

            else:
                layer = GenGeneralConvBlock(
                    int(self.latent_size // np.power(2, i-3)),
                    int(self.latent_size // np.power(2, i-2)),
                    use_eql = self.use_eql
                )
                rgb = self.toRGB(int(self.latent_size // np.power(2, i-2)))
            self.layers.append(layer)
            self.rgb_converters.append(rgb)
        
        # register the temporary upsampler
        self.temporaryUpsampler = lambda x: interpolate(x, scale_factor=2)

    def forward(self, x, depth, alpha):

        assert depth < self.depth, "Requested output depth cannot be produced"

        y = self.initial_block(x)

        if depth > 0:
            for block in self.layers[:depth-1]:
                y = block(y)
            
            # 旧层
            residual = self.rgb_converters[depth-1](self.temporaryUpsampler(y))
            # 新层
            straight = self.rgb_converters[depth](self.layers[depth - 1](y))

            out = (alpha * straight) + ((1 - alpha) * residual)
        else:
            out = self.rgb_converters[0](y)

        return out


In [93]:
class MinibatchStdDev(torch.nn.Module):
    """
    Minibatch standard deviation layer for the discriminator
    """

    def __init__(self):
        super(MinibatchStdDev, self).__init__()
    
    def forward(self, x, alpha=1e-8):
        """
        forward pass of the layer
        :param x: input activation volume
        :param alpha: small number for numerical stability
        :return: y => x appended with standard deviation constant map
        """
        batch_size, _, height. width = x.shape

        # [B x C x H x W] Subtract mean over batch.
        y = x - x.mean(dim=0, keepdim=True)

        # [1 x C x H x W]  Calc standard deviation over batch
        # y.pow: 方差 --> sqrt: 标准差 --> 对batch求均值，Minibatch的标准差
        y = torch.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha)

        # 所有batch，所有通道，所有像素点的平均标准差
        y = y.mean().view(1, 1, 1, 1)

        # 将上面求得的标准差广播成[B x 1 x H x W]
        y = y.repeat(batch_size, 1, height, width)

        # 在通道那一维拼接x和y
        y = torch.cat([x, y], 1)

        return y


In [94]:
class DisFinalBlock(torch.nn.Module):
    """ Final block for the Discriminator """
    def __init__(self, in_channels, use_eql):
        from torch.nn import LeakyReLU

        super(DisFinalBlock, self).__init__()

        # declare the required modules for forward pass
        self.batch_discriminator = MinibatchStdDev()

        if use_eql:
            self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True)
            self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True)
            # final conv layer emulates a fully connected layer
            self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
        else:
            from torch.nn import _equalized_conv2d
            self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
            self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
            # final conv layer emulates a fully connected layer
            self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)

        self.lrelu = LeakyReLU(0.2)

    def forward(self, x):
        # minibatch_std_dev layer
        y = self.batch_discriminator(x)

        y = self.lrelu(self.conv_1(y))
        y = self.lrelu(self.conv_2(y))

        # fully connected layer
        # 原文中用的全连接层，这里用1x1卷积+view替代了
        y = self.conv_3(y)

        # flatten the output raw discriminator scores
        return y.view(-1)

In [95]:
class DisGeneralConvBlock(torch.nn.Module):
    """ General block in the discriminator  """
    def __init__(self, in_channels, out_channels, use_eql):
        from torch.nn import AvgPool2d, LeakyReLU

        super(DisGeneralConvBlock, self).__init__()

        if use_eql:
            self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1, bias=True)
            self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1, bias=True)
        
        else:
            from torch.nn import _equalized_conv2d
            self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True)
            self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True)

        self.downSampler = AvgPool2d(2)

        self.lrelu = LeakyReLU(0.2)

    def forward(self, x):
        y = self.lrelu(self.conv_1(x))
        y = self.lrelu(self.conv_2(y))
        y = self.downSampler(y)

        return y


In [96]:
#================================================================
# Discriminator Module
# can be used with ProGAN or standalone (for inference).
# Note this cannot be used with ConditionalProGAN
#================================================================
class Discriminator(torch.nn.Module):
    """ Discriminator of the GAN """
    def __init__(self, height=7, feature_size=512, use_eql=True):
        """
        constructor for the class
        :param height: total height of the discriminator (Must be equal to the Generator depth)
        :param feature_size: size of the deepest features extracted
                             (Must be equal to Generator latent_size)
        :param use_eql: whether to use equalized learning rate
        """
        from torch.nn import ModuleList, AvgPool2d

        super(Discriminator, self).__init__()

        assert feature_size != 0 and ((feature_size & feature_size - 1) == 0), "feature size not a power of 2"
        if height >= 4:
            assert feature_size >= np.power(2, height-4), "feature size cannot be produced"

        self.use_eql = use_eql
        self.height = height
        self.feature_size = feature_size

        self.final_block = DisFinalBlock(self.feature_size, use_eql=self.use_eql)

        self.layers = ModuleList([])

        if self.use_eql:
            self.fromRGB = lambda out_channels: _equalized_conv2d(3, out_channels, (1, 1), bias=True)
        else:
            from torch.nn import _equalized_conv2d
            self.fromRGB = lambda out_channels: Conv2d(3, out_channels, (1, 1), bias=True)

        self.rgb_to_features = ModuleList([self.fromRGB(self.feature_size)])

        # create the remaining layers
        for i in range(self.height - 1):
            if i > 2:
                layer = DisGeneralConvBlock(
                    int(self.feature_size // np.power(2, i-2)),
                    int(self.feature_size // np.power(2, i-3)),
                    use_eql = self.use_eql
                )
                rgb = self.fromRGB(int(self.feature_size // np.power(2, i-2)))
            else:
                layer = DisGeneralConvBlock(self.feature_size, 
                                            self.feature_size,
                                            use_eql=self.use_eql)
                rgb - self.fromRGB(self.feature_size)

            self.layers.append(layer)
            self.rgb_to_features.append(rgb)

        self.temporaryDownsampler = AvgPool2d(2)

    def forward(self, x, height, alpha):
        assert height < self.height, "Requested output depth cannot be produced"

        if height > 0:
            # 旧层
            residual = self.rgb_to_features[height-1](self.temporaryDownsampler(x))
            # 新层
            straight = self.layer[height-1](self.rgb_to_features[height](x))

            for block in reversed(self.layers[:height - 1]):
                y = block(y)
        else:
            y = self.rgb_to_features[0](x)

        out = self.final_block(y)

        return out


In [97]:
# function to calculate the Exponential moving averages for the Generator weights
# This function updates the exponential average weights based on the current training
def update_average(model_tgt, model_src, beta):

    # utility function for toggling the gradient requirements of the models
    def toggle_grad(model, requires_grad):
        for p in model.parameters():
            p.requires_grad_(requires_grad)

    # turn off gradient calculation
    toggle_grad(model_tgt, False)
    toggle_grad(model_src, False)

    param_dict_src = dict(model_src.named_parameters())

    for p_name, p_tgt in model_tgt.named_parameters():
        p_src = param_dict_src[p_name]
        assert (p_src is not p_tgt)
        p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)

    toggle_grad(model_tgt, True)
    toggle_grad(model_src, True)

In [98]:
class GANLoss:
    """ Base class for all losses
        @args:
            dis: Discriminator used for calculating the loss
                 Note this must be a part of the GAN framework
    """

    def __init__(self, dis):
        self.dis = dis

    def dis_loss(self, real_samps, fake_samps, height, alpha):
        raise NotImplementedError("dis_loss method has not been implemented")


    def gen_loss(self, real_samps, fake_samps, height, alpha):
        raise NotImplementedError("gen_loss method has not been implemented")

class WGAN_GP(GANLoss):
    
    def __init__(self, dis, drift=0.001, use_gp=False):
        # 在python3中super().xxx相当于super(Class, self).xxx
        super().__init__(dis)

        self.drift = drift
        self.use_gp = use_gp

    def __gradient_penalty(self, real_samps, fake_samps, height, alpha, reg_lambda=10):

        batch_size = real_samps.shape[0]

        epsilon = torch.rand((batch_size, 1, 1 ,1)).to(fake_samps.device)

        merged = epsilon * real_samps + ((1 - epsilon) * fake_samps)
        merged.requires_grad_(True)

        op = self.dis(merged, height, alpha)

        gradient = torch.autograd.grad(
                                outputs=op, 
                                inputs = merged,
                                grad_outputs=torch.ones_like(op),
                                create_graph=True,
                                retain_graph=True,
                                only_inputs=True)[0]
        
        gradient = gradient.view(gradient.shape[0], -1)
        # 对第二维的梯度求L2范数
        penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()

        return penalty

    def dis_loss(self, real_samps, fake_samps, height, alpha):
        
        fake_out = self.dis(fake_samps, height, alpha)
        real_out = self.dis(real_samps, height, alpha)

        # 我们的目的是real和fake越相近越好，也就是real - fake越小越好
        # 而Discriminator的目的是区分real和fake，也就是real - fake越大越好
        # 也就是fake-real越小越好
        loss = (torch.mean(fake_out) - torch.mean(real_out) + (self.drift * torch.mean(real_out ** 2)))

        if self.use_gp:
            # gradient penalty使得梯度稳在1-Lipschitz范数
            gp = slef.__gradient_penalty(real_samps, fake_samps, height, alpha)
            loss += gp

        return loss

    def gen_loss(self, real_samps, fake_samps, height, alpha):
        
        loss = -torch.mean(self.dis(fake_samps, height, alpha))

        return loss

In [99]:
def get_transform(new_size=None):
    """
    obtain the image transforms required for the input data
    :param new_size: size of the resized images
    :return: image_transform => transform object from TorchVision
    """
    from torchvision.transforms import ToTensor, Normalize, Compose, Resize

    if new_size is not None:
        image_transform = Compose([
            Resize(new_size),
            ToTensor(),
            Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

    else:
        image_transform = Compose([
            ToTensor(),
            Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    return image_transform

def get_data_loader(dataset, batch_size, num_workers):
    """
    generate the data_loader from the given dataset
    :param dataset: dataset for training (Should be a PyTorch dataset)
                    Make sure every item is an Image
    :param batch_size: batch size of the data
    :param num_workers: num of parallel readers
    :return: dl => dataloader for the dataset
    """
    from torch.utils.data import DataLoader

    dl = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    return dl

In [100]:

#================================================================
# ProGAN Module (Unconditional)
#================================================================
class ProGAN:
    """ Wrapper around the Generator and the Discriminator """

    def __init__(self, depth=7, latent_size=512, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-8, drift=0.001, n_critic=1, use_eql=True, loss="wgan-gp", use_ema=True, ema_decay=0.999, device=torch.device("cuda")):
        from torch.optim import Adam
        from torch.nn import DataParallel

        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
        self.dis = Discriminator(depth, latent_size, use_eql=use_eql).to(device)

        if device == torch.device("cuda"):
            self.gen = DataParallel(self.gen)
            self.dis = DataParallel(self.dis)

        self.latent_size = latent_size
        self.depth = depth
        self.use_ema = use_ema
        self.ema_decay = ema_decay
        self.n_critic = n_critic
        self.use_eql = use_eql
        self.device = device
        self.drift = drift

        self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate, betas=(beta_1, beta_2), eps=eps)

        self.loss = self.__setup_loss(loss)

        if self.use_ema:
            
            # create a shadow copy of the generator
            self.gen_shadow = copy.deepcopy(self.gen)

            self.ema_updater = update_average

            self.ema_updater(self.gen_shadow, self.gen, beta=0)

    def __setup_loss(self, loss):
        
        loss = WGAN_GP(self.dis, self.drift, use_gp=True)

        return loss

    def __progressive_downsampling(self, real_batch, depth, alpha):
        from torch.nn import AvgPool2d
        from torch.nn.functional import interpolate

        # downsample the real_batch for the given depth
        down_sample_factor = int(np.power(2, self.depth  - depth -1))
        prior_downsample_factor = max(int(np.power(2, self.depth - depth)), 0)

        ds_real_samples = AvgPool2d(down_sample_factor)(real_batch)

        if depth > 0:
            prior_ds_real_samples = interpolate(AvgPool2d(prior_downsample_factor))

        else:
            prior_ds_real_samples = ds_real_samples

        # real samples are a combination of ds_real_samples and prior_ds_real_samples
        real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)

        return real_samples


    def optimize_discriminator(self, noise, real_batch, depth, alpha):
        real_samples = self.__progressive_downsampling(real_batch, depth, alpha)

        loss_val = 0
        for _ in range(self.n_critic):

            # generate a batch of samples
            fake_samples = self.gen(noise, depth, alpha).detach()

            loss = self.loss.dis_loss(real_samples, fake_samples, depth, alpha)

            # optimize discriminator
            self.dis_optim.zero_grad()
            loss.backward()
            self.dis_optim.step()

            loss_val += loss.item()

        return loss_val / self.n_critic

    
    def optimize_generator(self, noise, real_batch, depth, alpha):

        real_samples = self.__progressive_downsampling(real_batch, depth, alpha)

        # generate fake samples:
        fake_samples = self.gen(noise, depth, alpha)

        loss = self.loss.gen_loss(real_samples, fake_samples, depth, alpha)

        # optimize the generator
        self.gen_optim.zero_grad()
        loss.backward()
        self.gen_optim.step()

        # if use_ema is true, apply ema to the generator parameters
        if self.use_ema:
            self.ema_updater(self.gen_shadow, self.gen, self.ema_decay)

        # return the loss value
        return loss.item()


    def create_grid(samples, scale_factor, img_file):

        from torchvision.utils import save_image
        from torch.nn.functional import interpolate

        # upsample the image
        if scale_factor > 1:
            samples = interpolate(samples, scale_factor=scale_factor)

        # save the images:
        save_image(samples, img_file, nrow=int(np.sqrt(len(samples))), normalize=True, scale_each=True)

    # def train(self, dataset, epochs, batch_sizes,
    #           fade_in_percentage, num_samples=16,
    #           start_depth=0, num_workers=3, feedback_factor=100,
    #           log_dir="./models/", sample_dir="./samples/", save_dir="./models/",
    #           checkpoint_factor=1):

    #     assert self.depth == len(batch_sizes), "batch_sizes not compatible with depth"

    #     # turn the generator and discriminator into train mode
    #     self.gen.train()
    #     self.dis.train()
    #     if self.use_ema:
    #         self.gen_shadow.train()

    #     # create a global time counter
    #     global_time = time.time()

    #     # create fixed_input for debugging
    #     fixed_input = torch.randn(num_samples, self.latent_size).to(self.device)

In [101]:
import argparse
import torch
import numpy as np
import os
from torch.backends import cudnn
from torch.nn.functional import interpolate
from tqdm import tqdm

In [127]:
# turn on the fast GPU processing mode on
cudnn.benchmark = True
# define the device for the training script
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def parse_arguments():
    """
    default command line argument parser
    :return: args => parsed command line arguments
    """

    parser = argparse.ArgumentParser()

    parser.add_argument("--generator_file", 
                        type=str, 
                        default="GAN_GEN_8.pth",
                        help="pretrained weights file for generator")

    parser.add_argument("--latent_size",
                        type=int,
                        default=512,
                        help="latent size for the generator")

    parser.add_argument("--depth", action="store",
                        default=9,
                        help="depth of the network. **Starts from 1")

    parser.add_argument("--out_depth", 
                        type=int,
                        default=6,
                        help="output depth of images. **Starts from 0")

    parser.add_argument("--num_samples",
                        type=int,
                        default=300,
                        help="number of synchronized grids to be generated")

    parser.add_argument("--out_dir", 
                        type=str,
                        default="interp_animation_frames/",
                        help="path to the output directory for the frames")

    args = parser.parse_known_args()[0]

    return args


def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
    """
    adjust the dynamic colour range of the given input data
    :param data: input image data
    :param drange_in: original range of input
    :param drange_out: required range of output
    :return: img => colour range adjusted images
    """
    if drange_in != drange_out:
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * scale + bias
    return torch.clamp(data, min=0, max=1)


In [159]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2

# loader使用torchvision中自带的transforms函数
loader = transforms.Compose([
    transforms.ToTensor()])  

unloader = transforms.ToPILImage()

def save_image(tensor, epoch):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)  # remove the fake batch dimension
    image = unloader(image)
    image.save('save_pic/epoch_%d.png'%epoch)

In [160]:
args = parse_arguments()
print("Creating generator object ...")
# create the generator object
gen = torch.nn.DataParallel(Generator(
    depth=args.depth,
    latent_size=args.latent_size
))

print("Loading the generator weights from:", args.generator_file)
# load the weights into it
gen.load_state_dict(
    torch.load(args.generator_file, map_location=str(device))
)

#path for saving the files:
save_path = args.out_dir

print("Generating scale synchronized images ...")
for img_num in tqdm(range(1, args.num_samples + 1)):
    # generate the images:
    with torch.no_grad():
        point = torch.randn(1, args.latent_size)
        point = (point / point.norm()) * (args.latent_size ** 0.5)
        ss_image = gen(point, depth=args.out_depth, alpha=1)
        # color adjust the generated image:
        ss_image = adjust_dynamic_range(ss_image)

    #save the ss_image in the directory
    save_image(ss_image, img_num)
    break
print("Generated %d images at %s" % (args.num_samples, save_path))

Creating generator object ...
  0%|          | 0/300 [00:00<?, ?it/s]Loading the generator weights from: GAN_GEN_8.pth
Generating scale synchronized images ...
torch.Size([3, 256, 256])
  0%|          | 0/300 [00:04<?, ?it/s]Generated 300 images at interp_animation_frames/

