# MNIST GAN in Swift for TensorFlow

## Importing dependencies

In [0]:
import TensorFlow
import Foundation
import FoundationNetworking

In [0]:
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

In [0]:
import Python
let plt = Python.import("matplotlib.pyplot")
let np = Python.import("numpy")

## Loading MNIST

Recreating much of the MNIST API in https://github.com/tensorflow/swift-models, because it is not currently working in Colab. The team told us they are working on an update. (https://github.com/tensorflow/swift-models/issues/233)

In [0]:
public struct LabeledExample: TensorGroup {
    public var label: Tensor<Int32>
    public var data: Tensor<Float>

    public init(label: Tensor<Int32>, data: Tensor<Float>) {
        self.label = label
        self.data = data
    }

    public init<C: RandomAccessCollection>(
        _handles: C
    ) where C.Element: _AnyTensorHandle {
        precondition(_handles.count == 2)
        let labelIndex = _handles.startIndex
        let dataIndex = _handles.index(labelIndex, offsetBy: 1)
        label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
        data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
    }
}

public struct DatasetUtilities {
    public static let currentWorkingDirectoryURL = URL(
        fileURLWithPath: FileManager.default.currentDirectoryPath)

    public static func fetchResource(
        filename: String,
        remoteRoot: URL,
        localStorageDirectory: URL = currentWorkingDirectoryURL
    ) -> Data {
        print("Loading resource: \(filename)")

        let resource = ResourceDefinition(
            filename: filename,
            remoteRoot: remoteRoot,
            localStorageDirectory: localStorageDirectory)

        let localURL = resource.localURL

        if !FileManager.default.fileExists(atPath: localURL.path) {
            print(
                "File does not exist locally at expected path: \(localURL.path) and must be fetched"
            )
            fetchFromRemoteAndSave(resource)
        }

        do {
            print("Loading local data at: \(localURL.path)")
            let data = try Data(contentsOf: localURL)
            print("Succesfully loaded resource: \(filename)")
            return data
        } catch {
            fatalError("Failed to contents of resource: \(localURL)")
        }
    }

    struct ResourceDefinition {
        let filename: String
        let remoteRoot: URL
        let localStorageDirectory: URL

        var localURL: URL {
            localStorageDirectory.appendingPathComponent(filename)
        }

        var remoteURL: URL {
            remoteRoot.appendingPathComponent(filename).appendingPathExtension("gz")
        }

        var archiveURL: URL {
            localURL.appendingPathExtension("gz")
        }
    }

    static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
        let remoteLocation = resource.remoteURL
        let archiveLocation = resource.archiveURL

        do {
            print("Fetching URL: \(remoteLocation)...")
            let archiveData = try Data(contentsOf: remoteLocation)
            print("Writing fetched archive to: \(archiveLocation.path)")
            try archiveData.write(to: archiveLocation)
        } catch {
            fatalError("Failed to fetch and save resource with error: \(error)")
        }
        print("Archive saved to: \(archiveLocation.path)")

        extractArchive(for: resource)
    }

    static func extractArchive(for resource: ResourceDefinition) {
        print("Extracting archive...")

        let archivePath = resource.archiveURL.path

        #if os(macOS)
            let gunzipLocation = "/usr/bin/gunzip"
        #else
            let gunzipLocation = "/bin/gunzip"
        #endif

        let task = Process()
        task.executableURL = URL(fileURLWithPath: gunzipLocation)
        task.arguments = [archivePath]
        do {
            try task.run()
            task.waitUntilExit()
        } catch {
            fatalError("Failed to extract \(archivePath) with error: \(error)")
        }
    }
}

public struct MNIST {
    public let trainingDataset: Dataset<LabeledExample>
    public let testDataset: Dataset<LabeledExample>
    public let trainingExampleCount = 60000

    public init() {
        self.init(flattening: false, normalizing: false)
    }

