<a href="https://colab.research.google.com/github/rahulbhalley/swift-for-tensorflow-examples/blob/master/NeuralNetworks/WassersteinGAN/WassersteinGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import dependencies

In [0]:
import TensorFlow
import Python
PythonLibrary.useVersion(3)

let plt = Python.import("matplotlib.pyplot")
let time = Python.import("time")

## Data Downloading and Loading Helpers

In [0]:
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
  let subprocess = Python.import("subprocess")
  let path = Python.import("os.path")
  let filepath = "\(directory)/cifar-10-batches-py"
  let isdir = Bool(path.isdir(filepath))!
  if !isdir {
    print("Downloading CIFAR data...")
    let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)"
    subprocess.call(command, shell: true)
  }
}

struct Example: TensorGroup {
  var label: Tensor<Int32>
  var data: Tensor<Float>
}

// Each CIFAR data file is provided as a Python pickle of NumPy arrays
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
  downloadCIFAR10IfNotPresent(to: directory)
  let np = Python.import("numpy")
  let pickle = Python.import("pickle")
  let path = "\(directory)/cifar-10-batches-py/\(name)"
  let f = Python.open(path, "rb")
  let res = pickle.load(f, encoding: "bytes")

  let bytes = res[Python.bytes("data", encoding: "utf8")]
  let labels = res[Python.bytes("labels", encoding: "utf8")]

  let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
  let images = Tensor<UInt8>(numpy: bytes)!
  let imageCount = images.shape[0]

  // reshape and transpose from the provided N(CHW) to TF default NHWC
  let imageTensor = Tensor<Float>(images
      .reshaped(to: [imageCount, 3, 32, 32])
      .transposed(withPermutations: [0, 2, 3, 1]))

  let mean = Tensor<Float>([0.485, 0.456, 0.406])
  let std  = Tensor<Float>([0.229, 0.224, 0.225])
  let imagesNormalized = ((imageTensor / 255.0) - mean) / std

  return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
  let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
  return Example(
    label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
    data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
  )
}

func loadCIFARTestFile() -> Example {
  return loadCIFARFile(named: "test_batch")
}

func loadCIFAR10() -> (
  training: Dataset<Example>, test: Dataset<Example>) {
    let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
    let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
    return (training: trainingDataset, test: testDataset)
}

## Extra functions
- LeakyReLU

In [0]:
@differentiable
func leakyReLU(_ tensor: Tensor<Float>, negativeSlope: Float) -> Tensor<Float> {
  let zeros = Tensor<Float>(zeros: tensor.shape)
  let minimum = min(zeros, tensor)
  let maximum = max(zeros, tensor)
  let output = maximum + negativeSlope * minimum
  return output
}
/*
func clamp<Parameter: BinaryFloatingPoint, Parameters: KeyPathIterable>(_ parameters: inout Parameters, 
                                                                        low: Parameter = -0.01, 
                                                                        high: Parameter = 0.01) {
  // Iterate over recursively all writable key paths to the
  // parameter type.
  for kp in parameters.recursivelyAllWritableKeyPaths(to: Parameter.self) {
    // Define some new Tensors
//     let lowTensor = Tensor<Float>(repeating: low, shape: parameters[keyPath: kp].shape)
//     let highTensor = Tensor<Float>(repeating: high, shape: parameters[keyPath: kp].shape)
    print(parameters[keyPath: kp])
    if parameters[keyPath: kp] < low {
      parameters[keyPath: kp] = low
    } else if parameters[keyPath: kp] > high {
      parameters[keyPath: kp] = high
    }
//     print(parameters[keyPath: kp])
    print("Yeah!")
  }
//   print("Yeah!")
}
*/

## Configurations

In [0]:
let batchSize = 64
let imageSize = 64
let numberOfChannels = 3
let zLatent = 128
let numberOfEpochs = 5

let lowClampValue: Float = -0.01
let highClampValue: Float = 0.01

## Model

### Helper Layers
- ConvBN
- TransposedConvBN
- UpSampleConvBN

In [0]:
struct ConvBN: Layer, KeyPathIterable {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>

  var conv: Conv2D<Float>
  var norm: BatchNorm<Float>

  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int) = (1, 1),
    padding: Padding = .valid
  ) {
    self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(featureCount: filterShape.3)
  }

  @differentiable
  func call(_ input: Input) -> Output {
    return input.sequenced(through: conv, norm)
  }
}

