# Swift for Tensorflow

helper functions:

In [0]:
import Foundation

var stderr = FileHandle.standardError

extension FileHandle: TextOutputStream {
    public func write(_ string: String) {
        guard let data = string.data(using: .utf8) else { return }
        self.write(data)
    }
}

public func printError(_ message: String) {
    print(message, to: &stderr)
}

## Data

LabeledExample:

In [0]:
import TensorFlow

public struct LabeledExample: TensorGroup {
    public var label: Tensor<Int32>
    public var data: Tensor<Float>

    public init(label: Tensor<Int32>, data: Tensor<Float>) {
        self.label = label
        self.data = data
    }

    public init<C: RandomAccessCollection>(
        _handles: C
    ) where C.Element: _AnyTensorHandle {
        precondition(_handles.count == 2)
        let labelIndex = _handles.startIndex
        let dataIndex = _handles.index(labelIndex, offsetBy: 1)
        label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
        data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
    }
}

ImageClassificationDataset:

In [0]:
public protocol ImageClassificationDataset {
    init()
    var trainingDataset: Dataset<LabeledExample> { get }
    var testDataset: Dataset<LabeledExample> { get }
    var trainingExampleCount: Int { get }
    var testExampleCount: Int { get }
}

CIFAR10:

In [0]:
import Foundation
import ModelSupport

#if canImport(FoundationNetworking)
    import FoundationNetworking
#endif

public struct CIFAR10: ImageClassificationDataset {
    public let trainingDataset: Dataset<LabeledExample>
    public let testDataset: Dataset<LabeledExample>
    public let trainingExampleCount = 50000
    public let testExampleCount = 10000

    public init() {
        self.init(
            localStorageDirectory: FileManager.default.temporaryDirectory.appendingPathComponent(
                "CIFAR10"))
    }

    public init(localStorageDirectory: URL) {
        self.trainingDataset = Dataset<LabeledExample>(
            elements: loadCIFARTrainingFiles(localStorageDirectory: localStorageDirectory))
        self.testDataset = Dataset<LabeledExample>(
            elements: loadCIFARTestFile(localStorageDirectory: localStorageDirectory))
    }
}

func downloadCIFAR10IfNotPresent(to directory: URL) {
    if !FileManager.default.fileExists(atPath: directory.path) {
        do {
            try FileManager.default.createDirectory(
                at: directory, withIntermediateDirectories: false)
        } catch {
            fatalError(
                "Failed to create storage directory: \(directory.path), error: \(error)"
            )
        }
    }

    let downloadPath = directory.appendingPathComponent("cifar-10-batches-bin").path
    let directoryExists = FileManager.default.fileExists(atPath: downloadPath)

    guard !directoryExists else { return }

    printError("Downloading CIFAR dataset...")
    let archivePath = directory.appendingPathComponent("cifar-10-binary.tar.gz").path
    let archiveExists = FileManager.default.fileExists(atPath: archivePath)
    if !archiveExists {
        printError("Archive missing, downloading...")
        do {
            let downloadedFile = try Data(
                contentsOf: URL(
                    string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
            try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
        } catch {
            printError("Could not download CIFAR dataset, error: \(error)")
            exit(-1)
        }
    }

    printError("Archive downloaded, processing...")

    #if os(macOS)
        let tarLocation = "/usr/bin/tar"
    #else
        let tarLocation = "/bin/tar"
    #endif

    let task = Process()
    task.executableURL = URL(fileURLWithPath: tarLocation)
    task.arguments = ["xzf", archivePath, "-C", directory.path]
    do {
        try task.run()
        task.waitUntilExit()
    } catch {
        printError("CIFAR extraction failed with error: \(error)")
    }

    do {
        try FileManager.default.removeItem(atPath: archivePath)
    } catch {
        printError("Could not remove archive, error: \(error)")
        exit(-1)
    }

    printError("Unarchiving completed")
}

func loadCIFARFile(named name: String, in directory: URL) -> LabeledExample {
    downloadCIFAR10IfNotPresent(to: directory)
    let path = directory.appendingPathComponent("cifar-10-batches-bin/\(name)").path

    let imageCount = 10000
    guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
        printError("Could not read dataset file: \(name)")
        exit(-1)
    }
    guard fileContents.count == 30_730_000 else {
        printError(
            "Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
        exit(-1)
    }

    var bytes: [UInt8] = []
    var labels: [Int64] = []

    let imageByteSize = 3073
    for imageIndex in 0..<imageCount {
        let baseAddress = imageIndex * imageByteSize
        labels.append(Int64(fileContents[baseAddress]))
        bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
    }

    let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
    let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)

    // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
    let imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))

    let mean = Tensor<Float>([0.485, 0.456, 0.406])
    let std = Tensor<Float>([0.229, 0.224, 0.225])
    let imagesNormalized = ((imageTensor / 255.0) - mean) / std

    return LabeledExample(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles(localStorageDirectory: URL) -> LabeledExample {
    let data = (1..<6).map {
        loadCIFARFile(named: "data_batch_\($0).bin", in: localStorageDirectory)
    }
    return LabeledExample(
        label: Tensor(concatenating: data.map { $0.label }, alongAxis: 0),
        data: Tensor(concatenating: data.map { $0.data }, alongAxis: 0)
    )
}

func loadCIFARTestFile(localStorageDirectory: URL) -> LabeledExample {
    return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory)
}

