# MNIST GAN

## Imports

In [1]:
import Datasets
import Foundation
import ModelSupport
import TensorFlow
%include "GANSupport.swift"

## Parameters

In [2]:
let epochCount = 10
let batchSize = 32
let outputFolder = "./output/"
let imageHeight = 28
let imageWidth = 28
let imageSize = imageHeight * imageWidth
let latentSize = 64

In [3]:
struct Generator: Layer {
    var dense1 = Dense<Float>(
        inputSize: latentSize, outputSize: latentSize * 2,
        activation: { leakyRelu($0) })

    var dense2 = Dense<Float>(
        inputSize: latentSize * 2, outputSize: latentSize * 4,
        activation: { leakyRelu($0) })

    var dense3 = Dense<Float>(
        inputSize: latentSize * 4, outputSize: latentSize * 8,
        activation: { leakyRelu($0) })

    var dense4 = Dense<Float>(
        inputSize: latentSize * 8, outputSize: imageSize,
        activation: tanh)

    var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2)
    var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4)
    var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        let x1 = batchnorm1(dense1(input))
        let x2 = batchnorm2(dense2(x1))
        let x3 = batchnorm3(dense3(x2))
        return dense4(x3)
    }
}

## Discriminator Model

In [4]:
struct Discriminator: Layer {
    var dense1 = Dense<Float>(
        inputSize: imageSize, outputSize: 256,
        activation: { leakyRelu($0) })

    var dense2 = Dense<Float>(
        inputSize: 256, outputSize: 64,
        activation: { leakyRelu($0) })

    var dense3 = Dense<Float>(
        inputSize: 64, outputSize: 16,
        activation: { leakyRelu($0) })

    var dense4 = Dense<Float>(
        inputSize: 16, outputSize: 1,
        activation: identity)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        input.sequenced(through: dense1, dense2, dense3, dense4)
    }
}

## Loss functions

In [5]:
@differentiable
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {
    sigmoidCrossEntropy(
        logits: fakeLogits,
        labels: Tensor(ones: fakeLogits.shape))
}

@differentiable
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
    let realLoss = sigmoidCrossEntropy(
        logits: realLogits,
        labels: Tensor(ones: realLogits.shape))
    let fakeLoss = sigmoidCrossEntropy(
        logits: fakeLogits,
        labels: Tensor(zeros: fakeLogits.shape))
    return realLoss + fakeLoss
}


In [6]:
/// Returns `size` samples of noise vector.
func sampleVector(size: Int) -> Tensor<Float> {
    Tensor(randomNormal: [size, latentSize])
}

In [7]:
let dataset = MNIST(batchSize: batchSize, flattening: true, normalizing: true)

var generator = Generator()
var discriminator = Discriminator()

let optG = Adam(for: generator, learningRate: 2e-4, beta1: 0.5)
let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)

Loading resource: train-images-idx3-ubyte
Loading local data at: /notebooks/TFWorld 2019 Finished Examples/train-images-idx3-ubyte
Succesfully loaded resource: train-images-idx3-ubyte
Loading resource: train-labels-idx1-ubyte
Loading local data at: /notebooks/TFWorld 2019 Finished Examples/train-labels-idx1-ubyte
Succesfully loaded resource: train-labels-idx1-ubyte
2019-10-21 10:43:35.369831: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 188160000 exceeds 10% of system memory.
2019-10-21 10:43:35.415497: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2019-10-21 10:43:35.441600: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2400000000 Hz
2019-10-21 10:43:35.445355: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x259c050 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2

In [8]:
// Noise vectors and plot function for testing
let testImageGridSize = 4
let testVector = sampleVector(size: testImageGridSize * testImageGridSize)

In [9]:
func saveImageGrid(_ testImage: Tensor<Float>, name: String) throws {
    var gridImage = testImage.reshaped(
        to: [
            testImageGridSize, testImageGridSize,
            imageHeight, imageWidth,
        ])
    // Add padding.
    gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
    // Transpose to create single image.
    gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
    gridImage = gridImage.reshaped(
        to: [
            (imageHeight + 2) * testImageGridSize,
            (imageWidth + 2) * testImageGridSize,
        ])
    // Convert [-1, 1] range to [0, 1] range.
    gridImage = (gridImage + 1) / 2

    try saveImage(
        gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
        name: name)
}

## Training  and Inference

In [None]:
print("Start training...")

// Start training loop.
for epoch in 1...epochCount {
    // Start training phase.
    Context.local.learningPhase = .training
    for i in 0 ..< dataset.trainingSize / batchSize {
        // Perform alternative update.
        // Update generator.
        let vec1 = sampleVector(size: batchSize)

        let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
            let fakeImages = generator(vec1)
            let fakeLogits = discriminator(fakeImages)
            let loss = generatorLoss(fakeLogits: fakeLogits)
            return loss
        }
        optG.update(&generator, along: 𝛁generator)

        // Update discriminator.
        let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
        let vec2 = sampleVector(size: batchSize)
        let fakeImages = generator(vec2)

        let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
            let realLogits = discriminator(realImages)
            let fakeLogits = discriminator(fakeImages)
            let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)
            return loss
        }
        optD.update(&discriminator, along: 𝛁discriminator)
    }

    // Start inference phase.
    Context.local.learningPhase = .inference
    let testImage = generator(testVector)

    do {
        try saveImageGrid(testImage, name: "epoch-\(epoch)-output")
    } catch {
        print("Could not save image grid with error: \(error)")
    }

    let lossG = generatorLoss(fakeLogits: testImage)
    print("[Epoch: \(epoch)] Loss-G: \(lossG)")
}

Start training...