    public init(
        flattening: Bool = false, normalizing: Bool = false,
        localStorageDirectory: URL = DatasetUtilities.currentWorkingDirectoryURL
    ) {
        self.trainingDataset = Dataset<LabeledExample>(
            elements: fetchDataset(
                localStorageDirectory: localStorageDirectory,
                imagesFilename: "train-images-idx3-ubyte",
                labelsFilename: "train-labels-idx1-ubyte",
                flattening: flattening,
                normalizing: normalizing))

        self.testDataset = Dataset<LabeledExample>(
            elements: fetchDataset(
                localStorageDirectory: localStorageDirectory,
                imagesFilename: "t10k-images-idx3-ubyte",
                labelsFilename: "t10k-labels-idx1-ubyte",
                flattening: flattening,
                normalizing: normalizing))
    }
}

fileprivate func fetchDataset(
    localStorageDirectory: URL,
    imagesFilename: String,
    labelsFilename: String,
    flattening: Bool,
    normalizing: Bool
) -> LabeledExample {
    guard let remoteRoot:URL = URL(string: "http://yann.lecun.com/exdb/mnist") else {
        fatalError("Failed to create MNST root url: http://yann.lecun.com/exdb/mnist")
    }

    let imagesData = DatasetUtilities.fetchResource(
        filename: imagesFilename,
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)
    let labelsData = DatasetUtilities.fetchResource(
        filename: labelsFilename,
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)

    let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
    let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)

    let rowCount = labels.count
    let (imageWidth, imageHeight) = (28, 28)

    if flattening {
        var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
            / 255.0
        if normalizing {
            flattenedImages = flattenedImages * 2.0 - 1.0
        }
        return LabeledExample(label: Tensor(labels), data: flattenedImages)
    } else {
        return LabeledExample(
            label: Tensor(labels),
            data:
                Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
                .transposed(withPermutations: [0, 2, 3, 1]) / 255  // NHWC
        )
    }
}

In [0]:
let mnist = MNIST(flattening: false, normalizing: true)

## Building a DCGAN

In [0]:
let zDim = 100

### Generator

In [0]:
struct Generator: Layer {
    var flatten = Flatten<Float>()

    var dense1 = Dense<Float>(inputSize: zDim, outputSize: 7 * 7 * 256) 
    var batchNorm1 = BatchNorm<Float>(featureCount: 7 * 7 * 256)
    // leaky relu
    // reshape

    var transConv2D1 = TransposedConv2D<Float>(
        filterShape: (5, 5, 128, 256),
        strides: (1, 1),
        padding: .same
    )
    // flatten
    var batchNorm2 = BatchNorm<Float>(featureCount: 7 * 7 * 128)
    // leaky relu

    var transConv2D2 = TransposedConv2D<Float>(
        filterShape: (5, 5, 64, 128),
        strides: (2, 2),
        padding: .same
    )
    // flatten
    var batchNorm3 = BatchNorm<Float>(featureCount: 14 * 14 * 64)
    // leaky relu

    var transConv2D3 = TransposedConv2D<Float>(
        filterShape: (5, 5, 1, 64),
        strides: (2, 2),
        padding: .same
    )
    // tanh


    @differentiable 
    public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { 
        let x1 = leakyRelu(input.sequenced(through: dense1, batchNorm1))
        let x1Reshape = x1.reshaped(to: TensorShape(x1.shape.contiguousSize / (7 * 7 * 256), 7, 7, 256))
        let x2 = leakyRelu(x1Reshape.sequenced(through: transConv2D1, flatten, batchNorm2))
        let x2Reshape = x2.reshaped(to: TensorShape(x2.shape.contiguousSize / (7 * 7 * 128), 7, 7, 128))
        let x3 = leakyRelu(x2Reshape.sequenced(through: transConv2D2, flatten, batchNorm3))
        let x3Reshape = x3.reshaped(to: TensorShape(x3.shape.contiguousSize / (14 * 14 * 64), 14, 14, 64))
        return tanh(transConv2D3(x3Reshape))
    } 
} 