// DOES NOT WORK
// Currently, TransposedConv2D requires same number of input & output channels. FACK! 😒
struct TransposedConvBN: Layer, KeyPathIterable {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>

  var conv: TransposedConv2D
  var norm: BatchNorm<Float>

  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int) = (1, 1),
    padding: Padding = .valid
  ) {
    self.conv = TransposedConv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(featureCount: filterShape.3)
  }

  @differentiable
  func call(_ input: Input) -> Output {
    return input.sequenced(through: conv, norm)
  }
}

struct UpSampleConvBN: Layer, KeyPathIterable {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var upSample: UpSampling2D<Float>
  var conv: Conv2D<Float>
  var norm: BatchNorm<Float>
  
  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int),
    padding: Padding = .same
  ) {
    self.upSample = UpSampling2D(size: 4)
    self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(featureCount: filterShape.3)
  }
  
  @differentiable
  func call(_ input: Input) -> Output {
    return input.sequenced(through: upSample, conv, norm)
  }
}

### Generative Adversarial Network
- Generator
- Critic

In [0]:
struct Generator: Layer, KeyPathIterable {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var dense = Dense<Float>(inputSize: zLatent, outputSize: 2 * 2 * 1024)
  var norm = BatchNorm<Float>(featureCount: 4096)
  var upSampleConvBN1 = UpSampleConvBN(
    filterShape: (4, 4, 512, 256),
    strides: (2, 2),
    padding: .same
  )
  var upSampleConvBN2 = UpSampleConvBN(
    filterShape: (4, 4, 256, 128),
    strides: (2, 2),
    padding: .same
  )
  var upSampleConvBN3 = UpSampleConvBN(
    filterShape: (4, 4, 128, 64),
    strides: (2, 2),
    padding: .same
  )
//   var upSampleConvBN4 = UpSampleConvBN(
//     filterShape: (4, 4, 128, 64),
//     strides: (2, 2),
//     padding: .same
//   )
  var upSample = UpSampling2D<Float>(size: 4)
  var conv = Conv2D<Float>(
    filterShape: (4, 4, 64, numberOfChannels), 
    strides: (2, 2), 
    padding: .same
  )
  
  @differentiable
  func call(_ input: Input) -> Output {
    var output: Output
    // Dense activation
    output = relu(norm(dense(input)))
    // Expand shape to 4D
    output = output.expandingShape(at: 0).expandingShape(at: 0)
    // Reshape and transpose from N(CHW) to N(HWC)
//     print(output.shape)
    output = output.reshaped(to: [batchSize, 1024, 2, 2]) // N(CHW)
    output = output.transposed(withPermutations: [0, 2, 3, 1]) // NHWC
    // Upsample and convolve through remaining layers
    // 🧻: (Distill: Deconvolution and Checkerboard Artifacts)
    output = relu(upSampleConvBN1(output))
    output = relu(upSampleConvBN2(output))
    output = relu(upSampleConvBN3(output))
//     output = relu(upSampleConvBN4(output))
    output = tanh(conv(upSample(output)))
    return output
  }
}

struct Critic: Layer, KeyPathIterable {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>
  
  var convBN1 = Conv2D<Float>(
    filterShape: (4, 4, numberOfChannels, 64),
    strides: (2, 2),
    padding: .same
  )
  var convBN2 = ConvBN(
    filterShape: (4, 4, 64, 128),
    strides: (2, 2),
    padding: .same
  )
  var convBN3 = ConvBN(
    filterShape: (4, 4, 128, 256),
    strides: (2, 2),
    padding: .same
  )
  var convBN4 = ConvBN(
    filterShape: (4, 4, 256, 512),
    strides: (2, 2),
    padding: .same
  )
//   var convBN5 = ConvBN(
//     filterShape: (4, 4, 512, 1024),
//     strides: (2, 2),
//     padding: .same
//   )
  var flatten = Flatten<Float>()
  var dense = Dense<Float>(inputSize: 2 * 2 * 512, outputSize: 1)
  
  @differentiable
  func call(_ input: Input) -> Output {
    var output: Output
    output = leakyReLU(convBN1(input),  negativeSlope: 0.2)
    output = leakyReLU(convBN2(output), negativeSlope: 0.2)
    output = leakyReLU(convBN3(output), negativeSlope: 0.2)
    output = leakyReLU(convBN4(output), negativeSlope: 0.2)
//     output = leakyReLU(convBN5(output), negativeSlope: 0.2)
    output = flatten(output)
    output = dense(output)
    return output
  }
}

