# Pytorch Version

In [1]:
import torch.nn as nn
import torch
import torch.optim as optim
import numpy as np

np.random.seed(0)
torch.manual_seed(0)

class ResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, resample=None, normalize=False, activation=None, 
                 first_activation=True):
        super(ResBlock, self).__init__()
        self.in_filters = in_filters
        self.out_filters = out_filters
        self.resample = resample
        self.normalize = normalize
        self.conv1 = self.conv_layer(in_filters, out_filters, padding=1)
        self.conv2 = self.conv_layer(out_filters, out_filters, padding=1)
        if resample:
            self.conv3 = self.conv_layer(in_filters, out_filters, kernel_size=1)
        if normalize:
            self.bn1 = self.bn(in_filters)
            self.bn2 = self.bn(out_filters)
            
        if activation is not None:
            self.activation = activation
        else:
            self.activation = nn.ReLU()
            
        self.first_activation = first_activation     
    
    def forward(self, x): 
        orig_input = x
        if self.normalize:
            x = self.bn1(x)
            
        if self.first_activation:
            x = self.activation(x)
        
        if self.resample == 'up':
            x = self.upsample(x)
       
        x = self.conv1(x)

        if self.normalize:
            x = self.bn2(x)
            
        x = self.activation(x)
        x = self.conv2(x)
        
        if self.resample == 'down':
            x = self.downsample(x)
        
        # Shortcut
        if self.resample == 'down': 
            shortcut_x = self.downsample(orig_input)
            shortcut_x = self.conv3(shortcut_x)
        elif self.resample == 'up':
            shortcut_x = self.upsample(orig_input)
            shortcut_x = self.conv3(shortcut_x)
        elif self.resample == None:
            shortcut_x = orig_input
        return x + shortcut_x
    
    
    def upsample(self, x):
        return nn.functional.interpolate(x, scale_factor=2, mode="nearest") #TODO: Check this is the same

    def downsample(self, x):
        batch, height, width, channels = x.shape
        return nn.functional.interpolate(x, scale_factor=0.5, 
                                         mode="bilinear") #TODO: Check this is the same

    def conv_layer(self, in_filters, out_filters=32, kernel_size=3, padding=0):
        return nn.Conv2d(in_filters, out_filters, kernel_size=kernel_size, padding=padding)

    def bn(self, channels):
        return nn.BatchNorm2d(channels, eps=1e-5)
    


