<a href="https://colab.research.google.com/github/rahulbhalley/swift-for-tensorflow-examples/blob/master/GAN_S4TF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import dependencies

In [0]:
import TensorFlow
import Python
PythonLibrary.useVersion(3)

let time = Python.import("time")

In [2]:
let animals = ["🐶", "🐹", "🐻", "🐸"]
for animal in animals {
  print(animal)
}

🐶
🐹
🐻
🐸


## Data Downloading and Loading Helpers

In [0]:
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
  let subprocess = Python.import("subprocess")
  let path = Python.import("os.path")
  let filepath = "\(directory)/cifar-10-batches-py"
  let isdir = Bool(path.isdir(filepath))!
  if !isdir {
    print("Downloading CIFAR data...")
    let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)"
    subprocess.call(command, shell: true)
  }
}

struct Example: TensorGroup {
  var label: Tensor<Int32>
  var data: Tensor<Float>
}

// Each CIFAR data file is provided as a Python pickle of NumPy arrays
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
  downloadCIFAR10IfNotPresent(to: directory)
  let np = Python.import("numpy")
  let pickle = Python.import("pickle")
  let path = "\(directory)/cifar-10-batches-py/\(name)"
  let f = Python.open(path, "rb")
  let res = pickle.load(f, encoding: "bytes")

  let bytes = res[Python.bytes("data", encoding: "utf8")]
  let labels = res[Python.bytes("labels", encoding: "utf8")]

  let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
  let images = Tensor<UInt8>(numpy: bytes)!
  let imageCount = images.shape[0]

  // reshape and transpose from the provided N(CHW) to TF default NHWC
  let imageTensor = Tensor<Float>(images
      .reshaped(to: [imageCount, 3, 32, 32])
      .transposed(withPermutations: [0, 2, 3, 1]))

  let mean = Tensor<Float>([0.485, 0.456, 0.406])
  let std  = Tensor<Float>([0.229, 0.224, 0.225])
  let imagesNormalized = ((imageTensor / 255.0) - mean) / std

  return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
  let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
  return Example(
    label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
    data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
  )
}

func loadCIFARTestFile() -> Example {
  return loadCIFARFile(named: "test_batch")
}

func loadCIFAR10() -> (
  training: Dataset<Example>, test: Dataset<Example>) {
    let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
    let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
    return (training: trainingDataset, test: testDataset)
}

## Some function
- LeakyReLU

In [0]:
@differentiable
func leakyReLU(_ tensor: Tensor<Float>, negativeSlope: Float) -> Tensor<Float> {
  let zeros = Tensor<Float>(zeros: tensor.shape)
  let minimum = min(zeros, tensor)
  let maximum = max(zeros, tensor)
  let output = maximum + negativeSlope * minimum
  return output
}

## Some variables

In [0]:
let batchSize = 32
let imageSize = 64
let numberOfChannels = 3
let zLatent = 128

## Model

### LeNet

In [0]:
struct LeNet: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var conv1 = Conv2D<Float>(filterShape: (5, 5, 3, 6), activation: relu)
  var pool1 = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
  var conv2 = Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
  var pool2 = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
  var flatten = Flatten<Float>()
  var dense1 = Dense<Float>(inputSize: 16 * 5 * 5, outputSize: 120, activation: relu)
  var dense2 = Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
  var dense3 = Dense<Float>(inputSize: 84, outputSize: 10, activation: identity)
  
  @differentiable
  func call(_ input: Input) -> Output {
    let convolved = input.sequenced(through: conv1, pool1, conv2, pool2)
    return convolved.sequenced(through: flatten, dense1, dense2, dense3)
  }
}

### Helper Layers

In [0]:
struct ConvBN: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>

  var conv: Conv2D<Float>
  var norm: BatchNorm<Float>

  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int) = (1, 1),
    padding: Padding = .valid
  ) {
    self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(featureCount: filterShape.3)
  }

  @differentiable
  func call(_ input: Input) -> Output {
    return input.sequenced(through: conv, norm)
  }
}

