<a href="https://colab.research.google.com/gist/marcrasi/32bd67b4497a2151242cb646819d6404/differentiable-physics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Imports

%install '.package(url: "https://github.com/marcrasi/swift-vec2", .branch("master"))' Vec2

// Clear installation messages from output area.
print("\u{001B}[2J")




In [0]:
#@title Helpers

let iterationCount = 100
let maxLr = Float(1e-1)
let minLr = Float(1e-2)
func learningRate(_ i: Int = 0) -> Float {
    return (maxLr - minLr) * (cos(Float.pi * Float(i) / Float(iterationCount)) + 1) / 2 + minLr
}

In [0]:
#@title State Datastructures

import Vec2

struct BallState: AdditiveArithmetic, Differentiable {
  var position: Vec2
  var velocity: Vec2

  static let ballRadius = Float(1)

  @differentiable
  init(position: Vec2, velocity: Vec2) {
      self.position = position
      self.velocity = velocity
  }

  @differentiable
  func updating(position: Vec2) -> BallState {
    BallState(position: position, velocity: velocity)
  }

  @differentiable
  func updating(velocity: Vec2) -> BallState {
    BallState(position: position, velocity: velocity)
  }

  @differentiable
  func moved(_ delta: Vec2) -> BallState {
    updating(position: position + delta)
  }

  @differentiable
  func impulsed(_ delta: Vec2) -> BallState {
    updating(velocity: velocity + delta)
  }
}

struct BallTup: Differentiable {
  var ball1: BallState
  var ball2: BallState
  var collisionEnergy: Float

  @differentiable
  init(_ ball1: BallState, _ ball2: BallState, _ collisionEnergy: Float) {
    self.ball1 = ball1
    self.ball2 = ball2
    self.collisionEnergy = collisionEnergy
  }
}

struct Wall {
  var p1: Vec2
  var p2: Vec2
}

struct SimulationParameters {
  var dt: Float = 0.02
  var lambda: Float = 0.0
}

extension BallState {
  @differentiable
  func stepped(_ params: SimulationParameters) -> BallState {
    let frictionAcceleration = Float(1)
    let friction = Vec2(magnitude: frictionAcceleration * params.dt, direction: velocity.direction)
    var newVelocity = friction.magnitude > velocity.magnitude ? Vec2(0, 0) : velocity - friction

    if Float.random(in: 0..<1) > exp(-params.lambda * params.dt) {
        newVelocity = newVelocity + newVelocity.magnitude * Vec2(Float.random(in: (-0.5...0.5)), Float.random(in: (-0.5...0.5)))
    }

    return BallState(
      position: position + params.dt * newVelocity,
      velocity: newVelocity)
  }

  func touches(_ other: BallState) -> Bool {
    return (position - other.position).magnitude <= 2 * BallState.ballRadius
  }

  private func projected(to wall: Wall) -> Float {
    (wall.p1.magnitudeSquared + position.dot(wall.p2 - wall.p1) - wall.p1.dot(wall.p2)) / (wall.p2 - wall.p1).magnitudeSquared
  }

  func touches(_ wall: Wall) -> Bool {
    let t = projected(to: wall)
    if t < 0 || t > 1 { return false }
    let projection = (1 - t) * wall.p1 + t * wall.p2
    return (position - projection).magnitudeSquared <= BallState.ballRadius * BallState.ballRadius
  }

  @differentiable
  func bounced(_ wall: Wall) -> BallState {
    let tangent = wall.p2 - wall.p1
    let unitTangent = tangent / tangent.magnitude
    let unitNormal = Vec2(-unitTangent.y, unitTangent.x)
    let t = projected(to: wall)
    let projection = (1 - t) * wall.p1 + t * wall.p2
    let displacement = position - projection
    if velocity.dot(displacement) > 0 { return self }
    let newVelocity = velocity.dot(unitTangent) * unitTangent - velocity.dot(unitNormal) * unitNormal
    return BallState(position: position, velocity: newVelocity)
  }

  @differentiable
  static func collide(_ a: BallState, _ b: BallState) -> BallTup {
    updateCollisionVelocities(a, b)
  }

  private static func updateCollisionVelocities(_ ball1: BallState, _ ball2: BallState) -> BallTup {
    // Perfectly elastic collision. This is the impulse along the normal that preserves kinetic energy.
    let p = ball2.position - ball1.position
    let v = ball2.velocity - ball1.velocity
    let vdotp = v.dot(p)
    if vdotp > 0 { return BallTup(ball1, ball2, 0) }
    let impulse = (-vdotp / p.magnitudeSquared) * p
    return BallTup(ball1.impulsed(-1 * impulse), ball2.impulsed(impulse), impulse.magnitudeSquared)
  }
}

