# Complete - Building a GAN

We're not here to teach the fundamentals of neural networks or ML, but we think GANs are a pretty neat demo. GANs (Generative Adversarial Networks) have two entirely separate networks (models) that work together/compete against each other to generate something.

Their overarching goal is to generate new data that is somewhat similar to some of the data they were trained with.
    
Basically, the **generator** generates fake images that are then used by the **discriminator** to see if they're real. Working together, they both get cleverer and cleverer, until the discriminator cannot distinguish the difference between generator-generated images, and the real thing.

## Imports

We need `Foundation` so we can use the Swift types, `FoundationNetworking` so we can download stuff, `TensorFlow`, so we can use the machine learning bits and pieces, 

NOTE: If you're running this on your own local install then you might also need to import `Datasets` and `ModelSupport`, which helps you work with existing datasets and files. 

In [0]:
import Foundation
import FoundationNetworking
import TensorFlow

### Some support code

This is a collection of convenience methods and helpers to write/read files, and such. It's quite long, so leave this section collapsed. The code here is a little beyond the scope of the session. Ask us, and if we have time we can go through it with you.

We need to bring in some support Swift code that allows us to manipulate local files, download files, and get the MNIST dataset. You can expand this and read it if you want, but it's beyond the scope of this session.

In [0]:
// This code comes from the Swift-Models repo, from the TF team.

// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

public struct DatasetUtilities {
    public static let curentWorkingDirectoryURL = URL(
        fileURLWithPath: FileManager.default.currentDirectoryPath)

    public static func fetchResource(
        filename: String,
        remoteRoot: URL,
        localStorageDirectory: URL = curentWorkingDirectoryURL
    ) -> Data {
        print("Loading resource: \(filename)")

        let resource = ResourceDefinition(
            filename: filename,
            remoteRoot: remoteRoot,
            localStorageDirectory: localStorageDirectory)

        let localURL = resource.localURL

        if !FileManager.default.fileExists(atPath: localURL.path) {
            print(
                "File does not exist locally at expected path: \(localURL.path) and must be fetched"
            )
            fetchFromRemoteAndSave(resource)
        }

        do {
            print("Loading local data at: \(localURL.path)")
            let data = try Data(contentsOf: localURL)
            print("Succesfully loaded resource: \(filename)")
            return data
        } catch {
            fatalError("Failed to contents of resource: \(localURL)")
        }
    }

    struct ResourceDefinition {
        let filename: String
        let remoteRoot: URL
        let localStorageDirectory: URL

        var localURL: URL {
            localStorageDirectory.appendingPathComponent(filename)
        }

        var remoteURL: URL {
            remoteRoot.appendingPathComponent(filename).appendingPathExtension("gz")
        }

        var archiveURL: URL {
            localURL.appendingPathExtension("gz")
        }
    }

    static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
        let remoteLocation = resource.remoteURL
        let archiveLocation = resource.archiveURL

        do {
            print("Fetching URL: \(remoteLocation)...")
            let archiveData = try Data(contentsOf: remoteLocation)
            print("Writing fetched archive to: \(archiveLocation.path)")
            try archiveData.write(to: archiveLocation)
        } catch {
            fatalError("Failed to fetch and save resource with error: \(error)")
        }
        print("Archive saved to: \(archiveLocation.path)")

        extractArchive(for: resource)
    }

    static func extractArchive(for resource: ResourceDefinition) {
        print("Extracting archive...")

        let archivePath = resource.archiveURL.path

        #if os(macOS)
            let gunzipLocation = "/usr/bin/gunzip"
        #else
            let gunzipLocation = "/bin/gunzip"
        #endif

        let task = Process()
        task.executableURL = URL(fileURLWithPath: gunzipLocation)
        task.arguments = [archivePath]
        do {
            try task.run()
            task.waitUntilExit()
        } catch {
            fatalError("Failed to extract \(archivePath) with error: \(error)")
        }
    }
}


public struct MNIST {
    public let trainingImages: Tensor<Float>
    public let trainingLabels: Tensor<Int32>
    public let testImages: Tensor<Float>
    public let testLabels: Tensor<Int32>

    public let trainingSize: Int
    public let testSize: Int

    public let batchSize: Int