// DOES NOT WORK
// TransposedConv2D requires same number of input & output channels. FACK! 😒
struct TransposedConvBN: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>

  var conv: TransposedConv2D
  var norm: BatchNorm<Float>

  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int) = (1, 1),
    padding: Padding = .valid
  ) {
    self.conv = TransposedConv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(featureCount: filterShape.3)
  }

  @differentiable
  func call(_ input: Input) -> Output {
    return input.sequenced(through: conv, norm)
  }
}

struct UpSampleConvBN: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var upSample: UpSampling2D<Float>
  var conv: Conv2D<Float>
  var norm: BatchNorm<Float>
  
  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int),
    padding: Padding = .same
  ) {
    self.upSample = UpSampling2D(size: 4)
    self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(featureCount: filterShape.3)
  }
  
  @differentiable
  func call(_ input: Input) -> Output {
    return input.sequenced(through: upSample, conv, norm)
  }
}

### Generative Adversarial Network
- Critic
- Generator

In [0]:
struct Generator: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var dense = Dense<Float>(inputSize: zLatent, outputSize: 2 * 2 * 1024)
  var norm = BatchNorm<Float>(featureCount: 4096)
  var upSampleConvBN1 = UpSampleConvBN(
    filterShape: (4, 4, 512, 256),
    strides: (2, 2),
    padding: .same
  )
  var upSampleConvBN2 = UpSampleConvBN(
    filterShape: (4, 4, 256, 128),
    strides: (2, 2),
    padding: .same
  )
  var upSampleConvBN3 = UpSampleConvBN(
    filterShape: (4, 4, 128, 64),
    strides: (2, 2),
    padding: .same
  )
//   var upSampleConvBN4 = UpSampleConvBN(
//     filterShape: (4, 4, 128, 64),
//     strides: (2, 2),
//     padding: .same
//   )
  var upSample = UpSampling2D<Float>(size: 4)
  var conv = Conv2D<Float>(
    filterShape: (4, 4, 64, numberOfChannels), 
    strides: (2, 2), 
    padding: .same
  )
  
  @differentiable
  func call(_ input: Input) -> Output {
    var output: Output
    // Dense activation
    output = relu(norm(dense(input)))
    // Expand shape to 4D
    output = output.expandingShape(at: 0).expandingShape(at: 0)
    // Reshape and transpose from N(CHW) to N(HWC)
//     print(output.shape)
    output = output.reshaped(to: [batchSize, 1024, 2, 2]) // N(CHW)
    output = output.transposed(withPermutations: [0, 2, 3, 1]) // NHWC
    // Upsample and convolve through remaining layers
    // 🧻: (Distill: Deconvolution and Checkerboard Artifacts)
    output = relu(upSampleConvBN1(output))
    output = relu(upSampleConvBN2(output))
    output = relu(upSampleConvBN3(output))
//     output = relu(upSampleConvBN4(output))
    output = tanh(conv(upSample(output)))
    return output
  }
}

struct Critic: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var convBN1 = Conv2D<Float>(
    filterShape: (4, 4, numberOfChannels, 64),
    strides: (2, 2),
    padding: .same
  )
  var convBN2 = ConvBN(
    filterShape: (4, 4, 64, 128),
    strides: (2, 2),
    padding: .same
  )
  var convBN3 = ConvBN(
    filterShape: (4, 4, 128, 256),
    strides: (2, 2),
    padding: .same
  )
  var convBN4 = ConvBN(
    filterShape: (4, 4, 256, 512),
    strides: (2, 2),
    padding: .same
  )
//   var convBN5 = ConvBN(
//     filterShape: (4, 4, 512, 1024),
//     strides: (2, 2),
//     padding: .same
//   )
  var flatten = Flatten<Float>()
  var dense = Dense<Float>(inputSize: 2 * 2 * 512, outputSize: 1)
  
  @differentiable
  func call(_ input: Input) -> Output {
    var output: Output
    output = leakyReLU(convBN1(input),  negativeSlope: 0.2)
    output = leakyReLU(convBN2(output), negativeSlope: 0.2)
    output = leakyReLU(convBN3(output), negativeSlope: 0.2)
    output = leakyReLU(convBN4(output), negativeSlope: 0.2)
//     output = leakyReLU(convBN5(output), negativeSlope: 0.2)
    output = flatten(output)
    output = dense(output)
    return output
  }
}