class Generator(nn.Module):
    def __init__(self, input_size, num_filters=128, num_conv=3, first_linear_size=4):
        super(Generator, self).__init__()
        self.first_linear = nn.Linear(input_size, first_linear_size * first_linear_size * num_filters)
        self.num_filters = num_filters
        self.activation = nn.ReLU()
        self.first_linear_size = first_linear_size
        self.resblocks = []
        for _ in range(num_conv):
            self.resblocks.append(
                ResBlock(in_filters=self.num_filters, 
                         out_filters=self.num_filters, 
                         resample='up', 
                         normalize=True,
                         activation=self.activation))
        orig_num_channels = 3
        self.last_layer = nn.Conv2d(num_filters, orig_num_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(num_filters, eps=1e-5)
    
    def forward(self, noise):
        x = self.first_linear(noise)
        x = x.view(-1, self.num_filters, self.first_linear_size, self.first_linear_size)
        for resblock in self.resblocks:
            x = resblock(x)
            
        x = self.activation(self.bn(x))
        result = self.last_layer(x)
        return torch.tanh(result)
    

class Discriminator(nn.Module):
    def __init__(self, num_filters=128, num_conv=4):
        super(Discriminator, self).__init__()
        self.num_filters = num_filters
        self.activation = nn.ReLU()
        
        assert(num_conv >= 2, "Number of conv layers in the discriminator must be >= 2 b/c I'm laxy.")
        
        orig_num_filters = 3
        self.resblocks = []
        self.resblocks.append(ResBlock(in_filters=orig_num_filters, 
                     out_filters=self.num_filters, 
                     resample='down', 
                     normalize=False,
                     activation=self.activation,
                     first_activation=False))
        self.resblocks.append(ResBlock(in_filters=self.num_filters, 
                     out_filters=self.num_filters, 
                     resample='down', 
                     normalize=True,
                     activation=self.activation))
        for layer in range(num_conv - 2):
            self.resblocks.append(ResBlock(in_filters=self.num_filters, 
                         out_filters=self.num_filters, 
                         normalize=True,
                         activation=self.activation))
            
        self.last_linear = nn.Linear(num_filters, 1) #TODO: Choose size better
        
   
    def forward(self, x):
        for resblock in self.resblocks:
            x = resblock(x)
            
        x = self.activation(x)
        x = torch.mean(torch.mean(x, dim=3), dim=2)
        batch_size = x.shape[0]
        x = self.last_linear(x) #
        return x
        

class GAN(nn.Module):
    def __init__(self, input_size=32, num_filters=128, num_conv_generator=3, num_conv_discriminator=4,
                 c=5, s=0, batch_size=64,
                 channels=3, exponent=2):  
        super(GAN, self).__init__()
        self.generator = Generator(input_size, num_filters=num_filters, num_conv=num_conv_generator)
        self.discriminator = Discriminator(num_filters=num_filters, num_conv=num_conv_discriminator)
        self.c = c
        self.s = s
        self.batch_size = batch_size
        self.channels = channels
        self.exponent = exponent
        self.input_size = input_size

    def forward_train_generator(self, noise=None):
        generated_image = self.forward_predict_generator(noise)
        discriminator_score = self.forward_predict_discriminator(generated_image)
        return self.generator_loss(discriminator_score)
    
    def forward_train_discriminator(self, real_images, noise=None):
        generated_images = self.forward_predict_generator(noise)
        discriminator_score_generated = self.forward_predict_discriminator(generated_images)
        discriminator_score_real = self.forward_predict_discriminator(real_images)
        return self.discriminator_loss(discriminator_score_real, discriminator_score_generated, real_images, generated_images)
        
    def forward_predict_generator(self, noise=None):
        if noise is None:
            noise = self.generate_noise([self.batch_size, self.input_size])
        return self.generator(noise)
        
    def forward_predict_discriminator(self, image):
        return self.discriminator(image)
    
    def generate_noise(self, tensor_shape):
        return torch.randn(tensor_shape)
    
    def generator_loss(self, d_generated_train):
        return torch.mean(d_generated_train) # Technically we should scale by gamma but why bother??
    
    def discriminator_loss(self, d_score_real, d_score_generated, real_images, generated_images):
        gamma = self.calc_gamma(real_images)
        lamb = self.calc_lambda(real_images)
        
        wasserstein_scaled = (torch.mean(d_score_generated) - torch.mean(d_score_real))
        wasserstein_loss = wasserstein_scaled / gamma
        
        epsilon = torch.empty(self.batch_size, 1, 1, 1).uniform_(0, 1)
        real_fake_mix = epsilon * generated_images + (1 - epsilon) * real_images 
        d_score_mix = torch.mean(self.discriminator(real_fake_mix))
        
        
        gradients = []
        
        def update_grad_variable(grad):
            gradients.append(grad)
            
        real_fake_mix.register_hook(update_grad_variable)    
        d_score_mix.backward(retain_graph=True)
        gradients = gradients[0]
    
        dual_sobolev_gradients = self.sobolev_filter(gradients, c=self.c, s=-self.s)
        dual_exponent = 1 / (1 - 1/self.exponent) if self.exponent != 1 else np.inf
        ddx = self.stable_norm(dual_sobolev_gradients, order=dual_exponent)

        d_regularizer = torch.mean((ddx / gamma - 1) ** 2)
        d_regularizer_mean = torch.mean(d_score_real ** 2)
        
        d_loss = (-wasserstein_loss +
          lamb * d_regularizer +
          1e-5 * d_regularizer_mean)
        
        return d_loss
    
    def calc_gamma(self, real_images):      
        # DUAL NORM
        sobolev_true = self.sobolev_filter(real_images, c=self.c, s=self.s)
        dual_exponent = 1 / (1 - 1/self.exponent) if self.exponent != 1 else np.inf
        gamma = torch.mean(self.stable_norm(sobolev_true, order=dual_exponent))
        return gamma
    
    def calc_lambda(self, real_images):
        sobolev_true = self.sobolev_filter(real_images, c=self.c, s=self.s)
        lamb = torch.mean(self.stable_norm(sobolev_true, order=self.exponent))
        return lamb
    
    def stable_norm(self, x, order):
        x = x.view(self.batch_size, -1) 
        alpha, _ = torch.max(torch.abs(x) + 1e-5, dim=1)
        alpha = alpha.unsqueeze(1)
        result = alpha * torch.norm(x / alpha, p=order, dim=1)
        return result
    
    def sobolev_filter(self, x, c=5, s=1):
        """Apply sobolev filter.

        Parameters
        ----------
        x : pytorch tensor of shape B C W H
        c : float
            Scaling of the coordinate systems (1 / pixel size)
        s : float
            Order of the Sobolev norm
        """
        fft_dim = 2
        fft_x = torch.rfft(x, fft_dim)

        shape = fft_x.shape
        sx = shape[3]
        sy = shape[2]

        # Construct meshgrid for the scale
        x = torch.FloatTensor(range(sx))
        x = torch.min(x, sx - x)
        x = x / (sx // 2)
        y = torch.FloatTensor(range(sy))
        y = torch.min(y, sy - y)
        y = y / (sy // 2)
        X = x.expand(2, sy, sx).permute(1, 2, 0)
        Y = y.expand(2, sx, sy).permute(2, 1, 0)
        X = X.unsqueeze(0).unsqueeze(0)
        Y = Y.unsqueeze(0).unsqueeze(0)

        scale = (1 + c * (X ** 2 + Y ** 2)) ** (s / 2)

        # Compute spatial gradient in fourier space
        fft_x = scale * fft_x

        result_x = torch.irfft(fft_x, fft_dim)
        return result_x
    
    
    

  assert(num_conv >= 2, "Number of conv layers in the discriminator must be >= 2 b/c I'm laxy.")


Guess: 30 compile
    
    EACH:
        - runs
        - Gives correct outuput
        - Correct formula
        - Comments
        okay

In [2]:
def clip_grads(net):
    for p in net.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-5, 5)

def train(model,
            dset_loader,
            optimizer,
            lr_scheduler=None,
            num_epochs=20,
            use_cuda=False,
            num_discriminator=5):

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        if lr_scheduler is not None:
            lr_scheduler.step(epoch)

        # Iterate over data.
        for batch in dset_loader:
            if use_cuda:
                batch = batch.cuda()
            
            for _ in range(num_discriminator):
                loss = model.forward_train_discriminator(batch)
                loss.backward()
                clip_grads(model)
                optimizer.step()
                # zero the parameter gradients
                optimizer.zero_grad()
                
            loss = model.forward_train_generator(batch)
            loss.backward()
            clip_grads(model)
            optimizer.step()
            # zero the parameter gradients
            optimizer.zero_grad()


    return model


In [3]:
input_size = 32
batch_size = 64
model = GAN(input_size)
x_train = [torch.randn(batch_size, 3, 32, 32) for _ in range(3)]
optimizer = optim.Adam(model.parameters(), lr=2e-4)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)
num_epochs = 3
use_cuda=False

import time
train(model, x_train, optimizer, lr_scheduler, num_epochs, use_cuda)

# kangaroo

Epoch 0/2
----------


  "See the documentation of nn.Upsample for details.".format(mode))