    public init(
        batchSize: Int, flattening: Bool = false, normalizing: Bool = false,
        localStorageDirectory: URL = DatasetUtilities.curentWorkingDirectoryURL
    ) {
        self.batchSize = batchSize

        let (trainingImages, trainingLabels) = fetchDataset(
            localStorageDirectory: localStorageDirectory,
            imagesFilename: "train-images-idx3-ubyte",
            labelsFilename: "train-labels-idx1-ubyte",
            flattening: flattening,
            normalizing: normalizing)

        self.trainingImages = trainingImages
        self.trainingLabels = trainingLabels
        self.trainingSize = Int(trainingLabels.shape[0])

        let (testImages, testLabels) = fetchDataset(
            localStorageDirectory: localStorageDirectory,
            imagesFilename: "t10k-images-idx3-ubyte",
            labelsFilename: "t10k-labels-idx1-ubyte",
            flattening: flattening,
            normalizing: normalizing)
        self.testImages = testImages
        self.testLabels = testLabels
        self.testSize = Int(testLabels.shape[0])
    }
}

extension Tensor {
    public func minibatch(at index: Int, batchSize: Int) -> Tensor {
        let start = index * batchSize
        return self[start..<start+batchSize]
    }
}

fileprivate func fetchDataset(
    localStorageDirectory: URL,
    imagesFilename: String,
    labelsFilename: String,
    flattening: Bool,
    normalizing: Bool
) -> (images: Tensor<Float>, labels: Tensor<Int32>) {
    guard let remoteRoot: URL = URL(string: "http://yann.lecun.com/exdb/mnist") else {
        fatalError("Failed to create MNST root url: http://yann.lecun.com/exdb/mnist")
    }

    let imagesData = DatasetUtilities.fetchResource(
        filename: imagesFilename,
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)
    let labelsData = DatasetUtilities.fetchResource(
        filename: labelsFilename,
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)

    let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
    let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)

    let rowCount = labels.count
    let (imageWidth, imageHeight) = (28, 28)

    if flattening {
        var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
            / 255.0
        if normalizing {
            flattenedImages = flattenedImages * 2.0 - 1.0
        }
        return (images: flattenedImages, labels: Tensor(labels))
    } else {
        return (
            images:
                Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
                    .transposed(withPermutations: [0, 2, 3, 1]) / 255,  // NHWC
            labels: Tensor(labels)
        )
    }
}

public func createDirectoryIfMissing(at path: String) throws {
    guard !FileManager.default.fileExists(atPath: path) else { return }
    try FileManager.default.createDirectory(
        atPath: path,
        withIntermediateDirectories: false,
        attributes: nil)
}


public struct Image {
    public enum ByteOrdering {
        case bgr
        case rgb
    }

    enum ImageTensor {
        case float(data: Tensor<Float>)
        case uint8(data: Tensor<UInt8>)
    }

    let imageData: ImageTensor

    public init(tensor: Tensor<UInt8>) {
        self.imageData = .uint8(data: tensor)
    }

    public init(tensor: Tensor<Float>) {
        self.imageData = .float(data: tensor)
    }

    public init(jpeg url: URL, byteOrdering: ByteOrdering = .rgb) {
        let loadedFile = Raw.readFile(filename: StringTensor(url.absoluteString))
        let loadedJpeg = Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "")
        if byteOrdering == .bgr {
            self.imageData = .uint8(
                data: Raw.reverse(loadedJpeg, dims: Tensor<Bool>([false, false, false, true])))
        } else {
            self.imageData = .uint8(data: loadedJpeg)
        }
    }

    public func save(to url: URL, quality: Int64 = 95) {
        // This currently only saves in grayscale.
        let outputImageData: Tensor<UInt8>
        switch self.imageData {
        case let .uint8(data): outputImageData = data
        case let .float(data):
            let lowerBound = data.min(alongAxes: [0, 1])
            let upperBound = data.max(alongAxes: [0, 1])
            let adjustedData = (data - lowerBound) * (255.0 / (upperBound - lowerBound))
            outputImageData = Tensor<UInt8>(adjustedData)
        }

        let encodedJpeg = Raw.encodeJpeg(
            image: outputImageData, format: .grayscale, quality: quality, xmpMetadata: "")
        Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg)
    }

    public func resized(to size: (Int, Int)) -> Image {
        switch self.imageData {
        case let .uint8(data):
            return Image(
                tensor: Raw.resizeBilinear(
                    images: Tensor<UInt8>([data]),
                    size: Tensor<Int32>([Int32(size.0), Int32(size.1)])))
        case let .float(data):
            return Image(
                tensor: Raw.resizeBilinear(
                    images: Tensor<Float>([data]),
                    size: Tensor<Int32>([Int32(size.0), Int32(size.1)])))
        }

    }
}

