# [lightweught GAN(by lucidrains)](https://github.com/lucidrains/lightweight-gan)
tensorflow implementation

# Operations

In [None]:
!pip install tfa-nightly

Collecting tfa-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/9e/b6/30b7476e9b80b82f7e541077809e232154b629b65108f949fdced11f0d8e/tfa_nightly-0.13.0.dev20210213171101-cp36-cp36m-manylinux2010_x86_64.whl (706kB)
[K     |▌                               | 10kB 17.2MB/s eta 0:00:01[K     |█                               | 20kB 15.1MB/s eta 0:00:01[K     |█▍                              | 30kB 10.4MB/s eta 0:00:01[K     |█▉                              | 40kB 8.8MB/s eta 0:00:01[K     |██▎                             | 51kB 5.6MB/s eta 0:00:01[K     |██▉                             | 61kB 5.8MB/s eta 0:00:01[K     |███▎                            | 71kB 6.1MB/s eta 0:00:01[K     |███▊                            | 81kB 6.2MB/s eta 0:00:01[K     |████▏                           | 92kB 6.0MB/s eta 0:00:01[K     |████▋                           | 102kB 6.5MB/s eta 0:00:01[K     |█████                           | 112kB 6.5MB/s eta 0:00:01[K     |█████▋

In [None]:
import os
import pickle
from random import random

import numpy as np
from PIL import Image
import tensorflow as tf
import tensorflow_addons as tfa

from tqdm import tqdm

## Custom Layers

In [None]:
class GSA(tf.keras.layers.Layer):
    def __init__(self, output_filters, n_keys=64, heads=8, **kwargs):
        super(GSA, self).__init__(**kwargs)

        self.output_filters = output_filters
        self.n_keys = n_keys
        self.heads = heads

        hidden = n_keys * heads
        self.conv1 = tf.keras.layers.Conv2D(hidden*3, 1, kernel_initializer=kernel_initializer, use_bias=False)
        self.out_conv = tf.keras.layers.Conv2D(output_filters, 1, kernel_initializer=kernel_initializer)
    
    def call(self, inputs, **kwargs):
        input_shape = inputs.shape
        q, k, v = tf.split(self.conv1(inputs), 3, axis=3)
        s = [-1, input_shape[1]*input_shape[2], self.heads, self.n_keys]
        q, k, v = tf.reshape(q, s), tf.reshape(k, s), tf.reshape(v, s)
        k = tf.nn.softmax(k, axis=1)
        q = tf.nn.softmax(q, axis=3)
        context = tf.einsum('bihj, bihk -> bjhk', k, v)
        out = tf.einsum('bihc, bjhi -> bjhc', context, q)
        out = tf.reshape(out, [-1, input_shape[1], input_shape[2], self.heads*self.n_keys])
        out = self.out_conv(out)
        return out
    
    def get_config(self):
        base_config = super(GSA, self).get_config()
        config = dict(output_filters=self.output_filters, n_keys=self.n_keys, heads=self.heads)
        return dict(list(base_config.items()) + list(config.items()))

class ScalingLayer(tf.keras.layers.Layer):
    def __init__(self, *args, **kwargs):
        super(ScalingLayer, self).__init__(*args, **kwargs)
    
    def build(self, *args, **kwargs):
        self.weight = self.add_weight(name='scale',
                                      shape=(1,),
                                      dtype=tf.float32,
                                      initializer=tf.keras.initializers.Constant(1e-3),
                                      trainable=True,
                                      aggregation=tf.compat.v1.VariableAggregation.MEAN)
        super(ScalingLayer, self).build(*args, **kwargs)
    
    def call(self, inputs, **kwargs):
        return inputs * self.weight

class UpFIR2d(tf.keras.layers.Layer):
    def __init__(self, scale=2, k=None, gain=1.0, up=False, down=False, conv=False, conv_k_size=None, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super(UpFIR2d, self).__init__(trainable=trainable, name=name, dtype=dtype, dynamic=dynamic, **kwargs)
        
        self.scale = scale
        self.gain = gain
        self.up = up
        self.down = down
        assert not(self.up == self.down)
        self.conv = conv
        self.conv_k_size = conv_k_size
        if k is None:
            self.k = (1,) * self.scale
        else:
            self.k = tuple(k)
        f = self._get_filter()
        self.filter = self.add_weight(name='resample_kernel',
                                      shape=f.shape,
                                      dtype=tf.float32,
                                      initializer=tf.keras.initializers.Constant(f),
                                      trainable=False,
                                      aggregation=tf.VariableAggregation.MEAN)
    
    def _get_filter(self):
        k = np.asarray(self.k, dtype=np.float32)
        k = np.outer(k, k)
        k /= np.sum(k)
        if self.up:
            k = k * (self.gain * (self.scale ** 2))
        elif self.down:
            k = k * self.gain
        return k[:,:,np.newaxis, np.newaxis]
    
    def _get_params(self):
        p = self.filter.shape[0] - self.scale
        if self.conv:
            if self.up:
                p -= (self.conv_k_size - 1)
            elif self.down:
                p += (self.conv_k_size - 1)
        if self.up:
            up = 1 if self.conv else self.scale
            down = 1
            p0 = (p+1)//2+self.scale-1
            p1 = p//2 + 1 if self.conv else p//2
        elif self.down:
            up = 1
            down = 1 if self.conv else self.scale
            p0 = (p+1)//2
            p1 = p//2
        return dict(up=up, down=down, p0=p0, p1=p1)
    
    def _upfirdn2d_op(self, x, upx, upy, downx, downy, px0, px1, py0, py1):
        xs = tf.shape(x)
        #x = tf.reshape(x, [-1, x.shape[1], 1, x.shape[2], 1, x.shape[3]])
        x = tf.reshape(x, [-1, xs[1], 1, xs[2], 1, xs[3]])
        x = tf.pad(x, [[0, 0], [0, 0], [0, upy-1], [0, 0], [0, upx-1], [0, 0]])
        x = tf.reshape(x, [-1, xs[1]*upy, xs[2]*upx, xs[3]])

        x = tf.pad(x, [[0, 0], [max(py0, 0), max(py1, 0)], [max(px0, 0), max(px1, 0)], [0, 0]])
        x = x[:, max(-py0, 0):tf.shape(x)[1] - max(-py1, 0), max(-px0, 0):tf.shape(x)[2] - max(-px1, 0), :]

        x = tf.nn.depthwise_conv2d(x, tf.tile(self.filter, [1, 1, tf.shape(x)[-1], 1]), strides=[1, 1, 1, 1], padding='VALID')
        return x[:, ::downy, ::downx, :]

    def _upfirdn2d(self, x, up=1, down=1, p0=0, p1=0):
        x = self._upfirdn2d_op(x, upx=up, upy=up, downx=down, downy=down, px0=p0, px1=p1, py0=p0, py1=p1)
        return x
    
    def compute_output_shape(self, input_shape):
        if self.up:
            return (None, input_shape[1]*self.scale, input_shape[2]*self.scale, input_shape[3])
        elif self.down:
            return (None, input_shape[1]//self.scale, input_shape[2]//self.scale, input_shape[3])
    
    def call(self, inputs, **kwargs):
        params = self._get_params()
        x = self._upfirdn2d(inputs, **params)
        x.set_shape(self.compute_output_shape(inputs.shape))
        return x
    
    def get_config(self):
        basse_config = super(UpFIR2d, self).get_config()
        config = dict(
            scale=self.scale,
            gain=self.gain,
            up=self.up,
            down=self.down,
            conv=self.conv,
            conv_k_size=self.conv_k_size,
            k=self.k,
        )
        return dict(list(basse_config.items()) + list(config.items()))


##Operations

In [None]:
# kernel_initializer = tf.keras.initializers.random_normal(mean=0.0, stddev=0.02)
# kernel_initializer = tf.keras.initializers.glorot_uniform()
kernel_initializer = tf.keras.initializers.VarianceScaling(scale=2, mode='fan_in', distribution='truncated_normal')

def global_context(input, output_channels):
    #Squeeze and Excitation with simple attention
    input_shape = input.shape
    key = tf.keras.layers.Conv2D(1, 1)(input)
    key = tf.reshape(key, [-1, input_shape[1]*input_shape[2], 1])
    key = tf.nn.softmax(key, axis=1)
    flat_input = tf.reshape(input, [-1, input_shape[1]*input_shape[2], input_shape[3]])
    attn = tf.matmul(key, flat_input, transpose_a=True)
    attn = tf.expand_dims(attn, axis=1)
    out = tf.keras.layers.Conv2D(max(3, output_channels//2), 1, kernel_initializer=kernel_initializer)(attn)
    out = tf.nn.leaky_relu(out, alpha=0.1)
    out = tf.keras.layers.Conv2D(output_channels, 1, kernel_initializer=kernel_initializer)(out)
    out = tf.nn.sigmoid(out)
    return out

def global_self_attention(x, output_filters, n_keys=64, heads=8):
    return GSA(output_filters, n_keys, heads)(x)

def blur2D(input):
    blur_filter = np.array([1., 2., 1.])
    blur_filter = blur_filter[:, np.newaxis, np.newaxis, np.newaxis] * blur_filter[np.newaxis, :, np.newaxis, np.newaxis]
    blur_filter = blur_filter / np.sum(np.abs(blur_filter))
    blur_filter = np.tile(blur_filter, [1, 1, input.shape[-1], 1])
    pad = tf.pad(input, [[0, 0], [1, 1], [1, 1], [0, 0]])
    conv = tf.keras.layers.DepthwiseConv2D(3, padding='valid', use_bias=False, depthwise_initializer=tf.keras.initializers.Constant(blur_filter), trainable=False)
    # out = tf.nn.depthwise_conv2d(pad, tf.constant(blur_filter, dtype=tf.float32), (1, 1, 1, 1), 'VALID')
    return conv(pad)

def batchnorm(x):
    return tf.keras.layers.experimental.SyncBatchNormalization(momentum=0.9, epsilon=1e-4)(x)
    # return tf.keras.layers.BatchNormalization()(x)

def glu(x):
    x1, x2 = tf.split(x, 2, axis=3)
    return x1 * tf.nn.sigmoid(x2)

def upscale(input, filters, attn_res=[]):
    if input.shape[1] in attn_res:
        input = ScalingLayer()(GSA(input.shape[-1])(input)) + input
    x = tf.keras.layers.UpSampling2D()(input)
    x = blur2D(x)
    x = tf.keras.layers.Conv2D(filters*2, 3, padding='SAME', kernel_initializer=kernel_initializer)(x)
    x = batchnorm(x)
    x = glu(x)
    return x

def downscale(input, filters, attn_res=[]):
    if input.shape[1] in attn_res:
        input = ScalingLayer()(GSA(input.shape[-1])(input)) + input
    x = blur2D(input)

    x1 = tf.keras.layers.Conv2D(filters, 4, 2, padding='SAME', kernel_initializer=kernel_initializer)(x)
    x1 = tf.nn.leaky_relu(x1, 0.1)
    x1 = tf.keras.layers.Conv2D(filters, 3, padding='SAME', kernel_initializer=kernel_initializer)(x1)
    x1 = tf.nn.leaky_relu(x1, 0.1)

    x2 = tf.keras.layers.AveragePooling2D()(x)
    x2 = tf.keras.layers.Conv2D(filters, 1, kernel_initializer=kernel_initializer)(x2)
    x2 = tf.nn.leaky_relu(x2, 0.1)
    return x1 + x2


# Models

## Generator

In [None]:
def create_generator(img_size, latent_dims=256, filter_max=512, attn_res=[]):
    # img_size must be in [256, 512, 1024]
    inp = tf.keras.Input(shape=(latent_dims,), dtype=tf.float32)
    x = tf.reshape(inp, [-1, 1, 1, latent_dims])
    with tf.keras.backend.name_scope('Initial_conv'):
        x = tf.keras.layers.Conv2DTranspose(latent_dims*2, 4, kernel_initializer=kernel_initializer)(x)
        x = batchnorm(x)
        x = glu(x)
    x = tf.nn.l2_normalize(x, axis=3)
    filter_num = {
        8    : 512,
        16   : 512,
        32   : 256,
        64   : 128,
        128  :  64,
        256  :  32,
        512  :  16,
        1024 :   8
    }
    sle_resolution_pairs = {16 : 128, 32 : 256, 64 : 512, 128 : 1024}
    sle_feature = []
    with tf.keras.backend.name_scope('Upsample_to_8x8'):
        x = upscale(x, filters=min(filter_num[8], filter_max), attn_res=attn_res)
    with tf.keras.backend.name_scope('Upsample_to_16x16'):
        x = upscale(x, filters=min(filter_num[16], filter_max), attn_res=attn_res)
        sle_feature.append(x)
    with tf.keras.backend.name_scope('Upsample_to_32x32'):
        x = upscale(x, filters=min(filter_num[32], filter_max), attn_res=attn_res)
        sle_feature.append(x)
    with tf.keras.backend.name_scope('Upsample_to_64x64'):
        x = upscale(x, filters=min(filter_num[64], filter_max), attn_res=attn_res)
        if img_size > 256:
            sle_feature.append(x)
    with tf.keras.backend.name_scope('Upsample_to_128x128'):
        x = upscale(x, filters=min(filter_num[128], filter_max), attn_res=attn_res)
        if img_size > 512:
            sle_feature.append(x)
    lr_model = tf.keras.Model(inp, [x]+sle_feature, name='Low_res_part')

    feature128 = tf.keras.Input(shape=x.shape[1:])
    sle_feature_input = [tf.keras.Input(shape=f.shape[1:]) for f in sle_feature]
    x = feature128 * global_context(sle_feature_input[0], filter_num[sle_resolution_pairs[16]])
    with tf.keras.backend.name_scope('Upsample_to_256x256'):
        x = upscale(x, filters=min(filter_num[256], filter_max), attn_res=attn_res)
        x = x * global_context(sle_feature_input[1], filter_num[sle_resolution_pairs[32]])
    if img_size > 256:
        with tf.keras.backend.name_scope('Upsample_to_512x512'):
            x = upscale(x, filters=min(filter_num[512], filter_max), attn_res=attn_res)
            x = x * global_context(sle_feature_input[0], filter_num[sle_resolution_pairs[64]])
    if img_size > 512:
        with tf.keras.backend.name_scope('Upsample_to_1024x1024'):
            x = upscale(x, filters=min(filter_num[1024], filter_max), attn_res=attn_res)
            x = x * global_context(sle_feature_input[0], filter_num[sle_resolution_pairs[128]])
    x = tf.keras.layers.Conv2D(3, 3, padding='SAME', kernel_initializer=kernel_initializer)(x)
    # x = tf.nn.tanh(x)
    hr_model = tf.keras.Model([feature128]+sle_feature_input, x, name='High_res_part')
    input = tf.keras.Input(shape=(latent_dims,), dtype=tf.float32)
    return tf.keras.Model(input, hr_model(lr_model(input)))

## simple decoder

In [None]:
def simple_decoder(x, n_upsample=4):
    for i in range(n_upsample):
        filters = 3 if i == (n_upsample-1) else x.shape[-1]//2
        x = tf.keras.layers.UpSampling2D()(x)
        x = tf.keras.layers.Conv2D(filters*2, 3, padding='same', kernel_initializer=kernel_initializer)(x)
        x = glu(x)
    return x

## Discriminator

In [None]:
def create_discriminator(img_size, filter_max=512, attn_res=[]):
    # img_size must be in [256, 512, 1024]
    filter_num = {8    : 512,
                  16   : 256,
                  32   : 128,
                  64   :  64,
                  128  :  32,
                  256  :  16,
                  512  :   8,
                  1024 :   8}
    
    inp = tf.keras.Input(shape=(img_size, img_size, 3), dtype=tf.float32)

    x = inp
    n_dwonsample = int(np.log2(img_size//256))
    for _ in range(n_dwonsample):
        x = blur2D(x)
        x = tf.keras.layers.Conv2D(filter_num[x.shape[1]//2], 4, 2, padding='SAME', kernel_initializer=kernel_initializer)(x)
        x = tf.nn.leaky_relu(x, 0.1)
    mid_feature = []
    for _ in range(5):
        x = downscale(x, filter_num[x.shape[1]//2], attn_res=attn_res)
        if x.shape[1] <=16:
            mid_feature.append(x)
    x = tf.keras.layers.Conv2D(filter_num[8], 1, kernel_initializer=kernel_initializer)(x)
    x = tf.nn.leaky_relu(x, 0.1)
    logits = tf.keras.layers.Conv2D(1, 4, kernel_initializer=kernel_initializer)(x)

    # lr_part
    # resize_inp = tf.image.resize(lin, (32, 32))
    resize_inp = UpFIR2d(scale=img_size//32, k=[1, 3, 3, 1], down=True)(inp)
    resize_inp.set_shape([None, 32, 32, 3])
    x = tf.keras.layers.Conv2D(64, 3, padding='SAME', kernel_initializer=kernel_initializer)(resize_inp)
    x += ScalingLayer()(global_self_attention(x, 64))
    x = downscale(x, 32)
    x += ScalingLayer()(global_self_attention(x, 32))
    x = tfa.layers.AdaptiveAveragePooling2D((4, 4))(x)
    lr_logits = tf.keras.layers.Conv2D(1, 4, kernel_initializer=kernel_initializer)(x)
    return tf.keras.Model(inp, [logits, lr_logits]+mid_feature)

## Differentiable Augmentation

In [None]:
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# https://github.com/mit-han-lab/data-efficient-gans

def DiffAugment(x, policy=[], channels_first=False):
    if policy:
        if channels_first:
            x = tf.transpose(x, [0, 2, 3, 1])
        for p in policy:
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if channels_first:
            x = tf.transpose(x, [0, 3, 1, 2])
    return x


def rand_brightness(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5
    x = x + magnitude
    return x


def rand_saturation(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2
    x_mean = tf.reduce_mean(x, axis=3, keepdims=True)
    x = (x - x_mean) * magnitude + x_mean
    return x


def rand_contrast(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5
    x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)
    x = (x - x_mean) * magnitude + x_mean
    return x


def rand_translation(x, ratio=0.125):
    batch_size = tf.shape(x)[0]
    image_size = tf.shape(x)[1:3]
    shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
    translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
    translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
    grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
    grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
    x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
    x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
    return x


def rand_cutout(x, ratio=0.5):
    batch_size = tf.shape(x)[0]
    image_size = x.shape[1:3]
    cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
    offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
    offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
    grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32),
                                             tf.range(cutout_size[0], dtype=tf.int32),
                                             tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
    cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1)
    mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])
    cutout_grid = tf.maximum(cutout_grid, 0)
    cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))
    mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0)
    x = x * tf.expand_dims(mask, axis=3)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

## Define Trainer

In [None]:
class GAN:
    def __init__(
        self,
        data_dir,
        save_dir,
        batchsize=8,
        grad_accumulation=4,
        latent_dims=256,
        lr=2e-4,
        ttur_mul=1.0,
        beta_1=0.5,
        augment_policy=['translation', 'cutout'],
        d_steps=2,
        gp_type='r1',
        gp_target=0.0,
        gp_weight=10.0,
        attn_res=[],
        seed=12345):
 
        self.save_dir = save_dir
        self.batchsize = batchsize
        self.grad_accumulation = grad_accumulation
        self.batchsize_per_replica = self.batchsize
        self.latent_dims = latent_dims
        self.lr = lr
        self.ttur_mul = ttur_mul
        self.beta_1 = beta_1
        self.d_steps = d_steps
        self.gp_target = gp_target
        self.gp_weight = gp_weight
        self.policy = augment_policy
        self.attn_res = attn_res
        tf.random.set_seed(seed)
 
        os.makedirs(os.path.join(self.save_dir, 'result', 'generated'), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, 'result', 'reconstruction'), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, 'result', 'part_reconstruction'), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, 'models'), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, 'optimizers'), exist_ok=True)
        with open(os.path.join(save_dir, 'hparams_info.txt'), 'w') as f:
            f.write(f'batchsize : {batchsize}\nlatent_dims : {latent_dims}\nlearning_rater : {lr}\nttur_mul : {ttur_mul}\nAdam_beta_1 : {beta_1}\ndiff_augment_policy : {augment_policy}\ngp_type : {gp_type}\ngp_target : {gp_target}\ngp_weight : {gp_weight}\nattn_res : {attn_res}\nseed : {seed}\n\n')
 
        print('Load dataset...')
        self.dataset, self.data_shape, self.test_data = self.load_data(data_dir, self.batchsize*self.grad_accumulation)
        self.test_z = tf.random.normal(shape=(25, self.latent_dims))
        Image.fromarray(self.batch2tile(tf.image.resize(self.test_data, (128, 128)), 8, 8).numpy()).save(os.path.join(self.save_dir, 'result', 'reconstruction', 'raw.jpg'))
        Image.fromarray(self.batch2tile(tf.image.resize(self.test_data, (256, 256)), 8, 8).numpy()).save(os.path.join(self.save_dir, 'result', 'part_reconstruction', 'raw.jpg'))
        
        self.model_names = ('G', 'D', 'whole_decoder', 'part_decoder')
        self.G, self.D, self.dec, self.part_dec = self.initialize_models()
        self.G_params = self.G.trainable_weights
        self.D_params = self.D.trainable_weights + self.dec.trainable_weights + self.part_dec.trainable_weights
 
        self.loss_names = ('real_score', 'fake_score', 'real_score_32', 'fake_score_32', 'whole_rec', 'part_rec', 'gp')
        try:
            with open(os.path.join(self.save_dir, 'result', 'losses.pkl'), 'rb') as file:
                self.losses = pickle.load(file)
                self.test_z = self.losses['test_z']
        except:
            self.losses = {key : [] for key in self.loss_names}
            self.losses['iterations'] = 0
            self.losses['d_iterations'] = 0
            self.losses['test_z'] = self.test_z
        if gp_type == 'r1':
            self.gradient_penalty = self.r1_gradient_penalty
        elif gp_type == 'r2':
            self.gradient_penalty = self.r2_gradient_penalty
        elif gp_type == 'wgan':
            self.gradient_penalty = self.wgangp_gradient_penalty

        self.G_opt, self.D_opt = self.init_optimizers()
        self.it = tf.Variable(self.losses['iterations'], trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        self.d_it = tf.Variable(self.losses['d_iterations'], trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        self.G_opt.iterations.assign(self.losses['iterations'])
        self.D_opt.iterations.assign(self.losses['iterations'])
    
    # def load_data(self, dir, batchsize=8):
    #     with open(dir, 'rb') as file:
    #         data = pickle.load(file)
    #     dataset = tf.data.Dataset.from_tensor_slices(data)
    #     proc_fn = lambda x : tf.image.random_flip_left_right((tf.cast(x, tf.float32) / 127.5) - 1.0)
    #     dataset = dataset.map(proc_fn).shuffle(4096, reshuffle_each_iteration=True).repeat().batch(batchsize, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
    #     return dataset, data.shape, tf.cast(data[:64], tf.float32)/127.5 - 1.0
 
    def load_data(self, dir, batchsize=8):
        with open(dir, 'rb') as file:
            data = pickle.load(file)
        dataset = tf.data.Dataset.from_tensor_slices(data)
        def proc_fn(x):
            img = tf.image.decode_jpeg(x)
            img = tf.image.random_flip_left_right(img)
            img = tf.cast(img, tf.float32)
            img = img / 127.5 - 1.0
            return img
        dataset = dataset.map(proc_fn).shuffle(4096, reshuffle_each_iteration=True).repeat().batch(batchsize, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
        test_data = tf.concat([tf.image.decode_jpeg(x)[np.newaxis, :, :, :] for x in data[:64]], axis=0)
        test_data = tf.cast(test_data, tf.float32) / 127.5 - 1.0
        return dataset, test_data.shape, test_data
 
    def batch2tile(self, batch, row, col):
        shape = batch.shape
        batch = tf.transpose(tf.reshape(batch, [row, col]+shape[1:]), [0, 2, 1, 3, 4])
        img = tf.reshape(batch, [shape[1]*row, shape[2]*col, -1])
        return tf.cast((img + 1.0)*127.5, tf.uint8)
 
    def initialize_models(self):
        assert self.data_shape[1] in [256, 512, 1024]
        if False not in [os.path.exists(os.path.join(self.save_dir, 'models', name+'.h5')) for name in self.model_names]:
            print('Load models...')
            custom_objects = {'ScalingLayer':ScalingLayer, 'GSA':GSA, 'UpFIR2d': UpFIR2d}
            models = (tf.keras.models.load_model(os.path.join(self.save_dir, 'models', name+'.h5'), custom_objects=custom_objects, compile=False) for name in self.model_names)
            return models
        else:
            print('Initialize models...')
            G = create_generator(self.data_shape[1], attn_res=self.attn_res)
            D = create_discriminator(self.data_shape[1], attn_res=self.attn_res)
            d_output_shape = D.output_shape
            dec_input = tf.keras.Input(shape=(8, 8, d_output_shape[-1][-1]), dtype=tf.float32)
            dec = tf.keras.Model(dec_input, simple_decoder(dec_input))
            partdec_input = tf.keras.Input(shape=(8, 8, d_output_shape[-2][-1]), dtype=tf.float32)
            part_dec = tf.keras.Model(partdec_input, simple_decoder(partdec_input))
            return G, D, dec, part_dec
    
    def init_optimizers(self):
        print('Initialize optimizers...')
        G_opt = tf.keras.optimizers.Adam(self.lr, beta_1=self.beta_1)
        D_opt = tf.keras.optimizers.Adam(self.lr*self.ttur_mul, beta_1=self.beta_1)
        for f_name, val, opt in zip(['G_opt', 'D_opt'], [self.G_params, self.D_params], [G_opt, D_opt]):
            if os.path.exists(os.path.join(self.save_dir, 'optimizers', f_name+'.pkl')):
                print('Load %s state...'%f_name)
                with open(os.path.join(self.save_dir, 'optimizers', f_name+'.pkl'), 'rb') as file:
                    opt_state = pickle.load(file)
                with tf.name_scope(opt_state['opt_name']):
                    w = opt_state['weights']
                    for v in val:
                        if v.name in w.keys():
                            for slot in w[v.name].keys():
                                initializer = tf.initializers.Constant(w[v.name][slot])
                                opt.add_slot(v, slot, initializer=initializer)
        return G_opt, D_opt
    
    # @tf.function
    def sampling_image(self):
        gen_img = self.random_generate(z=self.test_z)
        _, _, f16, f8 = self.D(self.test_data, training=True)
        rec = self.dec(f8, training=True)
        quarter_feature = self.split_quarter(f16)
        part = [self.part_dec(quarter_feature[i], training=True) for i in tf.range(4)]
        top = tf.concat([part[0], part[1]], axis=2)
        bottom = tf.concat([part[2], part[3]], axis=2)
        part_rec = tf.concat([top, bottom], axis=1)
        return gen_img, rec, part_rec
    
    def save_snapshot_image(self, epoch:int):
        imgs = self.sampling_image()
        file_names = [os.path.join(self.save_dir, 'result', 'generated', '%06d.jpg'%epoch),
                      os.path.join(self.save_dir, 'result', 'reconstruction', '%06d.jpg'%epoch),
                      os.path.join(self.save_dir, 'result', 'part_reconstruction', '%06d.jpg'%epoch)]
        for img, path in zip(imgs, file_names):
            row_col = int(np.sqrt(img.shape[0]))
            Image.fromarray(self.batch2tile(img, row_col, row_col).numpy()).save(path)
 
    def save(self):
        for name, model in zip(self.model_names, [self.G, self.D, self.dec, self.part_dec]):
            model.save(os.path.join(self.save_dir, 'models', name+'.h5'))
        if self.losses['iterations'] % 10000 == 0:
            os.makedirs(os.path.join(self.save_dir, 'models', str(self.losses['iterations'])), exist_ok=True)
            for name, model in zip(self.model_names, [self.G, self.D, self.dec, self.part_dec]):
                model.save(os.path.join(self.save_dir, 'models', str(self.losses['iterations']), name+'.h5'))
        for opt, val, f_name in zip([self.G_opt, self.D_opt], [self.G_params, self.D_params], ['G_opt', 'D_opt']):
            slot_names = opt.get_slot_names()
            w = {}
            for v in val:
                w[v.name] = {}
                for slot in slot_names:
                    w[v.name][slot] = opt.get_slot(v, slot).numpy()
            opt_weight = {'opt_name':opt._name, 'weights' : w}
            with open(os.path.join(self.save_dir, 'optimizers', f_name+'.pkl'), 'wb') as file:
                pickle.dump(opt_weight, file)
                file.flush()
        with open(os.path.join(self.save_dir, 'result', 'losses.pkl'), 'wb') as file:
            pickle.dump(self.losses, file)
            file.flush()
    
    def split_quarter(self, x):
        shape = x.shape
        x = tf.reshape(x, [-1, 2, shape[1]//2, 2, shape[2]//2, shape[3]])
        x = tf.transpose(x, [1, 3, 0, 2, 4, 5])
        x = tf.reshape(x, [4, -1, shape[1]//2, shape[2]//2, shape[3]])
        return x
    
    def random_crop(self, feature, img):
        idx = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
        img = tf.image.resize(img, (256, 256))
        img = self.split_quarter(img)[idx]
        feature = self.split_quarter(feature)[idx]
        return feature, img
    
    def get_latent_z(self, z=None, n=None, normalize=True):
        if z is None:
            z = tf.random.normal(shape=(n, self.latent_dims), dtype=tf.float32)
        if normalize:
            z = z/tf.norm(z, axis=-1, keepdims=True)
        return z
 
    def random_generate(self, z=None, n=None, normalize_latent=True, training=True):
        z = self.get_latent_z(z=z, n=n, normalize=normalize_latent)
        img = self.G(z, training=training)
        return img
    
    def mixing_generate(self, z_list=None, n=None, normalize_latent=True, training=True):
        lr = self.G.get_layer(index=1)
        hr = self.G.get_layer(index=2)
        if z_list is None:
            z_list = [self.get_latent_z(n=n, normalize=normalize_latent) for _ in range(2)]
        else:
            z_list = [self.get_latent_z(z, normalize=normalize_latent) for z in z_list]
        hr_inputs_z0 = lr(z_list[0], training=training)
        hr_inputs_z1 = lr(z_list[1], training=training)
        idx = tf.random.uniform(shape=(), maxval=len(lr.output_shape), dtype=tf.int32)
        hr_inputs = hr_inputs_z0[:idx] + hr_inputs_z1[idx:]
        img = hr(hr_inputs, training=training)
        return img
 
    def r1_gradient_penalty(self, real, fake):
        bs = real.shape
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(real)
            aug_image = DiffAugment(real, self.policy)
            d_out, d_out32, _, _  = self.D(aug_image, training=True)
        grad = tape.gradient([d_out, d_out32], real)
        gp = tf.norm(tf.reshape(grad, [-1, bs[1]*bs[2]*bs[3]]), axis=-1)
        gp = (gp - self.gp_target)**2
        return gp
    
    def r2_gradient_penalty(self, real, fake):
        bs = real.shape
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(fake)
            aug_image = DiffAugment(fake, self.policy)
            d_out, d_out32, _, _  = self.D(aug_image, training=True)
        grad = tape.gradient([d_out, d_out32], fake)
        gp = tf.norm(tf.reshape(grad, [-1, bs[1]*bs[2]*bs[3]]), axis=-1)
        gp = (gp - self.gp_target)**2
        return gp
    
    def wgangp_gradient_penalty(self, real, fake):
        bs = real.shape
        alpha = tf.random.uniform(shape=[bs[0], 1, 1, 1])
        image = real * alpha + fake * (1 - alpha)
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(image)
            aug_image = DiffAugment(image, self.policy)
            d_out, d_out32, _, _  = self.D(aug_image, training=True)
        grad = tape.gradient([d_out, d_out32], image)
        gp = tf.norm(tf.reshape(grad, [-1, bs[1]*bs[2]*bs[3]]), axis=-1)
        gp = (gp - self.gp_target)**2
        return gp
    
    def hinge_loss(self, real_logits, fake_logits):
        real_loss = tf.reduce_mean(tf.nn.relu(1.0 + real_logits), axis=[1, 2, 3])
        fake_loss = tf.reduce_mean(tf.nn.relu(1.0 - fake_logits), axis=[1, 2, 3])
        return real_loss + fake_loss
    
    def reconstruction_loss(self, rec, real):
        return tf.reduce_mean((real - rec)**2, axis=[1, 2, 3])
    
    def _train_step(self, real_img):
        bs = real_img.shape[0]
        with tf.GradientTape(watch_accessed_variables=False) as g_tape:
            g_tape.watch(self.G_params)
            # fake_img = self.mixing_generate(n=bs)
            fake_img = self.random_generate(n=bs)
            fake_aug = DiffAugment(fake_img, self.policy)
        for _ in range(self.d_steps):
            with tf.GradientTape() as d_tape:
                real_aug = DiffAugment(real_img, self.policy)
                fake_score, fake_score32, _, _ = self.D(fake_aug, training=True)
                real_score, real_score32, f16, f8 = self.D(real_aug, training=True)
 
                d_adv_loss = self.hinge_loss(real_score, fake_score)
                d_adv_loss_32 = self.hinge_loss(real_score32, fake_score32)
 
                rec_loss = self.reconstruction_loss(self.dec(f8, training=True), tf.image.resize(real_aug, (128, 128)))
                crop_feature, crop_img = self.random_crop(f16, real_aug)
                part_rec_loss = self.reconstruction_loss(self.part_dec(crop_feature, training=True), crop_img)
 
                d_loss = d_adv_loss + d_adv_loss_32 + rec_loss + part_rec_loss
 
                d_loss = tf.reduce_sum(d_loss) * (1. / self.batchsize)
            d_grad = d_tape.gradient(d_loss, self.D_params, unconnected_gradients=tf.UnconnectedGradients.ZERO)
 
            # gp
            if self.d_it % 4 == 0:
                with tf.GradientTape() as d_tape:
                    gp = self.gradient_penalty(real_img, fake_img)
                    gp_loss = self.gp_weight * tf.reduce_sum(gp) * (1. / self.batchsize)
                gp_grad = d_tape.gradient(gp_loss, self.D_params, unconnected_gradients=tf.UnconnectedGradients.ZERO)
                d_grad = [x+y for x, y in zip(d_grad, gp_grad)]
            else:
                gp = self.gradient_penalty(real_img, fake_img)
            self.D_opt.apply_gradients(zip(d_grad, self.D_params))
            self.d_it.assign_add(1)
        with g_tape:
            fake_score, fake_score32, _, _ = self.D(fake_aug, training=True)
            fake_score, fake_score32 = tf.reduce_mean(fake_score, axis=[1, 2, 3]), tf.reduce_mean(fake_score32, axis=[1, 2, 3])
 
            g_loss = fake_score + fake_score32
            g_loss = tf.reduce_sum(g_loss) * (1. / self.batchsize)
        g_grad = g_tape.gradient(g_loss, self.G_params, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        self.G_opt.apply_gradients(zip(g_grad, self.G_params))
        return tuple(tf.reduce_mean(x) for x in (real_score, fake_score, real_score32, fake_score32, rec_loss, part_rec_loss, gp))

    def _train_step_gradacc(self, real_img):
        total_real_score = tf.constant(0, dtype=tf.float32)
        total_fake_score = tf.constant(0, dtype=tf.float32)
        total_real_score32 = tf.constant(0, dtype=tf.float32)
        total_fake_score32 = tf.constant(0, dtype=tf.float32)
        total_rec_loss = tf.constant(0, dtype=tf.float32)
        total_part_rec_loss = tf.constant(0, dtype=tf.float32)
        total_gp = tf.constant(0, dtype=tf.float32)

        d_grad = [tf.zeros_like(v) for v in self.D_params]
        for batch in tf.split(real_img, self.grad_accumulation, axis=0):
            bs = batch.shape[0]
            fake_img = self.random_generate(n=bs)
            with tf.GradientTape() as d_tape:
                real_aug = DiffAugment(batch, self.policy)
                fake_score, fake_score32, _, _ = self.D(DiffAugment(fake_img, self.policy), training=True)
                real_score, real_score32, f16, f8 = self.D(real_aug, training=True)
 
                d_adv_loss = self.hinge_loss(real_score, fake_score)
                d_adv_loss_32 = self.hinge_loss(real_score32, fake_score32)
 
                rec_loss = self.reconstruction_loss(self.dec(f8, training=True), tf.image.resize(real_aug, (128, 128)))
                crop_feature, crop_img = self.random_crop(f16, real_aug)
                part_rec_loss = self.reconstruction_loss(self.part_dec(crop_feature, training=True), crop_img)
 
                d_loss = d_adv_loss + d_adv_loss_32 + rec_loss + part_rec_loss
 
                d_loss = tf.reduce_sum(d_loss) * (1. / (self.batchsize * self.grad_accumulation))
            
            d_grad = [x + y for x, y in zip(d_grad, d_tape.gradient(d_loss, self.D_params, unconnected_gradients=tf.UnconnectedGradients.ZERO))]

            total_real_score += tf.reduce_mean(real_score)
            total_real_score32 += tf.reduce_mean(real_score32)
            total_rec_loss += tf.reduce_mean(rec_loss)
            total_part_rec_loss += tf.reduce_mean(part_rec_loss)
            
            # gp
            if self.it % 4 == 0:
                with tf.GradientTape() as d_tape:
                    gp = self.gradient_penalty(batch, fake_img)
                    gp_loss = self.gp_weight * tf.reduce_sum(gp) * (1. / (self.batchsize * self.grad_accumulation))
                d_grad = [x + y for x, y in zip(d_grad, d_tape.gradient(gp_loss, self.D_params, unconnected_gradients=tf.UnconnectedGradients.ZERO))]
            else:
                gp = self.gradient_penalty(batch, fake_img)
            total_gp += tf.reduce_mean(gp)
        self.D_opt.apply_gradients(zip(d_grad, self.D_params))

        g_grad = [tf.zeros_like(v) for v in self.G_params]
        for _ in tf.range(self.grad_accumulation):
            with tf.GradientTape(watch_accessed_variables=False) as g_tape:
                g_tape.watch(self.G_params)
                fake_img = self.random_generate(n=self.batchsize_per_replica)
                fake_score, fake_score32, _, _ = self.D(DiffAugment(fake_img, self.policy), training=True)
                fake_score, fake_score32 = tf.reduce_mean(fake_score, axis=[1, 2, 3]), tf.reduce_mean(fake_score32, axis=[1, 2, 3])
    
                g_loss = fake_score + fake_score32
                g_loss = tf.reduce_sum(g_loss) * (1. / (self.batchsize * self.grad_accumulation))
            g_grad = [x+y for x, y in zip(g_grad, g_tape.gradient(g_loss, self.G_params, unconnected_gradients=tf.UnconnectedGradients.ZERO))]
            total_fake_score += tf.reduce_mean(fake_score)
            total_fake_score32 += tf.reduce_mean(fake_score32)
        self.G_opt.apply_gradients(zip(g_grad, self.G_params))
        return tuple(x / self.grad_accumulation for x in (total_real_score, total_fake_score, total_real_score32, total_fake_score32, total_rec_loss, total_part_rec_loss, total_gp))
    
    @tf.function
    def train_step(self, img, strategy=None):
        if strategy is not None:
            result = strategy.run(self._train_step, args=(img,))
            result = (strategy.reduce(tf.distribute.ReduceOp.MEAN, r, axis=None) for r in result)
        else:
            result = self._train_step(img)
        return tuple(result)
    
    def train_TPU(self, strategy=None, iterations=50000):
        if strategy is not None:
            self.dataset = strategy.experimental_distribute_dataset(self.dataset)
            self.batchsize_per_replica = self.batchsize // strategy.num_replicas_in_sync
        else:
            self.batchsize_per_replica = self.batchsize
        iterater = iter(self.dataset)
 
        with tqdm(range(self.losses['iterations']+1, iterations+1)) as p_bar:
            for _ in p_bar:
                data = next(iterater)
                loss = self.train_step(data, strategy=strategy)
                loss_dict = {'iteration':self.losses['iterations']+1}
                for name, val in zip(self.loss_names, loss):
                    l = val.numpy()
                    self.losses[name].append(l)
                    loss_dict[name] = l
                p_bar.set_postfix(loss_dict)
                if self.losses['iterations'] % 1000 == 0:
                    self.save_snapshot_image(self.losses['iterations'])
                    self.save()
                self.it.assign_add(1)
                self.losses['iterations'] = self.it.numpy()
                self.losses['d_iterations'] = self.d_it.numpy()
            self.save_snapshot_image(self.losses['iterations'])
            self.save()

# Training

In [None]:
# Download Dataset
%%shell

FILE_ID=1-0DvJgxitJtoVYL9HwYgKD3_5GpTDSYK
FILE_NAME=dataset.pkl
curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=${FILE_ID}" > /dev/null
CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"  
curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=${FILE_ID}" -o ${FILE_NAME}

In [None]:
args = dict(
    data_dir = 'dataset.pkl',
    save_dir = 'LightweightGAN_animeface',
    batchsize = 16,
    latent_dims = 256,
    lr = 2e-4,
    ttur_mul = 1.0,
    # learning rate in bigGAN
    #   res 128x128
    #   D : 2e-4 G : 5e-5
    #   res 256x256 or higher
    #   both D and G : 2.5e-5
    # learning rate in SAGAN
    #   D : 4e-4 G : 1e-4
    gp_type='r1',       # select 'r1' or 'r2' or 'wgan'
    gp_target = 1.0,    # gradient penalty center 
    gp_weight = 10.0,
    d_steps = 1,
    augment_policy=['translation', 'cutout'],
    attn_res=[],
)

## Train on TPU

In [None]:
# setup TPU
tpu_grpc_url = 'grpc://' + os.environ['COLAB_TPU_ADDR']
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)

INFO:tensorflow:Initializing the TPU system: grpc://10.43.221.178:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.43.221.178:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


In [None]:
tf.keras.backend.clear_session()
with strategy.scope():
    gan = GAN(**args)
    gan.train_TPU(strategy, iterations=300000)

Load dataset...
Load models...
Initialize optimizers...
Load G_opt state...
Load D_opt state...


 69%|██████▊   | 85112/124000 [10:25:12<4:29:17,  2.41it/s, iteration=261113, real_score=-.374, fake_score=0.249, real_score_32=0.345, fake_score_32=0.557, whole_rec=0.0417, part_rec=0.0334, gp=0.00538]