## Model

In [0]:
// Ported from github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py
struct KerasModel: Layer {
    typealias Input = Tensor<Float>
    typealias Output = Tensor<Float>

    var conv1a = Conv2D<Float>(filterShape: (3, 3, 3, 32), padding: .same, activation: relu)
    var conv1b = Conv2D<Float>(filterShape: (3, 3, 32, 32), activation: relu)
    var pool1 = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
    var dropout1 = Dropout<Float>(probability: 0.25)
    var conv2a = Conv2D<Float>(filterShape: (3, 3, 32, 64), padding: .same, activation: relu)
    var conv2b = Conv2D<Float>(filterShape: (3, 3, 64, 64), activation: relu)
    var pool2 = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
    var dropout2 = Dropout<Float>(probability: 0.25)
    var flatten = Flatten<Float>()
    var dense1 = Dense<Float>(inputSize: 64 * 6 * 6, outputSize: 512, activation: relu)
    var dropout3 = Dropout<Float>(probability: 0.5)
    var dense2 = Dense<Float>(inputSize: 512, outputSize: 10, activation: identity)

    @differentiable
    func callAsFunction(_ input: Input) -> Output {
        let conv1 = input.sequenced(through: conv1a, conv1b, pool1, dropout1)
        let conv2 = conv1.sequenced(through: conv2a, conv2b, pool2, dropout2)
        return conv2.sequenced(through: flatten, dense1, dropout3, dense2)
    }
}

## Train

In [0]:
let batchSize = 32

let dataset = CIFAR10()
let testBatches = dataset.testDataset.batched(batchSize)

var model = KerasModel()
let optimizer = RMSProp(for: model, learningRate: 0.0001, decay: 1e-6)

print("Starting training...")

for epoch in 1...20 {
    Context.local.learningPhase = .training
    var trainingLossSum: Float = 0
    var trainingBatchCount = 0
    let trainingShuffled = dataset.trainingDataset.shuffled(
        sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
    for batch in trainingShuffled.batched(batchSize) {
        let (labels, images) = (batch.label, batch.data)
        let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
            let logits = model(images)
            return softmaxCrossEntropy(logits: logits, labels: labels)
        }
        trainingLossSum += loss.scalarized()
        trainingBatchCount += 1
        optimizer.update(&model, along: gradients)
    }

    Context.local.learningPhase = .inference
    var testLossSum: Float = 0
    var testBatchCount = 0
    var correctGuessCount = 0
    var totalGuessCount = 0
    for batch in testBatches {
        let (labels, images) = (batch.label, batch.data)
        let logits = model(images)
        testLossSum += softmaxCrossEntropy(logits: logits, labels: labels).scalarized()
        testBatchCount += 1

        let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
        correctGuessCount = correctGuessCount
            + Int(
                Tensor<Int32>(correctPredictions).sum().scalarized())
        totalGuessCount = totalGuessCount + batchSize
    }

    let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
    print(
        """
        [Epoch \(epoch)] \
        Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
        Loss: \(testLossSum / Float(testBatchCount))
        """
    )
}

Starting training...
[Epoch 1] Accuracy: 4989/10016 (0.49810302) Loss: 1.431187
[Epoch 2] Accuracy: 5582/10016 (0.5573083) Loss: 1.2572505
[Epoch 3] Accuracy: 6011/10016 (0.6001398) Loss: 1.1344422
[Epoch 4] Accuracy: 6396/10016 (0.6385783) Loss: 1.0295383
[Epoch 5] Accuracy: 6591/10016 (0.65804714) Loss: 0.9724738
[Epoch 6] Accuracy: 6870/10016 (0.68590254) Loss: 0.9059218
[Epoch 7] Accuracy: 6867/10016 (0.685603) Loss: 0.88318604
[Epoch 8] Accuracy: 7128/10016 (0.71166134) Loss: 0.8305917
[Epoch 9] Accuracy: 7201/10016 (0.7189497) Loss: 0.8110487
[Epoch 10] Accuracy: 7332/10016 (0.7320288) Loss: 0.77304375
[Epoch 11] Accuracy: 7288/10016 (0.7276358) Loss: 0.7856382
[Epoch 12] Accuracy: 7412/10016 (0.740016) Loss: 0.7428276
[Epoch 13] Accuracy: 7417/10016 (0.7405152) Loss: 0.75255996
[Epoch 14] Accuracy: 7416/10016 (0.74041533) Loss: 0.7470886
[Epoch 15] Accuracy: 7504/10016 (0.7492013) Loss: 0.7293965
[Epoch 16] Accuracy: 7426/10016 (0.7414137) Loss: 0.7562756
[Epoch 17] Accuracy: 76