public func saveImage(_ tensor: Tensor<Float>, size: (Int, Int), directory: String, name: String) throws {
    try createDirectoryIfMissing(at: directory)
    let reshapedTensor = tensor.reshaped(to: [size.0, size.1, 1])
    let image = Image(tensor: reshapedTensor)
    let outputURL = URL(fileURLWithPath:"\(directory)\(name).jpg")
    image.save(to: outputURL)
}

## Parameters

Our parameters are as follows:

* `epochCount` is how many epochs it should train for. 10 is a good number to get a reasonable GAN in this case.
* `batchSize` is the size of a batch that we're going to ask the MNIST dataset for.
* `outputFolder` defines the output folder where we'll be writing things on the file system.
* `imageHeight` and `imageWidth`, together with `imageSize` define the output imagesize that the Generator will make, as well as (naturally) the input image size the Discriminator will take.
* `latentSize` defines the latent representation size used by the Generator to generate.
* `testImageGridSize` defines the size of the grid of images that we'll generate to look at the result of the GAN.

In [0]:
let epochCount = 10
let batchSize = 32
let outputFolder = "./MNIST_GAN_Output/"
let imageHeight = 28
let imageWidth = 28
let imageSize = imageHeight * imageWidth
let latentSize = 64
let testImageGridSize = 4

## Convenience helper to save an image grid

In [0]:
func saveImageGrid(_ testImage: Tensor<Float>, name: String) throws {
    var gridImage = testImage.reshaped(
        to: [
            testImageGridSize, testImageGridSize,
            imageHeight, imageWidth,
        ])

    // Add padding.
    gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)

    // Transpose to create single image.
    gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
    gridImage = gridImage.reshaped(
        to: [
            (imageHeight + 2) * testImageGridSize,
            (imageWidth + 2) * testImageGridSize,
        ])
        
    // Convert [-1, 1] range to [0, 1] range.
    gridImage = (gridImage + 1) / 2

    try saveImage(
        gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
        name: name)
}

# Generator Model

Our `Generator` is a `Struct` adhering to the  [`Layer` Protocol](https://www.tensorflow.org/swift/api_docs/Protocols/Layer) (which is part of Swift For TensorFlow's API). The Generator has the following layers:

* `dense1`, a `Dense` layer (a [densely-connected layer](https://www.tensorflow.org/swift/api_docs/Structs/Dense)) that takes an `inputSize` of `latentSize` (defined earlier), and an `outputSize` of `latentSize*2`. The `activation` function determines the output shape of each node in the layer. There are many available activations, but [ReLU](https://www.tensorflow.org/swift/api_docs/Functions#leakyrelu_:alpha:) is common for hidden layers.

* `dense2` is likewise, but with an `inputSize` of `latentSize*2` (taking the output of the previous layer), and an `outputSize` of `latestSize*4`.

* `dense3` is likewise, taking the previous output as input, and outputting it larger.

* `dense4` is, again, the same, but has an `outputSize` of `imageSize` instead (our final desired image size). It uses [tanh](https://www.tensorflow.org/swift/api_docs/Functions#tanh_:) as its activation, tanh (hyperbolic tangent) is sigmoidal (s-shaped) and outputs values that range from -1 to 1.

* three [`BatchNorm`]() layers, `batchnorm1`, `batchnorm2`, `batchnorm3`, that normalise the activations of the previous layer at each batch by applying transformations that maintain the mean activation close to 0 and the activation standard deviation close to 1. `featureCount` is the number of features.
    
Finally, we have our `callAsFunction()` method, which sequences through the `Dense` layers, using the `BatchNorm` layers to normalise, before finally returning the output of the fourth and final `Dense` layer.



    

In [0]:
struct Generator: Layer {
    var dense1 = Dense<Float>(
        inputSize: latentSize, outputSize: latentSize * 2,
        activation: { leakyRelu($0) })

    var dense2 = Dense<Float>(
        inputSize: latentSize * 2, outputSize: latentSize * 4,
        activation: { leakyRelu($0) })

    var dense3 = Dense<Float>(
        inputSize: latentSize * 4, outputSize: latentSize * 8,
        activation: { leakyRelu($0) })

    var dense4 = Dense<Float>(
        inputSize: latentSize * 8, outputSize: imageSize,
        activation: tanh)

    var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2)
    var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4)
    var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        let x1 = batchnorm1(dense1(input))
        let x2 = batchnorm2(dense2(x1))
        let x3 = batchnorm3(dense3(x2))
        return dense4(x3)
    }
}

## Discriminator Model

Our `Discriminator` is a `Struct` adhering to the `Layer` Protocol. The `Discriminator` has the following layers:

* `dense1`, a `Dense` layer, taking an `inputSize` of `imageSize`, outputting an `outputSize` of 256. It also uses ReLU for activation.

* `dense2` and `dense3`, which take an `inputSize` and `outputSize` of 256 and 64, and 64 and 16, respectively, also using ReLU.

* `dense4`, which takes the `inputSize` of 16, and has an `outputSize` of 1, and using `identity` as the activation (just linear).

Finally, we have our `callAsFunction()` method, which just sequences the input through the four (`Dense`) layers.

In [0]:
struct Discriminator: Layer {
    var dense1 = Dense<Float>(
        inputSize: imageSize, outputSize: 256,
        activation: { leakyRelu($0) })

    var dense2 = Dense<Float>(
        inputSize: 256, outputSize: 64,
        activation: { leakyRelu($0) })

    var dense3 = Dense<Float>(
        inputSize: 64, outputSize: 16,
        activation: { leakyRelu($0) })

    var dense4 = Dense<Float>(
        inputSize: 16, outputSize: 1,
        activation: identity)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        input.sequenced(through: dense1, dense2, dense3, dense4)
    }
}

