In [None]:
import TensorFlow

In [None]:
//export
// FALayer is a layer that supports callbacks through its LayerDelegate.`a`
public protocol FALayer: Layer {
    var delegates: [(Output) -> ()] { get set }
    
    // FALayer's will implement this instead of `func call`.
    @differentiable
    func forward(_ input: Input) -> Output
    
    associatedtype Input
    associatedtype Output
}

In [None]:
//export
public extension FALayer {
    
    @differentiable(wrt: (self))
    @differentiable(vjp: callGrad)
    func callAsFunction(_ input: Input) -> Output {
        let activation = forward(input)
        for d in delegates { d(activation) }
        return activation
    }
       
    func callGrad(_ input: Input) ->
        (Output, (Self.Output.TangentVector) -> (Self.TangentVector, Self.Input.TangentVector)) {
        return Swift.valueWithPullback(at: self, input) { (m, i) in m.forward(i) }
    }
    
    mutating func addDelegate(_ d: @escaping (Output) -> ()) { delegates.append(d) }
}

In [None]:
//export
@frozen
public struct FADense<Scalar: TensorFlowFloatingPoint>: FALayer {
    // Note: remove the explicit typealiases after TF-603 is resolved.
    public typealias Input = Tensor<Scalar>
    public typealias Output = Tensor<Scalar>
    public var weight: Tensor<Scalar>
    public var bias: Tensor<Scalar>
    public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
    @noDerivative public var delegates: [(Output) -> ()] = []
    @noDerivative public let activation: Activation

    public init(
        weight: Tensor<Scalar>,
        bias: Tensor<Scalar>,
        activation: @escaping Activation
    ) {
        self.weight = weight
        self.bias = bias
        self.activation = activation
    }

    @differentiable
    //@differentiable(wrt: (self))
    public func forward(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
        return activation(input • weight + bias)
    }
}

public extension FADense {
    init(_ nIn: Int, _ nOut: Int, activation: @escaping Activation = identity) {
        self.init(weight: Tensor(randomNormal: [nIn, nOut]),
                  bias: Tensor(zeros: [nOut]),
                  activation: activation)
    }
}