In [0]:
#@title WorldState
struct WorldState: Differentiable {
    var balls: [BallState]
    var minDistanceToTarget: [Float]
    var t: Float

    @noDerivative var targets: [Vec2]
    @noDerivative var walls: [Wall]

    @differentiable
    init(
      balls: [BallState],
      minDistanceToTarget: [Float],
      t: Float,
      targets: [Vec2],
      walls: [Wall]
    ) {
        self.balls = balls
        self.minDistanceToTarget = minDistanceToTarget
        self.t = t
        self.targets = targets
        self.walls = walls
    }
}

In [0]:
#@title Simulation Logic

extension WorldState {
  @differentiable
  init(balls: [BallState], targets: [Vec2], walls: [Wall]) {
      self.init(
          balls: balls,
          minDistanceToTarget: Array(repeating: Float.infinity, count: withoutDerivative(at: targets.count)),
          t: 0,
          targets: targets,
          walls: walls
    )
  }

  @differentiable
  init(ball1InitialVelocity: Vec2) {
    self.init(
        balls: [
            BallState(position: Vec2(-20, 0), velocity: ball1InitialVelocity),
            BallState(position: Vec2(-10, 0), velocity: Vec2(0, 0))
        ],
        targets: [
            Vec2(0, 5),
            Vec2(-5, 15),
            Vec2(-2, -20)
        ],
        walls: [
            Wall(p1: Vec2(7, -10), p2: Vec2(7, 20))
        ]
    )
  }
}

extension WorldState {
  var still: Bool {
    if t > 7 { return true }
    for ball in balls {
      if ball.velocity.magnitude > 0 {
        return false
      }
    }
    return true
  }
}

extension WorldState {
  @differentiable
  func stepped(_ params: SimulationParameters = SimulationParameters()) -> WorldState {
    // Integrate the ball velocity.
    var updatedBalls = balls.differentiableMap { $0.stepped(params) }

    // Collide the balls with the walls.
    updatedBalls = updatedBalls.differentiableMap { [walls = walls] (ball: BallState) -> BallState in
      for i in withoutDerivative(at: walls.indices) {
        let wall = walls[i]
        if ball.touches(wall) {
          return ball.bounced(wall)
        }
      }
      return ball
    }
    
    // Collide the balls with each other.
    if updatedBalls[0].touches(updatedBalls[1]) {
      let collidedBalls = BallState.collide(updatedBalls[0], updatedBalls[1])
      updatedBalls = [collidedBalls.ball1, collidedBalls.ball2]
    }

    // Update min target distance.
    var newMinTargetDistance: [Float] = []
    for i in withoutDerivative(at: targets.indices) {
        let distTo1 = (updatedBalls[0].position - targets[i]).magnitude
        let distTo2 = (updatedBalls[1].position - targets[i]).magnitude
        var curTargetDistance = distTo1 < distTo2 ? distTo1 : distTo2
        if curTargetDistance < 2 * BallState.ballRadius { curTargetDistance = 2 * BallState.ballRadius }
        if curTargetDistance < minDistanceToTarget[i] {
            newMinTargetDistance = newMinTargetDistance + [curTargetDistance]
        } else {
            newMinTargetDistance = newMinTargetDistance + [minDistanceToTarget[i]]
        }
    }

    return WorldState(
        balls: updatedBalls,
        minDistanceToTarget: newMinTargetDistance.withDerivative { [count = withoutDerivative(at: newMinTargetDistance.count)] (d: inout Array<Float>.DifferentiableView) -> () in
            if d.base.count == 0 {
                d = Array.DifferentiableView(Array(repeating: 0, count: count))
            }
        },
        t: t + params.dt,
        targets: targets,
        walls: walls
    )
  }

  @differentiable
  func steppedUntilStill(_ params: SimulationParameters, _ f: (WorldState) -> () = { _ in }) -> WorldState {
    var state = self
    while !state.still {
      f(state)
      state = state.stepped(params)
    }
    f(state)
    return state
  }
}

struct DrawingArrow {
    var offset: Vec2
    var color: String
    var direction: Vec2
}