## Set up dataset, inputs, models, and optimizers

- Dataset

In [0]:
// Load the dataset
let cifarDataset = loadCIFAR10()
let testBatches = cifarDataset.test.batched(batchSize)

- Input
- Generative Adversarial Nets
- Optimizers

In [0]:
// Inputs
var zInput = Tensor<Float>(randomUniform: [batchSize, zLatent])
// var zInput: Tensor<Float>

// Iterations for Critic
let numberOfCriticIterations = 1

// Models
var generatorModel = Generator()
var criticModel = Critic()

// Optimizers
// let generatorOptimizer = Adam(for: generatorModel, learningRate: 0.0002, beta2: 0.5)
// let criticOptimizer = Adam(for: criticModel, learningRate: 0.0002, beta2: 0.5)
let generatorOptimizer = RMSProp(for: generatorModel, learningRate: 5e-5)
let criticOptimizer = RMSProp(for: criticModel, learningRate: 5e-5)

In [0]:
// var input = Tensor<Float>(glorotUniform: [batchSize, imageSize, imageSize, numberOfChannels])
// var input = Tensor<Float>(glorotUniform: [batchSize, zLatent])
// print("input: \(input.shape)")

// var outputGenerator = generatorModel(input)
// print("generator: \(outputGenerator.shape)")

// var outputCritic = criticModel(outputGenerator)
// print("outputCritic: \(outputCritic.shape)")

// criticModel.allDifferentiableVariables = criticModel.allDifferentiableVariables.withoutDerivative()

In [10]:
print(Context.local.learningPhase)
Context.local.learningPhase = .training
print(Context.local.learningPhase)

inference
training


## Training

In [11]:
// Store arrays of all the losses
var wassersteinDistances = [Float]()
var criticLosses = [Float]()
var generatorLosses = [Float]()

// Begin training.
print("Started. 🛫")

// Setup the training context
Context.local.learningPhase = .training

