In [16]:
using Flux, DataFrames, CSV

In [2]:
rnn = Flux.RNNCell(2, 3)

x = rand(Float32, 2) # dummy data
h = rand(Float32, 3)  # initial hidden state

h, y = rnn(h, x)

(Float32[-0.9481508, 0.9169154, -0.6440355], Float32[-0.9481508, 0.9169154, -0.6440355])

In [3]:
vcat(x, h)

5-element Vector{Float32}:
  0.15125108
  0.9087713
 -0.9481508
  0.9169154
 -0.6440355

In [4]:
m = Chain(RNN(2, 5), Dense(5, 1))
m(rand(Float32, 2))

1-element Vector{Float32}:
 0.26668862

In [14]:
const n_input_temps = 2
const n_non_input_temps = 3
const n_total_inputs = n_input_temps + n_non_input_temps
const n_targets = 3
const n_temps = n_targets + n_input_temps
const n_profiles = 49


mutable struct HeatTransferLayer
    n_temps::Int
    n_targets::Int
    conductance_net
    adj_mat::Matrix{Int8}
end

function HeatTransferLayer(n_temps::Integer, n_targets::Integer)
    # populate adjacency matrix
    adj_mat = zeros(Int8, n_temps, n_temps)
    k = 1
    for col_j in 1:n_temps
        for row_i in col_j+1:n_temps
            adj_mat[row_i, col_j] = k
            k += 1
        end
    end
    adj_mat = adj_mat + adj_mat'
    n_conds = Int(0.5 * n_temps * (n_temps - 1))
    HeatTransferLayer(n_temps, n_targets, 
                      Dense(n_total_inputs + n_targets, n_conds, σ),
                      adj_mat)
end

# overload struct to make it callable
function (m::HeatTransferLayer)(all_input)
    n_temps = m.n_temps
    prev_out = @views all_input[1:m.n_targets, :]
    temps = @views all_input[1:n_temps, :]
    
    conductances = m.conductance_net(all_input)
    
    # subtract, scale, and sum
    tmp = zeros(eltype(prev_out), size(prev_out))
    for i in 1:m.n_targets
        for j in 1:n_temps
            if j != i
                @. tmp[i, :] += (temps[j, :] - prev_out[i, :]) * conductances[m.adj_mat[i, j], :]
            end
        end
    end

    return tmp
end

Flux.@functor HeatTransferLayer
ploss_net = Chain(Dense(n_total_inputs + n_targets, 8, σ),
                  Dense(8, n_targets, σ))
heat_transfer = HeatTransferLayer(n_temps, n_targets)
prll = Parallel(+, ploss_net, heat_transfer)

function my_cell(prev_̂y, x)
    x_non_temps, x_temps = x
    xx = vcat(prev_̂y, x_temps, x_non_temps)
    y = prll(xx)
    return y, prev_̂y
end

# smoke-test the topology
xs = [(rand(Float32, n_non_input_temps, n_profiles), 
        rand(Float32, n_input_temps, n_profiles)) for i in 1:10]
h = rand(Float32, n_targets, n_profiles)  # initial hidden state
m = Flux.Recur(my_cell, h)

# predict
ys = [m(x) for x in xs]


10-element Vector{Matrix{Float32}}:
 [0.15575695 0.47646832 … 0.4927429 0.9418055; 0.8937129 0.45485008 … 0.89658284 0.403031; 0.07059038 0.9383588 … 0.69216573 0.9007287]
 [0.7597275 0.4969073 … 0.49025214 -0.53089094; -0.9860904 0.6267067 … -0.30998075 1.0288675; 1.0563452 -0.422546 … 0.20994455 -0.019340098]
 [-0.92069817 0.38168728 … 0.12233563 2.7148623; 3.793907 0.0326142 … 1.8290619 -1.2502389; -0.6850257 2.5838437 … 0.73840535 1.6738575]
 [2.7132058 1.2482662 … 1.8034438 -5.6915526; -6.7145996 2.4748693 … -2.147357 4.3415427; 5.1458282 -2.5789168 … 0.5697482 0.45128942]
 [-11.154015 -2.1126046 … -4.3827705 14.89292; 17.422298 -3.2937376 … 5.6016617 -15.819507; -1.523951 9.827733 … 0.53956985 3.3076]
 [8.212044 6.47008 … 10.537048 -60.46117; -40.75537 14.843651 … -17.05578 40.016; 20.450392 -9.17193 … 3.589607 11.748254]
 [-62.41081 -15.423424 … -50.094433 161.25975; 91.137405 -19.137611 … 43.691494 -198.73477; -1.0432239 44.88357 … 5.8553762 9.321945]
 [24.158127 27.772429 … 14