In [None]:
import tensorflow as tf
import numpy as np

In [None]:
a = 5
b = 2
c = 5

In [None]:
tf.random.set_seed(90)
M = tf.Variable(tf.random.normal((a, b, c), 5, 2), name='M')
p = tf.Variable(tf.random.uniform((), 0, 1, name='p'))

## Implementation using accumulative functions

In [None]:
with tf.device("/cpu:0"):
    with tf.GradientTape(persistent=True) as tape:
        reshaped_M = tf.reshape(M, [a*b, -1])  # dropping one dimension to ease the operations on the last dimension
        N = tf.reshape(tf.map_fn(lambda row: tf.scan(lambda x, y: x*p+y, row), reshaped_M), M.shape)
print('dN/dM:')
print(tape.gradient(N, M))
print('dN/dp:', tape.gradient(N, p).numpy())

In [None]:
N

## The same, but implemented with numpy - for verification purposes

In [None]:
memory_accumulation = np.frompyfunc(lambda x, y: y + p.numpy()*x, 2, 1)

In [None]:
N_numpy = memory_accumulation.accumulate(M.numpy(), axis=2, dtype=object).astype(float)

In [None]:
np.allclose(N.numpy(), N_numpy)  # exact comparison doesn't work for numerical reasons

## Alternative - utilizng matrix multiplication

The solution above, even though effective, lacks some elegance. We could therefore represent it as matrix multiplication. The crucial observation is that by going along the last axis, we can develop a multinomial dependence on `p`.

In [None]:
with tf.device("/cpu:0"):
    with tf.GradientTape(persistent=True) as tape:
        powers_of_p = tf.linalg.band_part(p**tf.math.cumsum(tf.linalg.set_diag(tf.linalg.band_part(tf.ones((c, c)), 0, -1), tf.zeros((c))), 1), 0, -1)
        N = tf.matmul(M, powers_of_p)  # we utilize broadcasting here
print('dN/dM:')
print(tape.gradient(N, M))
print('dN/dp:', tape.gradient(N, p).numpy())

In [None]:
np.allclose(N.numpy(), N_numpy)  # exact comparison doesn't work for numerical reasons

## Comparison of performance

In [None]:
%%timeit
with tf.device("/cpu:0"):
    with tf.GradientTape(persistent=True) as tape:
        reshaped_M = tf.reshape(M, [a*b, -1])
        N = tf.reshape(tf.map_fn(lambda row: tf.scan(lambda x, y: x*p+y, row), reshaped_M), M.shape)
tape.gradient(N, M)
tape.gradient(N, p)

In [None]:
%%timeit
with tf.device("/cpu:0"):
    with tf.GradientTape(persistent=True) as tape:
        powers_of_p = tf.linalg.band_part(p**tf.math.cumsum(tf.linalg.set_diag(tf.linalg.band_part(tf.ones((c, c)), 0, -1), tf.zeros((c))), 1), 0, -1)
        N = tf.matmul(M, powers_of_p)
tape.gradient(N, M)
tape.gradient(N, p)

It seems that the latter solution (matrix multiplication) on the laptop I was running this code is ~15x faster and, additionally, doesn't raise warnings.