In [0]:
var criticModel = Critic()
var generatorModel = Generator()

In [10]:
// var input = Tensor<Float>(glorotUniform: [batchSize, imageSize, imageSize, numberOfChannels])
var input = Tensor<Float>(glorotUniform: [batchSize, zLatent])
print("input: \(input.shape)")

var outputGenerator = generatorModel(input)
print("generator: \(outputGenerator.shape)")

var outputCritic = criticModel(outputGenerator)
print("outputCritic: \(outputCritic.shape)")

input: [32, 128]
generator: [32, 32, 32, 3]
outputCritic: [32, 1]


In [11]:
print(Context.local.learningPhase)
Context.local.learningPhase = .training
print(Context.local.learningPhase)

inference
training


### ResNet

In [0]:
// struct Conv2DBatchNorm: Layer {
//     typealias Input = Tensor<Float>
//     typealias Output = Tensor<Float>

//     var conv: Conv2D<Float>
//     var norm: BatchNorm<Float>

//     init(
//         filterShape: (Int, Int, Int, Int),
//         strides: (Int, Int) = (1, 1)
//     ) {
//         self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: .same)
//         self.norm = BatchNorm(featureCount: filterShape.3)
//     }

//     @differentiable
//     func call(_ input: Input) -> Output {
//         return input.sequenced(through: conv, norm)
//     }
// }

// struct BasicBlock: Layer {
//     typealias Input = Tensor<Float>
//     typealias Output = Tensor<Float>

//     var blocks: [Conv2DBatchNorm]
//     var shortcut: Conv2DBatchNorm

//     init(
//         featureCounts: (Int, Int),
//         kernelSize: Int = 3,
//         strides: (Int, Int) = (2, 2),
//         blockCount: Int = 3
//     ) {
//         self.blocks = [Conv2DBatchNorm(
//             filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
//             strides: strides)]
//         for _ in 2..<blockCount {
//             self.blocks += [Conv2DBatchNorm(
//                 filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.1))]
//         }
//         self.shortcut = Conv2DBatchNorm(
//             filterShape: (1, 1, featureCounts.0, featureCounts.1),
//             strides: strides)
//     }

//     @differentiable
//     func call(_ input: Input) -> Output {
//         let blocksReduced = blocks.differentiableReduce(input) { last, layer in
//             relu(layer(last))
//         }
//         return relu(blocksReduced + shortcut(input))
//     }
// }

// struct ResNet: Layer {
//     typealias Input = Tensor<Float>
//     typealias Output = Tensor<Float>

//     var inputLayer = Conv2DBatchNorm(filterShape: (3, 3, 3, 16))

//     var basicBlock1: BasicBlock
//     var basicBlock2: BasicBlock
//     var basicBlock3: BasicBlock

//     init(blockCount: Int = 3) {
//         basicBlock1 = BasicBlock(featureCounts:(16, 16), strides: (1, 1), blockCount: blockCount)
//         basicBlock2 = BasicBlock(featureCounts:(16, 32), blockCount: blockCount)
//         basicBlock3 = BasicBlock(featureCounts:(32, 64), blockCount: blockCount)
//     }

//     var averagePool = AvgPool2D<Float>(poolSize: (8, 8), strides: (8, 8))
//     var flatten = Flatten<Float>()
//     var classifier = Dense<Float>(inputSize: 64, outputSize: 10, activation: softmax)

//     @differentiable
//     func call(_ input: Input) -> Output {
//         let tmp = relu(inputLayer(input))
//         let convolved = tmp.sequenced(through: basicBlock1, basicBlock2, basicBlock3)
//         return convolved.sequenced(through: averagePool, flatten, classifier)
//     }
// }

// extension ResNet {
//     enum Kind: Int {
//         case resNet20 = 3
//         case resNet32 = 5
//         case resNet44 = 7
//         case resNet56 = 9
//         case resNet110 = 18
//     }

//     init(kind: Kind) {
//         self.init(blockCount: kind.rawValue)
//     }
// }

## Training

In [19]:
// Load the dataset
let cifarDataset = loadCIFAR10()
let testBatches = cifarDataset.test.batched(batchSize)