func svg(states: [WorldState], params: SimulationParameters = SimulationParameters(), vectors: [[DrawingArrow]] = [], delay: Float = 0) -> String {
  let scale = Float(10)
  let origin = Vec2(250, 250)
  let size = Vec2(500, 500)

  func transformed(_ position: Vec2) -> Vec2 {
    scale * position + origin
  }

  var r = ""
  r += """
    <svg width="\(size.x)", height="\(size.y)">\n
  """

  let totalDuration = (states.last?.t ?? 0) + delay
  r += """
    <rect>
        <animate
            id="looper"
            begin="0;looper.end"
            attributeName="visibility"
            from="hide"
            to="hide"
            dur="\(totalDuration)s" />
    </rect>
  """

  func circle(id: String, cx: String, cy: String, fill: String) -> String {
    """
      <circle
        id="\(id)"
        r="\(scale * BallState.ballRadius)"
        cx="\(cx)"
        cy="\(cy)"
        fill="\(fill)" />\n
    """
  }

  func line(p1: Vec2, p2: Vec2) -> String {
    """
      <line
        x1="\(p1.x)"
        y1="\(p1.y)"
        x2="\(p2.x)"
        y2="\(p2.y)"
        style="stroke:#000;stroke-width:2" />\n
    """
  }

  func arrow(id: String, _ arrow: DrawingArrow, _ base: Vec2) -> String {
      let p1 = base + arrow.offset
      let p2 = p1 + arrow.direction
      return """
        <line
            id="\(id)"
            x1="\(p1.x)"
            y1="\(p1.y)"
            x2="\(p2.x)"
            y2="\(p2.y)"
            style="stroke:\(arrow.color);stroke-width:2" />\n
      """
  }

for (index, finalPosition) in (states.first?.targets ?? []).enumerated() {
    let position = transformed(finalPosition)
    r += circle(id: "target\(index)", cx: "\(position.x)", cy: "\(position.y)", fill: "red")
}

  for (index, ball) in (states.first?.balls ?? []).enumerated() {
    let position = transformed(ball.position)
    r += circle(id: "ball\(index)", cx: "\(position.x)", cy: "\(position.y)", fill: "orange")

    if vectors.count > index {
        for (index2, vector) in vectors[index].enumerated() {
            let id = "arrow\(index)_\(index2)"
            r += arrow(id: id, vector, position)
            r += """
                
                <animate
                    xlink:href="#\(id)"
                    attributeName="opacity"
                    from="0"
                    to="1"
                    dur="\(params.dt)s"
                    begin="looper.begin"
                    fill="freeze" />
                <animate
                    xlink:href="#\(id)"
                    attributeName="opacity"
                    from="1"
                    to="0"
                    dur="\(params.dt)s"
                    begin="looper.begin+\(delay)s"
                    fill="freeze" />
            """
        }
    }
  }

  for wall in (states.first?.walls ?? []) {
    r += line(p1: transformed(wall.p1), p2: transformed(wall.p2))
  }

  for (timeIndex, (state, nextState)) in zip(states, states.dropFirst(1)).enumerated() {
    let t = Float(timeIndex) * params.dt + delay
    for (ballIndex, (ballState, nextBallState)) in zip(state.balls, nextState.balls).enumerated() {
      func animate(attributeName: String, from: String, to: String) -> String {
        """
          <animate
            xlink:href="#ball\(ballIndex)"
            attributeName="\(attributeName)"
            from="\(from)"
            to="\(to)"
            dur="\(params.dt)s"
            begin="looper.begin+\(t)s" />\n
        """
      }
      let position = transformed(ballState.position)
      let nextPosition = transformed(nextBallState.position)
      r += animate(attributeName: "cx", from: String(position.x), to: String(nextPosition.x))
      r += animate(attributeName: "cy", from: String(position.y), to: String(nextPosition.y))
    }

    for (targetIndex, (targetDistance, nextTargetDistance)) in zip(state.minDistanceToTarget, nextState.minDistanceToTarget).enumerated() {
        if targetDistance > 2 && nextTargetDistance <= 2 {
            r += """
                <animate
                    xlink:href="#target\(targetIndex)"
                    attributeName="fill"
                    from="green"
                    to="red"
                    dur="\(params.dt)s"
                    begin="looper.begin"
                    fill="freeze" />
                """
            r += """
                <animate
                    xlink:href="#target\(targetIndex)"
                    attributeName="fill"
                    from="red"
                    to="green"
                    dur="\(params.dt)s"
                    begin="looper.begin+\(t)s"
                    fill="freeze" />
                """
        }
    }
  }

  r += "</svg>\n"
  return r
}

import Python
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
let display = Python.import("IPython.core.display")

func drawSVG(states: [WorldState], params: SimulationParameters = SimulationParameters(), vectors: [[DrawingArrow]] = [], delay: Float = 0) {
    display[dynamicMember: "display"](display.HTML(svg(states: states, params: params, vectors: vectors, delay: delay)))
}