for epoch in 1...numberOfEpochs {
  var startTime = time.time()

  //var trainingLossSum: Float = 0
  var trainingBatchCount = 0

  // Shuffle the dataset
  let trainingShuffled = cifarDataset.training.shuffled(
      sampleCount: 50000,
      randomSeed: Int64(epoch)
  )

  // Start the training now
  for (iteration, batch) in trainingShuffled.batched(batchSize).enumerated() {
    //let (_, images) = (batch.label, batch.data)
    let images = batch.data

    if images.shape[0] != batchSize {
      print("Shapes not compatible: \(images.shape) != [\(batchSize), 1]. Skipping")
      continue
    }

    // Track Wasserstein distance
    var wassersteinDistance: Float = 0

    /*
     Train `criticModel` one step
     */
    var averageCriticLoss: Float = 0

    for _ in 1...numberOfCriticIterations {
      let (criticLoss, criticGradients) = valueWithGradient(at: criticModel) {
        criticModel -> Tensor<Float> in

        // Sample uniformly from normal distribution of [`batchSize`, `zLatent`] shaped `Tensor`
        zInput = Tensor<Float>(randomUniform: [batchSize, zLatent])
        //print(zInput.shape)

        // Forward pass `zInput` through `generatorModel`
        // and don't track gradients of `generatorModel`.
        let Gz = generatorModel(zInput).withoutDerivative()
        //print(generatedImage.shape)

        // Critize the `generatedImage` through `criticModel`
        let CGz = criticModel(Gz)
        //print(criticProbability.shape)

        // Real image critic probability
        let Cx = criticModel(images)

        // Calculate the loss finally.
        //let loss = -(log(Cx) + log(1.0 - CGz)).mean()
        //var loss = Cx.mean()
        wassersteinDistance = (Cx - CGz).mean().scalarized()
        let loss = (CGz - Cx).mean()
        return loss
        //return softmaxCrossEntropy(logits: logits, labels: labels)
      }
      averageCriticLoss += criticLoss.scalarized()
      //trainingLossSum += loss.scalarized()
      trainingBatchCount += 1
      criticOptimizer.update(&criticModel.allDifferentiableVariables, along: criticGradients)
      
      // Clamp the parameters
      for kp in criticModel.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
        // Define high and low values
        let lowTensor = Tensor<Float>(repeating: lowClampValue, shape: criticModel[keyPath: kp].shape)
        let highTensor = Tensor<Float>(repeating: highClampValue, shape: criticModel[keyPath: kp].shape)
        // Adjust the parameters by clamping in range [lowTensor, highTensor]
        if criticModel[keyPath: kp] < lowTensor {
          criticModel[keyPath: kp] = lowTensor
        } else if criticModel[keyPath: kp] > highTensor {
          criticModel[keyPath: kp] = highTensor
        }
        // Let's check if it's done correctly.
        //print(min(lowTensor, criticModel[keyPath: kp]))
        //print(max(zerosTensor, criticModel[keyPath: kp]))
      }

    }
    // Average the criticLoss
    averageCriticLoss /= Float(numberOfCriticIterations)
    // Get all the criticLosses and Wasserstein distances
    criticLosses.append(averageCriticLoss)
    wassersteinDistances.append(wassersteinDistance)

    /*
     Train `generatorModel` one step
     */
    let (generatorLoss, generatorGradients) = valueWithGradient(at: generatorModel) {
      generatorModel -> Tensor<Float> in

      // Sample uniformly from normal distribution of [`batchSize`, `zLatent`] shaped `Tensor`
      zInput = Tensor<Float>(randomUniform: [batchSize, zLatent])
      //print(zInput.shape)

      // Forward pass `zInput` through `generatorModel`.
      let Gz = generatorModel(zInput) //.withoutDerivative()
      //print(generatedImage.shape)

      // Critize the `generatedImage` through `criticModel` withoutDerivative()
      let CGz = criticModel(Gz)//.withoutDerivative()
      //print(criticProbability.shape)

      // Calculate the loss finally.
      //let loss = log(1.0 - CGz).mean()
      let loss = -CGz.mean()
      return loss
    }
    generatorOptimizer.update(&generatorModel.allDifferentiableVariables, along: generatorGradients)
    generatorLosses.append(generatorLoss.scalarized())

    // Print statistics
    if iteration % 10 == 0{
      print("""
        Epoch: [\(epoch)/\(numberOfEpochs)] \
        iteration: \(iteration) \
        wassersteinDistance: \(wassersteinDistance) \
        criticLoss: \(averageCriticLoss) \
        generatorLoss: \(generatorLoss.scalarized())
        """)
    }
    
  }
  // Save the losses plot
  plt.plot(Array(1...wassersteinDistances.count),
           wassersteinDistances,
           "r-", label: "Wasserstein Distance")
  plt.plot(Array(1...criticLosses.count),
           criticLosses,
           "g-", label: "Critic Loss")
  plt.plot(Array(1...generatorLosses.count),
           generatorLosses,
           "b-", label: "Generator Loss")

  plt.xlabel("Iterations")
  plt.legend(loc: 9)
  plt.title("Wasserstein GAN in S4TF")

  plt.savefig("LossGraph.pdf")
  plt.close()
  
  var endTime = time.time()
  
  print("Took: \(endTime - startTime) s")
}

print("Finished. 🛬")

Started. 🛫
Epoch: [1/5] iteration: 0 wassersteinDistance: 0.012591884 criticLoss: -0.012591884 generatorLoss: 0.00045472168
Epoch: [1/5] iteration: 10 wassersteinDistance: 4.6538014e-05 criticLoss: -4.6538014e-05 generatorLoss: 0.0005876661
Epoch: [1/5] iteration: 20 wassersteinDistance: -1.8906581e-05 criticLoss: 1.8906581e-05 generatorLoss: 0.00088533247
Epoch: [1/5] iteration: 30 wassersteinDistance: -8.566382e-05 criticLoss: 8.566382e-05 generatorLoss: 0.0008633041
Epoch: [1/5] iteration: 40 wassersteinDistance: 6.801309e-05 criticLoss: -6.801309e-05 generatorLoss: 0.00093450915
Epoch: [1/5] iteration: 50 wassersteinDistance: -5.296575e-06 criticLoss: 5.296575e-06 generatorLoss: 0.00087156787
Epoch: [1/5] iteration: 60 wassersteinDistance: -6.156739e-05 criticLoss: 6.156739e-05 generatorLoss: 0.00072640553
Epoch: [1/5] iteration: 70 wassersteinDistance: -1.8690374e-05 criticLoss: 1.8690374e-05 generatorLoss: 0.00094254105
Epoch: [1/5] iteration: 80 wassersteinDistance: 1.3872006e-0