In [None]:
# GAN training
# this is pseudocode! don't expect it to run!
def train_step(real_batch):
    # train discriminator
    generated_batch = generator(noise_function())
    real_labels = ones(real_batch.shape)
    generated_labels = zeros(generated_batch.shape)
    
    # concatenating is optional
    full_batch = concat(real_batch, generated_batch)
    full_labels = conact(real_labels, generated_labels)
    # NOTE it's enough to start gradient tape here as we don't train G
    d_output = discriminator(full_batch)
    
    # as always, be mindful of sigmoid and logits.
    # i.e. for numerical stability it's good to use cross-entropy from_logits.
    # then do not have a sigmoid in your discriminator output layer!
    loss = crossentropy(full_labels, d_output)
    compute_and_apply_gradients(loss, d_variables)
    
    
    # train generator
    # this time, you need the tape here already, since we backpropagate through G
    generated_batch = generator(noise_function())
    fake_labels = ones(generated_batch.shape)  # ones this time!!
    d_output = discriminator(generated_batch)
    
    # this implements the "flipped loss" which has better gradients for G
    loss = crossentropy(fake_labels, d_output)
    compute_and_apply_gradients(loss, g_variables)
    

# here are some ideas for further "tricks" that MAY help

# (one-sided) label-smoothing
# - use soft labels such as 0.1 instead of 0, and 0.9 instead of 1
#   - in the one-sided version, it is proposed to only smooth the 1 label.
#     also, ONLY smooth it when training D, not when training G.

# normalize data into [-1, 1] instead of [0, 1]
# - accordingly, use tanh as output activation in G instead of sigmoid.
# - seems like it should do nothing, but supposedly it helps with the learning dynamics.

# "dequantize" the training data
# - just add noise! in theory, uniform noise with range of +/- 1 pixel
# - so. if your data is in [0, 1], +/- 1/256.

# DO NOT use batchnorm in the discriminator. DEFINITELY not in the first layer.
# - you can use e.g. LayerNormalization, or InstanceNormalization, GroupNormalization from tensorflow_addons.
# - batchnorm in the generator is okay.
# if you are interested in this issue, you could read this https://ovgu-ailab.github.io/blog/methods/2022/07/07/batchnorm-gans.html

# feature matching is great! see the paper "Improved Techniques for Training GANs",
# section 3.1: https://arxiv.org/pdf/1606.03498.pdf
# here is very rough sketch how it may be implemented

# again this is just pseudocode!
# discriminator definition
inputs = tf.keras.Input(noise_shape)
h1 = layer(inputs)
h2 = layer(h1)
output = layer(h3)

# now don't do this
discriminator = Model(inputs, output)
# ...do this instead!
# there is no theory on which hidden layer(s) to pick.
# you can choose one, or multiple and compute feature matching for all chosen layers, and add up the losses.
discriminator = Model(input, [output, h1, h2])

# training D proceeds as normal
outputs = discriminator(batch)
logits = outputs[0]  # take final layer output, discard hidden layers
...

# feature matching is only used for training G
real_outputs = discriminator(real_batch)
real_features = real_outputs[1:]  # discard logits, keep hidden layer outputs
fake_outputs = discriminator(generated_batch)
fake_features = fake_outputs[1:]

total_loss = 0
for real_feature, fake_feature in zip(real_features, fake_features):
    # monte carlo approximation of expectations, AKA mean over batch
    squared_difference = (mean(real_feature, axis=0) - mean(fake_feature, axis=0))**2 
    total_loss += sum(squared_difference)  # sum over feature dimensions
    
# you can also still train G on the normal loss (just add it), but I dunno if it really makes a difference