func drawSimulation(ball1InitialVelocity: Vec2, params: SimulationParameters = SimulationParameters()) {
    let initialState = WorldState(ball1InitialVelocity: ball1InitialVelocity)
    let vectors = initialState.balls.enumerated().map { (index: Int, ball: BallState) -> [DrawingArrow] in
        let scaledVelocity = 2 * ball.velocity
        let vs = [DrawingArrow(offset: Vec2(0, 0), color: "blue", direction: scaledVelocity)]
        return vs
    }
    var allStates: [WorldState] = []
    initialState.steppedUntilStill(params) { allStates.append($0) }
    drawSVG(states: allStates, params: params, vectors: vectors, delay: 1)
}

func drawGradients(ball1InitialVelocity: Vec2, params: SimulationParameters = SimulationParameters(), grad: Vec2) {
    let initialState = WorldState(ball1InitialVelocity: ball1InitialVelocity)
    let vGrads = [-10 * grad]
    let vectors = initialState.balls.enumerated().map { (index: Int, ball: BallState) -> [DrawingArrow] in
        let scaledVelocity = 2 * ball.velocity
        var vs = [DrawingArrow(offset: Vec2(0, 0), color: "blue", direction: scaledVelocity)]
        if vGrads.count > index {
            vs.append(DrawingArrow(offset: scaledVelocity, color: "green", direction: vGrads[index]))
        }
        return vs
    }
    drawSVG(states: [initialState], params: params, vectors: vectors, delay: 3600)
}

In [0]:
#@title Main Simulation Loop

@differentiable
func simulate(_ initialState: WorldState) -> WorldState {
    var state = initialState
    while !state.still {
        state = state.stepped()
    }
    return state
}

In [8]:
var ball1InitialVelocity = Vec2(20, 0.01)
drawSimulation(ball1InitialVelocity: ball1InitialVelocity)

In [9]:
@differentiable
func loss(ball1InitialVelocity: Vec2) -> Float {
    let initialState = WorldState(ball1InitialVelocity: ball1InitialVelocity)
    let finalState = simulate(initialState)
    return finalState.minDistanceToTarget.differentiableReduce(0, +)
}
loss(ball1InitialVelocity: ball1InitialVelocity)

39.977932


In [13]:
let grad = gradient(at: ball1InitialVelocity, in: loss)
drawGradients(ball1InitialVelocity: ball1InitialVelocity, grad: grad)
ball1InitialVelocity -= learningRate() * grad

In [14]:
drawSimulation(ball1InitialVelocity: ball1InitialVelocity)

In [15]:
#@title Training Loop

for i in 0..<iterationCount {
    let (loss, grad) = valueWithGradient(at: ball1InitialVelocity) { ball1InitialVelocity -> Float in
    let initialState = WorldState(ball1InitialVelocity: ball1InitialVelocity)
    let finalState = simulate(initialState)
        return finalState.minDistanceToTarget.differentiableReduce(0, +)
    }
    print("\(i): loss \(loss)")
    ball1InitialVelocity -= learningRate(i) * grad
}

// let (loss, grad) = valueWithGradient(at: ball1InitialVelocity) { ball1InitialVelocity -> Float in
//     let initialState = WorldState(ball1InitialVelocity: ball1InitialVelocity)
//     let finalState = simulate(initialState)
//     return finalState.minDistanceToTarget.differentiableReduce(0, +)
// }
// drawGradients(ball1InitialVelocity: ball1InitialVelocity, grad: grad)
// ball1InitialVelocity -= learningRate() * grad

// drawSimulation(ball1InitialVelocity: ball1InitialVelocity)

0: loss 21.04597
1: loss 20.070435
2: loss 18.87973
3: loss 17.46678
4: loss 13.101482
5: loss 11.001682
6: loss 21.909658
7: loss 21.1699
8: loss 20.349894
9: loss 19.384205
10: loss 18.246502
11: loss 16.851357
12: loss 12.034846
13: loss 13.339533
14: loss 17.814167
15: loss 19.226051
16: loss 18.07266
17: loss 16.508118
18: loss 9.455233
19: loss 18.540794
20: loss 13.669064
21: loss 16.349358
22: loss 8.894419
23: loss 17.748875
24: loss 13.843623
25: loss 16.937824
26: loss 12.324703
27: loss 9.76158
28: loss 20.71047
29: loss 19.922564
30: loss 19.009449
31: loss 17.979296
32: loss 16.846195
33: loss 12.829235
34: loss 7.6661596
35: loss 19.012976
36: loss 17.997353
37: loss 16.422146
38: loss 10.441612
39: loss 12.143142
40: loss 19.473305
41: loss 18.596855
42: loss 17.672577
43: loss 15.762589
44: loss 10.225161
45: loss 10.676186
46: loss 19.179134
47: loss 18.37634
48: loss 17.531239
49: loss 15.660039
50: loss 10.337152
51: loss 8.296061
52: loss 18.748028
53: loss 17.962

In [16]:
drawSimulation(ball1InitialVelocity: ball1InitialVelocity)