## Loss functions

### Discriminator Loss Function

Our `discriminatorLoss()` function, which takes both the real and fake [logits](https://datascience.stackexchange.com/a/31045), and returns the `realLoss` and `fakeLoss`, via the `sigmoidCrossEntropy()` function. That's it!

In [0]:
@differentiable
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
    let realLoss = sigmoidCrossEntropy(
        logits: realLogits,
        labels: Tensor(ones: realLogits.shape))
    let fakeLoss = sigmoidCrossEntropy(
        logits: fakeLogits,
        labels: Tensor(zeros: fakeLogits.shape))
    return realLoss + fakeLoss
}

### Generator Loss Function

Our `generatorLoss()` function takes the fake logits, and calculates the `sigmoidCrossEntropy()`.

In [0]:
@differentiable
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {
    sigmoidCrossEntropy(
        logits: fakeLogits,
        labels: Tensor(ones: fakeLogits.shape))
}

### Random Samples

Our `sampleVector()` function returns random stuff, that we use for both the Discriminator and Generator later on.

In [0]:
/// Returns `size` samples of noise vector.
func sampleVector(size: Int) -> Tensor<Float> {
    Tensor(randomNormal: [size, latentSize])
}

## Setting up to train

### Getting a dataset

We're going to use the "Hello, world!" of machine learning, MNIST, as our dataset. This comes from some of the helper libraries we've provided for this session (which, in turn, are largely drawn from deep in the bowels of the TensorFlow project):

In [0]:
let dataset = MNIST(batchSize: batchSize, flattening: true, normalizing: true)

Loading resource: train-images-idx3-ubyte
File does not exist locally at expected path: /content/train-images-idx3-ubyte and must be fetched
Fetching URL: http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Writing fetched archive to: /content/train-images-idx3-ubyte.gz
Archive saved to: /content/train-images-idx3-ubyte.gz
Extracting archive...
Loading local data at: /content/train-images-idx3-ubyte
Succesfully loaded resource: train-images-idx3-ubyte
Loading resource: train-labels-idx1-ubyte
File does not exist locally at expected path: /content/train-labels-idx1-ubyte and must be fetched
Fetching URL: http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Writing fetched archive to: /content/train-labels-idx1-ubyte.gz
Archive saved to: /content/train-labels-idx1-ubyte.gz
Extracting archive...
Loading local data at: /content/train-labels-idx1-ubyte
Succesfully loaded resource: train-labels-idx1-ubyte
Loading resource: t10k-images-idx3-ubyte
File does not exist loc

### Creating a generator and a discriminator

In [0]:
var generator = Generator()
var discriminator = Discriminator()