RuntimeError: $ Torch: not enough memory: you tried to allocate 3GB. Buy new RAM! at /pytorch/aten/src/TH/THGeneral.cpp:201

# TF VERSION

In [None]:
"""Utilities for computing the sobolev norm."""

import tensorflow as tf

def sobolev_filter(x, c=5, s=1):
    """Apply sobolev filter.

    Parameters
    ----------
    x : tensorflow.Tensor of shape B W H C
        txt
    c : float
        Scaling of the cooridinate systems (1 / pixel size)
    s : float
        Order of the Sobolev norm
    """
    with tf.name_scope('sobolev'):
        # FFT is taken over the innermost axes, so move channel to beginning.
        x = tf.transpose(x, [0, 3, 1, 2])
        fft_x = tf.spectral.fft2d(tf.cast(x, 'complex64'))

        shape = tf.shape(fft_x)
        sx = shape[3]
        sy = shape[2]

        # Construct meshgrid for the scale
        x = tf.range(sx)
        x = tf.minimum(x, sx - x)
        x = tf.cast(x, dtype='complex64') / tf.cast(sx // 2, dtype='complex64')
        y = tf.range(sy)
        y = tf.minimum(y, sy - y)
        y = tf.cast(y, dtype='complex64') / tf.cast(sy // 2, dtype='complex64')
        X, Y = tf.meshgrid(x, y)
        X = X[None, None]
        Y = Y[None, None]

        scale = (1 + c * (X ** 2 + Y ** 2)) ** (s / 2)

        # Compute spatial gradient in fourier space
        fft_x = scale * fft_x

        result_x = tf.spectral.ifft2d(fft_x)
        result_x = tf.real(result_x)
        return tf.transpose(result_x, [0, 2, 3, 1])


In [None]:
# FROM ADLER - https://github.com/adler-j/adler/blob/master/adler/tensorflow/summaries.py

def scalars_summary(name, x):
    with tf.name_scope(name):
        x = tf.reshape(x, [-1])
        mean, var = tf.nn.moments(x, axes=0)
        tf.summary.scalar('mean', mean)
        tf.summary.scalar('std', tf.sqrt(var))
        tf.summary.histogram('histogram', x)
        
        
def image_grid(x, size=8):
    t = tf.unstack(x[:size * size], num=size*size, axis=0)
    rows = [tf.concat(t[i*size:(i+1)*size], axis=0) for i in range(size)]
    image = tf.concat(rows, axis=1)
    return image[None]
        
def image_grid_summary(name, x):
    with tf.name_scope(name):
        tf.summary.image('grid', image_grid(x))
        
def default_checkpoint_path(name):
    checkpoint_dir = join(get_base_dir(), 'checkpoints')
    if not exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    checkpoint_path = join(checkpoint_dir,
                           '{}.ckpt'.format(name))

    return checkpoint_path

def default_tensorboard_dir(name):
    tensorboard_dir = join(get_base_dir(), 'tensorboard', name)
    if not exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    return tensorboard_dir

def get_base_dir():
    """Get the data directory."""
    base_odl_dir = os.environ.get('ADLER_HOME',
                                  expanduser(join('~', '.adler')))
    data_home = join(base_odl_dir, 'tensorflow')
    if not exists(data_home):
        os.makedirs(data_home)
    return data_home


def summary_writers(name, cleanup=False, session=None, write_graph=True):
    if session is None:
        session = tf.get_default_session()

    dname = default_tensorboard_dir(name)

    if cleanup and os.path.exists(dname):
        shutil.rmtree(dname, ignore_errors=True)

    if write_graph:
        graph = session.graph
    else:
        graph = None

    test_summary_writer = tf.summary.FileWriter(dname + '/test', graph)
    train_summary_writer = tf.summary.FileWriter(dname + '/train')

    return test_summary_writer, train_summary_writer

class EMAHelper(object):
    def __init__(self, decay=0.99, session=None):
        if session is None:
            self.session = tf.get_default_session()
        else:
            self.session = session

        self.all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        self.ema = tf.train.ExponentialMovingAverage(decay=decay)
        self.apply = self.ema.apply(self.all_vars)
        self.averages = [self.ema.average(var) for var in self.all_vars]

    def average_dict(self):
        ema_averages_results = self.session.run(self.averages)
        return {var: value for var, value in zip(self.all_vars,
                                                 ema_averages_results)}

    def variables_to_restore(self):
        return self.ema.variables_to_restore(tf.moving_average_variables())



In [None]:
"""Code for training Banach Wasserstein GAN on CIFAR 10.

With all the dependencies installed, the code should run as-is. 
Data is downloaded on the fly.
"""

import sys
import tensorflow as tf
import numpy as np
import tensordata
import functools

# User selectable parameters
EXPONENT = 2
SOBOLEV_C = 5.0
SOBOLEV_S = 0
MAX_ITERS = 100000
SUMMARY_FREQ = 10
INCEPTION_FREQ = 1000
BATCH_SIZE = 64
BATCH_SIZE_TEST = 100
reset = True

# set seeds for reproducibility
np.random.seed(0)
tf.set_random_seed(0)

sess = tf.InteractiveSession()

# Training specific parameters
size = 32
DUAL_EXPONENT = 1 / (1 - 1/EXPONENT) if EXPONENT != 1 else np.inf



with tf.name_scope('placeholders'):
    x_train_ph, _ = tensordata.get_cifar10_tf(batch_size=BATCH_SIZE)
    print("SHAPE", x_train_ph.shape)
    x_test_ph, _ = tensordata.get_cifar10_tf(batch_size=BATCH_SIZE_TEST)

    is_training = tf.placeholder(bool, name='is_training')
    use_agumentation = tf.identity(is_training, name='is_training')


with tf.name_scope('pre_process'):
    x_train = (x_train_ph - 0.5) * 2.0
    x_test = (x_test_ph - 0.5) * 2.0

    x_true = tf.cond(is_training,
                     lambda: x_train,
                     lambda: x_test)

def apply_conv(x, filters=32, kernel_size=3, he_init=True):
    if he_init:
        initializer = tf.contrib.layers.variance_scaling_initializer(uniform=True)
    else:
        initializer = tf.contrib.layers.xavier_initializer(uniform=True)

    return tf.layers.conv2d(x, filters=filters, kernel_size=kernel_size,
                            padding='SAME', kernel_initializer=initializer)


def activation(x):
    with tf.name_scope('activation'):
        return tf.nn.relu(x)


def bn(x):
    return tf.contrib.layers.batch_norm(x,
                                    decay=0.9,
                                    center=True,
                                    scale=True,
                                    epsilon=1e-5,
                                    zero_debias_moving_mean=True,
                                    is_training=is_training)


def stable_norm(x, ord):
    x = tf.contrib.layers.flatten(x)
    alpha = tf.reduce_max(tf.abs(x) + 1e-5, axis=1)
    result = alpha * tf.norm(x / alpha[:, None], ord=ord, axis=1)
    return result


def downsample(x):
    with tf.name_scope('downsample'):
        x = tf.identity(x)
        return tf.add_n([x[:,::2,::2,:], x[:,1::2,::2,:],
                         x[:,::2,1::2,:], x[:,1::2,1::2,:]]) / 4.

def upsample(x):
    with tf.name_scope('upsample'):
        x = tf.identity(x)
        x = tf.concat([x, x, x, x], axis=-1)
        return tf.depth_to_space(x, 2)


def conv_meanpool(x, **kwargs):
    return downsample(apply_conv(x, **kwargs))

def meanpool_conv(x, **kwargs):
    return apply_conv(downsample(x), **kwargs)

def upsample_conv(x, **kwargs):
    return apply_conv(upsample(x), **kwargs)

def resblock(x, filters, resample=None, normalize=False):
    if normalize:
        norm_fn = bn
    else:
        norm_fn = tf.identity

    if resample == 'down':
        conv_1 = functools.partial(apply_conv, filters=filters)
        conv_2 = functools.partial(conv_meanpool, filters=filters)
        conv_shortcut = functools.partial(conv_meanpool, filters=filters,
                                          kernel_size=1, he_init=False)
    elif resample == 'up':
        conv_1 = functools.partial(upsample_conv, filters=filters)
        conv_2 = functools.partial(apply_conv, filters=filters)
        conv_shortcut = functools.partial(upsample_conv, filters=filters,
                                          kernel_size=1, he_init=False)
    elif resample == None:
        conv_1 = functools.partial(apply_conv, filters=filters)
        conv_2 = functools.partial(apply_conv, filters=filters)
        conv_shortcut = tf.identity

    with tf.name_scope('resblock'):
        x = tf.identity(x)
        update = conv_1(activation(norm_fn(x)))
        update = conv_2(activation(norm_fn(update)))

        skip = conv_shortcut(x)
        return skip + update


def resblock_optimized(x, filters):
    with tf.name_scope('resblock'):
        x = tf.identity(x)
        update = apply_conv(x, filters=filters)
        update = conv_meanpool(activation(update), filters=filters)

        skip = meanpool_conv(x, filters=128, kernel_size=1, he_init=False)
        return skip + update


def generator(z, reuse):
    with tf.variable_scope('generator', reuse=reuse):
        with tf.name_scope('pre_process'):
            z = tf.layers.dense(z, 4 * 4 * 128)
            x = tf.reshape(z, [-1, 4, 4, 128])

        with tf.name_scope('x1'):
            x = resblock(x, filters=128, resample='up', normalize=True) # 8
            x = resblock(x, filters=128, resample='up', normalize=True) # 16
            x = resblock(x, filters=128, resample='up', normalize=True) # 32

        with tf.name_scope('post_process'):
            x = activation(bn(x))
            result = apply_conv(x, filters=3, he_init=False)
            return tf.tanh(result)


def discriminator(x, reuse):
    with tf.variable_scope('discriminator', reuse=reuse):
        with tf.name_scope('pre_process'):
            x = resblock_optimized(x, filters=128)

        with tf.name_scope('x1'):
            x = resblock(x, filters=128, resample='down') # 8
            x = resblock(x, filters=128) # 16
            x = resblock(x, filters=128) # 32

        with tf.name_scope('post_process'):
            x = activation(x)
            x = tf.reduce_mean(x, axis=[1, 2])
            flat = tf.contrib.layers.flatten(x)
            flat = tf.layers.dense(flat, 1)
            return flat


with tf.name_scope('gan'):
    z = tf.random_normal([tf.shape(x_true)[0], 128], name="z")

    x_generated = generator(z, reuse=True) # ORIGINALLY FALSE

    d_true = discriminator(x_true, reuse=True) # ORIGINALLY FALSE
    d_generated = discriminator(x_generated, reuse=True)

    z_gen = tf.random_normal([BATCH_SIZE * 2, 128], name="z")
    d_generated_train = discriminator(generator(z_gen, reuse=True), reuse=True)

with tf.name_scope('dual_norm'):
    sobolev_true = sobolev_filter(x_true, c=SOBOLEV_C, s=SOBOLEV_S)
    lamb = tf.reduce_mean(stable_norm(sobolev_true, ord=EXPONENT))
    dual_sobolev_true = sobolev_filter(x_true, c=SOBOLEV_C, s=-SOBOLEV_S)
    gamma = tf.reduce_mean(stable_norm(sobolev_true, ord=DUAL_EXPONENT))

with tf.name_scope('regularizer'):
    epsilon = tf.random_uniform([tf.shape(x_true)[0], 1, 1, 1], 0.0, 1.0)
    x_hat = epsilon * x_generated + (1 - epsilon) * x_true
    d_hat = discriminator(x_hat, reuse=True)

    gradients = tf.gradients(d_hat, x_hat)[0]
    dual_sobolev_gradients = sobolev_filter(gradients, c=SOBOLEV_C, s=-SOBOLEV_S)
    ddx = stable_norm(dual_sobolev_gradients, ord=DUAL_EXPONENT)

    d_regularizer = tf.reduce_mean(tf.square(ddx / gamma - 1))
    d_regularizer_mean = tf.reduce_mean(tf.square(d_true))

with tf.name_scope('loss_gan'):
    wasserstein_scaled = (tf.reduce_mean(d_generated) - tf.reduce_mean(d_true))
    wasserstein = wasserstein_scaled / gamma

    g_loss = tf.reduce_mean(d_generated_train) / gamma
    d_loss = (-wasserstein +
              lamb * d_regularizer +
              1e-5 * d_regularizer_mean)

with tf.name_scope('optimizer'):
    ema = EMAHelper(decay=0.99)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    decay = tf.maximum(0., 1.-(tf.cast(global_step, tf.float32)/MAX_ITERS))
    learning_rate = 2e-4 * decay
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0., beta2=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='gan/generator')
    g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    with tf.control_dependencies(update_ops):
        g_train = optimizer.minimize(g_loss, var_list=g_vars,
                                     global_step=global_step)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='gan/discriminator')
    d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
    with tf.control_dependencies(update_ops):
        d_train = optimizer.minimize(d_loss, var_list=d_vars)


