##### Copyright 2020 The TensorFlow Authors. [Licensed under the Apache License, Version 2.0](#scrollTo=ByZjmtFgB_Y5).

In [None]:
%install-location $cwd/swift-install
%install '.package(url: "https://github.com/tensorflow/swift-models", .branch("master"))' ModelSupport
print("\u{001B}[2J")

In [None]:
// #@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

<table class="tfo-notebook-buttons" align="left">
 <td>
  <a target="_blank" href="https://colab.research.google.com/github/BradLarson/swift-models/blob/Physarum/Examples/Physarum/Physarum.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
 </td>
 <td>
  <a target="_blank" href="https://github.com/BradLarson/swift-models/blob/Physarum/Examples/Physarum/Physarum.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
 </td>
</table>

# Physarum

This is an implementation in Swift for TensorFlow of the [Physarum transport model](https://www.sagejenson.com/physarum) described by Sage Jenson, inspired by [“Characteristics of pattern formation and evolution in approximations of Physarum transport networks.”](http://eprints.uwe.ac.uk/15260/1/artl.2010.16.2.pdf)

## Device setup and model parameters

We'll start by importing the appropriate modules:

In [None]:
import Foundation
import TensorFlow
import ModelSupport

Next, we'll configure the accelerator the tensor operations will run on.

In [None]:
let device = Device.defaultTFEager
// let device = Device.defaultXLA
// device

To aid us in displaying images within the notebook, we'll use Swift's Python interoperability to set up an image display function.

In [None]:
import PythonKit

%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
let display = Python.import("IPython.core.display")

func showImageFile(_ filename: String) {
  display.Image(Python.open(filename, "rb").read()).display()
}

The following contains all model parameters used during training:

In [None]:
let gridSize = 512
let particleCount = 1024
let senseAngle = 0.20 * Float.pi
let senseDistance: Float = 4.0
let evaporationRate: Float = 0.95
let moveAngle = 0.1 * Float.pi
let moveStep: Float = 2.0
let channelSize = 1

## The Physarum transport model



Before we write the update step, we'll first set up a helper method for masking Tensor values:

In [None]:
extension Tensor where Scalar: Numeric {
  func mask(condition: (Tensor) -> Tensor<Bool>) -> Tensor {
    let satisfied = condition(self)
    return Tensor(zerosLike: self)
      .replacing(with: Tensor(onesLike: self), where: satisfied)
  }
}

as well as a function for converting an angle (in radians) into an X, Y displacement:

In [None]:
func angleToVector(_ angle: Tensor<Float>) -> Tensor<Float> {
  return Tensor(stacking: [cos(angle), sin(angle)], alongAxis: -1)
}

At present, the Swift for TensorFlow API lacks high-level N-dimensional gathering and scattering, so we'll define those operators in terms of the lower-level _Raw APIs. These will be added soon, at which point this can be removed.

In [None]:
extension Tensor where Scalar: TensorFlowFloatingPoint {
  @inlinable
  @differentiable(wrt: self)
  public func dimensionGathering<Index: TensorFlowIndex>(
    atIndices indices: Tensor<Index>
  ) -> Tensor {
    return _Raw.gatherNd(params: self, indices: indices)
  }

  @inlinable
  @derivative(of: dimensionGathering)
  func _vjpDimensionGathering<Index: TensorFlowIndex>(
    atIndices indices: Tensor<Index>
  ) -> (value: Tensor, pullback: (Tensor) -> Tensor) {
    let shapeTensor = Tensor<Index>(self.shapeTensor)
    let value = _Raw.gatherNd(params: self, indices: indices)
    return (
      value,
      { v in
        let dparams = _Raw.scatterNd(indices: indices, updates: v, shape: shapeTensor)
        return dparams
      }
    )
  }

  @inlinable
  @differentiable(wrt: self)
  public func dimensionScattering<Index: TensorFlowIndex>(
    atIndices indices: Tensor<Index>, shape: Tensor<Index>
  ) -> Tensor {
    return _Raw.scatterNd(indices: indices, updates: self, shape: shape)
  }

  @inlinable
  @derivative(of: dimensionScattering)
  func _vjpDimensionScattering<Index: TensorFlowIndex>(
    atIndices indices: Tensor<Index>, shape: Tensor<Index>
  ) -> (value: Tensor, pullback: (Tensor) -> Tensor) {
    let value = _Raw.scatterNd(indices: indices, updates: self, shape: shape)
    return (
      value,
      { v in
        let dparams = _Raw.gatherNd(params: v, indices: indices)
        return dparams
      }
    )
  }
}

We'll configure the initial state and some commonly reused variables:

In [None]:
var grid = Tensor<Float>(zeros: [2, gridSize, gridSize], on: device)
var positions = Tensor<Float>(randomUniform: [particleCount, 2], on: device) * Float(gridSize)
var headings = Tensor<Float>(randomUniform: [particleCount], on: device) * 2.0 * Float.pi
let gridShape = Tensor<Int32>(
  shape: [2], scalars: [Int32(gridSize), Int32(gridSize)], on: device)
let scatterValues = Tensor<Float>(ones: [particleCount], on: device)

I'm defining the stepping of the model all in one function, with a phase for specifying which half of the grid to update:

In [None]:
func step(phase: Int) {
  var currentGrid = grid[phase]
  // Perceive
  let senseDirection = headings.expandingShape(at: 1).broadcasted(to: [particleCount, 3])
    + Tensor<Float>([-moveAngle, 0.0, moveAngle], on: device)
  let sensingOffset = angleToVector(senseDirection) * senseDistance
  let sensingPosition = positions.expandingShape(at: 1) + sensingOffset
  let sensingIndices = abs(Tensor<Int32>(sensingPosition))
    % (gridShape.expandingShape(at: 0).expandingShape(at: 0))
  let sensedValues = currentGrid.expandingShape(at: 2)
    .dimensionGathering(atIndices: sensingIndices).squeezingShape(at: 2)
  
  // Move
  let lowValues = sensedValues.argmin(squeezingAxis: -1)
  let highValues = sensedValues.argmax(squeezingAxis: -1)
  let middleMask = lowValues.mask { $0 .== 1 }
  let middleDistribution = Tensor<Float>(randomUniform: [particleCount], on: device)
  let randomTurn = middleDistribution.mask { $0 .< 0.1 } * Tensor<Float>(middleMask)
  let turn = Tensor<Float>(highValues - 1) * Tensor<Float>(1 - middleMask) + randomTurn
  headings += (turn * moveAngle)
  positions += angleToVector(headings) * moveStep
  
  // Deposit
  let depositIndices = abs(Tensor<Int32>(positions)) % (gridShape.expandingShape(at: 0))
  let deposits = scatterValues.dimensionScattering(atIndices: depositIndices, shape: gridShape)
  currentGrid += deposits
  
  // Diffuse
  currentGrid = currentGrid.expandingShape(at: 0).expandingShape(at: 3)
  currentGrid = currentGrid.padded(forSizes: [(0, 0), (1, 1), (1, 1), (0, 0)], mode: .reflect)
  currentGrid = avgPool2D(currentGrid, filterSize: (1, 3, 3, 1), strides: (1, 1, 1, 1), padding: .valid)
  currentGrid = currentGrid * evaporationRate
  grid[1 - phase] = currentGrid.squeezingShape(at: 3).squeezingShape(at: 0)
}

## Running the model
With the stepping function defined, we can run the model for a set number of steps. The phase is alternated on each step. The grid tensor is captured at each step for later use in constructing an animation of the evolution of the grid over time.

In [None]:
let stepCount = 100

var steps: [Tensor<Float>] = []
for stepIndex in 0..<stepCount {
  step(phase: stepIndex % 2)
  LazyTensorBarrier()
  steps.append(grid[0].expandingShape(at: 2) * 255.0)
}

Finally, an animated GIF of the grid evolution can be generated and displayed:

In [None]:
try steps.saveAnimatedImage(directory: "output", name: "physarum", delay: 1)
showImageFile("output/physarum.gif")