### Creating optimisers for the generator and the discriminator

We need an optimization algorithm for both the models. In each case, we'll use the [Adam](https://www.tensorflow.org/swift/api_docs/Classes/Adam) optimisation algorithm. It's a popular choice!

#### Generator's optimizer

In [0]:
let optG = Adam(for: generator, learningRate: 2e-4, beta1: 0.5)

#### Discriminator's optimizer

In [0]:
let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)

## Training and Inference

First, we'll print out a message to say we're starting training:

In [0]:
print("GAN: Training Begins")

GAN: Training Begins


To train, we iterate through to our desired `epochCount`, runs training using both the Generator and the Discriminator, and then runs an inference to generate a grid of images and print out the current epoch, and the generator's loss:

Specifically, in each epoch, we:
* set the [`Context`](https://www.tensorflow.org/swift/api_docs/Structs/Context) to `.training` so that, for example, `BatchNorm` layers (like we're using in our Generator) will compute mean and variance when applied to inputs
* iterate through the training data batch and:
  * create a random sample using the `sampleVector()` function we wrote earlier
  * for the generator's gradient (𝛁), use the random sample and the output of the discriminator using that random sample to calculate a loss using the `generatorLoss()` function we wrote earlier
  * update the generator model, along the generator gradient, using the generator's optimizer
  * get a batch of of real images from the training data, as well as another random sample using `sampleVector()`, and use the generator to generate some generated (aka fake) images using the random sample data
  * for the discriminator's gradient (𝛁), calculate and return the loss between the generator running on the real images and on the fake images
  * update the discriminator model, along the discriminator gradient, using the discriminator's optimizer
* after iterating through the  training data batch, we set the [`Context`](https://www.tensorflow.org/swift/api_docs/Structs/Context) to `.inferece`
* then (after training for that epoch) we generate a test image, using the generator and random sample of the size our parameters dictate for the test image grid
  * and attempt to save that test image, using one of our convenience functions, `saveImageGrid()`
* we then check the loss on the generator for the test image, with our `generatorLoss()` function
* and print out the current epoch and generator loss

In [0]:
for epoch in 1...epochCount {

    Context.local.learningPhase = .training

    for i in 0 ..< dataset.trainingSize / batchSize {
        // Perform alternative update.
        // Update generator.
        let vec1 = sampleVector(size: batchSize)

        let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
            let fakeImages = generator(vec1)
            let fakeLogits = discriminator(fakeImages)
            let loss = generatorLoss(fakeLogits: fakeLogits)
            return loss
        }
        optG.update(&generator, along: 𝛁generator)

        // Update discriminator.
        let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
        let vec2 = sampleVector(size: batchSize)
        let fakeImages = generator(vec2)

        let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
            let realLogits = discriminator(realImages)
            let fakeLogits = discriminator(fakeImages)
            let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)
            return loss
        }
        optD.update(&discriminator, along: 𝛁discriminator)
    }

    // Start inference phase.
    Context.local.learningPhase = .inference
    let testImage = generator(sampleVector(size: testImageGridSize * testImageGridSize))

    do {
        try saveImageGrid(testImage, name: "epoch-\(epoch)-output")
    } catch {
        print("Could not save image grid with error: \(error)")
    }

    let lossG = generatorLoss(fakeLogits: testImage)
    print("Current Epoch: \(epoch) | Generator Loss: \(lossG)")
}

Current Epoch: 1 | Generator Loss: 1.1418108
Current Epoch: 2 | Generator Loss: 1.1449372
Current Epoch: 3 | Generator Loss: 1.1592628
Current Epoch: 4 | Generator Loss: 1.1639445
Current Epoch: 5 | Generator Loss: 1.1449314
Current Epoch: 6 | Generator Loss: 1.1559143
Current Epoch: 7 | Generator Loss: 1.1588637
Current Epoch: 8 | Generator Loss: 1.1727781
Current Epoch: 9 | Generator Loss: 1.1668153
Current Epoch: 10 | Generator Loss: 1.1905712


## Extra Credit

Our suggestions for what to do next:


1. use a Python library to visualise some of this in the notebook, either via graphs, or via displaying images inline in the notebook
2. modify the GAN to be able to generate one image of a digit at a time, upon request (e.g. make a function that lets you request a generated 5, or a generated 6)
3. modify the GAN to generate something other than MNIST digits 