// Inputs
var zInput = Tensor<Float>(glorotUniform: [batchSize, zLatent])

// Models
var generatorModel = Generator()
var criticModel = Critic()

// Optimizers
let generatorOptimizer = RMSProp(for: generatorModel, learningRate: 0.0001, decay: 1e-6)
let criticOptimizer = RMSProp(for: criticModel, learningRate: 0.0001, decay: 1e-6)

print("Starting training...")

// Setup the training context
Context.local.learningPhase = .training

for epoch in 1...100 {
  var startTime = time.time()

  var trainingLossSum: Float = 0
  var trainingBatchCount = 0
  
  // Shuffle the dataset
  let trainingShuffled = cifarDataset.training.shuffled(
      sampleCount: 50000, randomSeed: Int64(epoch))
  
  // Start the training now
  for batch in trainingShuffled.batched(batchSize) {
    let (labels, images) = (batch.label, batch.data)

    // Train `criticModel` one step
    let (loss, gradients) = valueWithGradient(at: criticModel) {
      criticModel -> Tensor<Float> in
      // Sample uniformly from normal distribution of [`batchSize`, `zLatent`] shaped `Tensor`
      zInput = Tensor<Float>(glorotUniform: [batchSize, zLatent])
//       print(zInput.shape)
      // Forward pass `zInput` through `generatorModel`
      // and don't track gradients of `generatorModel`.
      var generatedImage = generatorModel(zInput).withoutDerivative()
//       print(generatedImage.shape)
      // Critize the `generatedImage` though `criticModel`
      let criticProbability = criticModel(generatedImage)
//       print(criticProbability.shape)
      // Calculate the loss finally.
      return 1.0 - log(criticProbability)
      //return softmaxCrossEntropy(logits: logits, labels: labels)
    }
//     trainingLossSum += loss.scalarized()
    trainingBatchCount += 1
    criticOptimizer.update(&criticModel.allDifferentiableVariables, along: gradients)
    print("Epoch: \(epoch) Loss: \(loss.mean())")
    
    // Train `generatorModel` open step
//     let (loss, gradients) = valueWithGradient(at: generatorModel) { generatorModel -> Tensor<Float> in
//       let logits = generatorModel(images)

//       //return softmaxCrossEntropy(logits: logits, labels: labels)
//     }
//     trainingLossSum += loss.scalarized()
//     trainingBatchCount += 1
//     optimizer.update(&generatorModel.allDifferentiableVariables, along: gradients)
  }

//   var testLossSum: Float = 0
//   var testBatchCount = 0
//   var correctGuessCount = 0
//   var totalGuessCount = 0
//   for batch in testBatches {
//     let (labels, images) = (batch.label, batch.data)
//     let logits = model(images)
//     testLossSum += softmaxCrossEntropy(logits: logits, labels: labels).scalarized()
//     testBatchCount += 1

//     let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
//     correctGuessCount = correctGuessCount + Int(Tensor<Int32>(correctPredictions).sum().scalarized())
//     totalGuessCount = totalGuessCount + batchSize
//   }
//   var endTime = time.time()

//   let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
//   print("""
//         [Epoch \(epoch)] \
//         Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
//         Loss: \(testLossSum / Float(testBatchCount)) \
//         Took: \(endTime - startTime) s
//         """)
}

Starting training...
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1fffff)
Epoch: 1 Loss: nan(0x1

: ignored

In [14]:
var aTensor = Tensor<Float>(glorotUniform: [10, 1])
print(aTensor)

[[  0.58413595],
 [ -0.18717654],
 [   0.6253822],
 [  0.11087661],
 [  0.47472066],
 [  0.36341968],
 [  0.47421372],
 [-0.042357836],
 [  0.41362908],
 [ -0.10153097]]


In [15]:
print(aTensor.mean())

0.27153125


In [18]:
aTensor = aTensor.reshaped(to: [1, 10])
print(aTensor)
print(aTensor.mean().scalarized())

[[  0.58413595,  -0.18717654,    0.6253822,   0.11087661,   0.47472066,   0.36341968,
    0.47421372, -0.042357836,   0.41362908,  -0.10153097]]
0.27153125