with tf.name_scope('summaries'):
    tf.summary.scalar('wasserstein_scaled', wasserstein_scaled)
    tf.summary.scalar('wasserstein', wasserstein)

    tf.summary.scalar('g_loss', g_loss)

    tf.summary.scalar('d_loss', d_loss)
    scalars_summary('d_true', d_true)
    scalars_summary('d_generated', d_generated)
    tf.summary.scalar('d_regularizer', d_regularizer)
    tf.summary.scalar('d_regularizer_mean', d_regularizer_mean)

    tf.summary.scalar('learning_rate', learning_rate)
    tf.summary.scalar('global_step', global_step)

    scalars_summary('x_generated', x_generated)
    scalars_summary('x_true', x_true)

    scalars_summary('gamma', gamma)
    scalars_summary('lamb', lamb)

    image_grid_summary('x_true', x_true)
    image_grid_summary('x_generated', x_generated)
    image_grid_summary('gradients', gradients)
    image_grid_summary('dual_sobolev_gradients', dual_sobolev_gradients)

    scalars_summary('ddx', ddx)
    scalars_summary('gradients', gradients)
    scalars_summary('dual_sobolev_gradients', dual_sobolev_gradients)

    merged_summary = tf.summary.merge_all()

    # Advanced metrics
    with tf.name_scope('inception'):
        # Specific function to compute inception score for very large
        # number of samples
        def generate_and_classify(z):
            INCEPTION_OUTPUT = 'logits:0'
            x = generator(z, reuse=True)
            x = tf.image.resize_bilinear(x, [299, 299])
            return tf.contrib.gan.eval.run_inception(x, output_tensor=INCEPTION_OUTPUT)

        # Fixed z for fairness between runs
        inception_z = tf.constant(np.random.randn(10000, 128), dtype='float32')
        inception_score = tf.contrib.gan.eval.classifier_score(inception_z,
                                                               classifier_fn=generate_and_classify,
                                                               num_batches=10000 // 100)

        inception_summary = tf.summary.merge([
                tf.summary.scalar('inception_score', inception_score)])

        full_summary = tf.summary.merge([merged_summary, inception_summary])

    test_summary_writer, train_summary_writer = summary_writers(name, cleanup=reset, write_graph=False)

# Initialize all TF variables
sess.run([tf.global_variables_initializer(),
          tf.local_variables_initializer()])

# Coordinate the loading of image files.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

# Add op to save and restore
saver = tf.train.Saver()
if not reset:
    saver.restore(sess, default_checkpoint_path(name))

# Standardized validation z
z_validate = np.random.randn(BATCH_SIZE_TEST, 128)

# Train the network
while True:
    i = sess.run(global_step)
    if i >= MAX_ITERS:
        break

    num_d_train = 5
    for j in range(num_d_train):
        _, d_loss_result = sess.run([d_train, d_loss],
                                    feed_dict={is_training: True})

    _, g_loss_result, _ = sess.run([g_train, g_loss, ema.apply],
             feed_dict={is_training: True})

    print('s={}, i={}, j={}, d_loss={}, g_loss={}'.format(SOBOLEV_S, i, j,
                                                    d_loss_result,
                                                    g_loss_result))

    if i % SUMMARY_FREQ == SUMMARY_FREQ - 1:
        ema_dict = ema.average_dict()
        merged_summary_result_train = sess.run(merged_summary,
                                         feed_dict={is_training: False,
                                                    **ema_dict})
        train_summary_writer.add_summary(merged_summary_result_train, i)
    if i % INCEPTION_FREQ == INCEPTION_FREQ - 1:
        ema_dict = ema.average_dict()
        merged_summary_result_test = sess.run(full_summary,
                                         feed_dict={z: z_validate,
                                                    is_training: False,
                                                    **ema_dict})
        test_summary_writer.add_summary(merged_summary_result_test, i)


    if i % 1000 == 999:
        saver.save(sess, default_checkpoint_path(name))