In [0]:
var generator = Generator()

In [0]:
let noise = Tensor<Float>(randomNormal: TensorShape(1, zDim))
let generatedImage = generator(noise)
generatedImage.shape

In [0]:
plt.imshow(generatedImage.reshaped(to: TensorShape(28, 28)).makeNumpyArray())
plt.show()

### Discriminator

In [0]:
struct Discriminator: Layer {
      var conv2D1 = Conv2D<Float>(
          filterShape: (5, 5, 1, 64),
          strides: (2, 2),
          padding: .same
      )
      // leaky relu
      var dropout = Dropout<Float>(probability: 0.3)

      var conv2D2 = Conv2D<Float>(
          filterShape: (5, 5, 64, 128),
          strides: (2, 2),
          padding: .same
      )
      // leaky relu
      // dropout

      var flatten = Flatten<Float>()
      var dense = Dense<Float>(inputSize: 6272, outputSize: 1)

      @differentiable 
      public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { 
          let x1 = dropout(leakyRelu(conv2D1(input)))
          let x2 = dropout(leakyRelu(conv2D2(x1)))
          return x2.sequenced(through: flatten, dense)
      } 
}

In [0]:
var discriminator = Discriminator()

In [0]:
discriminator(generatedImage)

## Training

In [0]:
let batchSize = 512

In [0]:
@differentiable
func generatorLoss(fakeLabels: Tensor<Float>) -> Tensor<Float> {
    sigmoidCrossEntropy(logits: fakeLabels,
                        labels: Tensor(ones: fakeLabels.shape))
}

In [0]:
@differentiable
func discriminatorLoss(realLabels: Tensor<Float>, fakeLabels: Tensor<Float>) -> Tensor<Float> {
    let realLoss = sigmoidCrossEntropy(logits: realLabels,
                                       labels: Tensor(ones: realLabels.shape)) // should say it's real, 1
    let fakeLoss = sigmoidCrossEntropy(logits: fakeLabels,
                                       labels: Tensor(zeros: fakeLabels.shape)) // should say it's fake, 0
    return realLoss + fakeLoss // accumaltive
}

In [0]:
let optG = Adam(for: generator, learningRate: 0.0001)
let optD = Adam(for: discriminator, learningRate: 0.0001)

In [0]:
let epochs = 20

In [0]:
for epoch in 0...epochs {
    Context.local.learningPhase = .training
    let trainingShuffled = mnist.trainingDataset.shuffled(sampleCount: mnist.trainingExampleCount, randomSeed: Int64(epoch))
    for batch in trainingShuffled.batched(batchSize) {
        let realImages = batch.data

        // train generator
        let noiseG = Tensor<Float>(randomNormal: TensorShape(batchSize, zDim))
        let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
            let fakeImages = generator(noiseG)
            let fakeLabels = discriminator(fakeImages)
            let loss = generatorLoss(fakeLabels: fakeLabels)
            return loss
        }
        optG.update(&generator, along: 𝛁generator)

        // train discriminator
        let noiseD = Tensor<Float>(randomNormal: TensorShape(batchSize, zDim))
        let fakeImages = generator(noiseD)
        
        let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
            let realLabels = discriminator(realImages)
            let fakeLabels = discriminator(fakeImages)
            let loss = discriminatorLoss(realLabels: realLabels, fakeLabels: fakeLabels)
            return loss
        }
        optD.update(&discriminator, along: 𝛁discriminator)
    }

    // test
    Context.local.learningPhase = .inference

    // render images
    let generatedImage = generator(noise)
    plt.imshow(generatedImage.reshaped(to: TensorShape(28, 28)).makeNumpyArray())
    plt.show()
    
    // print loss
    let generatorLoss_ = generatorLoss(fakeLabels: generatedImage)
    print("epoch: \(epoch) | Generator loss: \(generatorLoss_)")
}