In [1]:
import tensorflow as tf
import numpy as np

logits1 = tf.random.uniform((1, 4, 5, 3), dtype=tf.float32)
labels1 = tf.concat([
  tf.zeros((1, 1), dtype=tf.int32),
  tf.random.uniform((1, 4), 1, 3, dtype=tf.int32)
], axis=-1)

logits2 = tf.random.uniform((1, 3, 4, 3), dtype=tf.float32)
labels2 = tf.random.uniform((1, 4), 1, 3, dtype=tf.int32)

logits2_padded = tf.pad(logits2, [[0, 0], [0, tf.shape(logits1)[1] - tf.shape(logits2)[1]], [0, tf.shape(logits1)[2] - tf.shape(logits2)[2]], [0, 0]], constant_values=0)
labels2_padded = tf.pad(labels2, [[0, 0], [0, tf.shape(labels1)[1] - tf.shape(labels2)[1]]], constant_values=0)

logits = tf.concat([logits1, logits2_padded], axis=0)
labels = tf.concat([labels1, labels2_padded], axis=0)

logits_rev = tf.reverse(logits, axis=[0])
labels_rev = tf.reverse(labels, axis=[0])


In [2]:
def rnnt_loss (logits, labels, time_lengths, label_lengths):
  log_pr = tf.math.log_softmax(logits, axis=-1)
  pr = pr_loss(log_pr, labels, time_lengths, label_lengths)
  ret = tf.reduce_sum(pr)
  return ret

@tf.custom_gradient
def pr_loss (log_pr, labels, time_lengths, label_lengths):
  LOG_0 = float('-inf')
  batch_size = log_pr.shape[0]
  max_time_lengths = log_pr.shape[1]
  max_label_lengths = log_pr.shape[2]

  def get_truth_log_pr (log_pr, labels):
    labels_one_hot = tf.one_hot(labels, tf.shape(log_pr)[-1], axis=-1, dtype=tf.float32)
    labels_one_hot = tf.expand_dims(labels_one_hot, axis=1)
    labels_one_hot = tf.repeat(labels_one_hot, tf.shape(log_pr)[1], axis=1)
    ret = tf.reduce_sum(log_pr * labels_one_hot, axis=-1)
    return ret
  
  def get_blank_log_pr (log_pr):
    return log_pr[:, :, :, 0]

  truth_log_pr = get_truth_log_pr(log_pr, labels)
  blank_log_pr = get_blank_log_pr(log_pr)

  def get_alpha (log_pr, labels, time_lengths, label_lengths):
    alpha = LOG_0 * np.ones((batch_size, max_time_lengths, max_label_lengths + 1), dtype=np.float32)
    log_pr = tf.concat([
      LOG_0 * tf.ones((batch_size, max_time_lengths, 1, tf.shape(log_pr)[-1]), tf.float32),
      log_pr
    ], axis=-2)
    labels = tf.concat([
      tf.zeros((batch_size, 1), tf.int32),
      labels
    ], axis=-1)
    for b in range(batch_size):
      alpha[b][0][1] = 0
      for u in tf.range(2, label_lengths[b] + 1):
        alpha[b][0][u] = alpha[b][0][u - 1] + log_pr[b][0][u - 1][labels[b][u]]

      for t in tf.range(1, time_lengths[b]):
        for u in tf.range(1, label_lengths[b] + 1):
          alpha[b][t][u] = tf.reduce_logsumexp(
            tf.stack([
              alpha[b][t - 1][u] + log_pr[b][t - 1][u][0],
              alpha[b][t][u - 1] + log_pr[b][t][u - 1][labels[b][u]],
            ], axis=0),
            axis=0
          )

    return alpha[:, :, 1: ]
  
  def get_beta (log_pr, labels, time_lengths, label_lengths):
    beta = LOG_0 * np.ones((batch_size, max_time_lengths, max_label_lengths + 1), dtype=np.float32)
    log_pr = tf.concat([log_pr, LOG_0 * tf.ones((batch_size, max_time_lengths, 1, tf.shape(log_pr)[-1]), tf.float32)], axis=-2)
    labels = tf.concat([labels, tf.zeros((batch_size, 1), tf.int32)], axis=-1)

    for b in range(batch_size):
      beta[b][time_lengths[b] - 1][label_lengths[b] - 1] = log_pr[b][time_lengths[b] - 1][label_lengths[b] - 1][0]
      for u in tf.reverse(tf.range(label_lengths[b] + 1 - 2), axis=[-1]):
        beta[b][time_lengths[b] - 1][u] = beta[b][time_lengths[b] - 1][u + 1] + log_pr[b][time_lengths[b] - 1][u][labels[b][u + 1]]
      for t in tf.reverse(tf.range(time_lengths[b] - 1), axis=[-1]):
        for u in tf.reverse(tf.range(label_lengths[b] + 1 - 1), axis=[-1]):
          beta[b][t][u] = tf.reduce_logsumexp(tf.stack([
              beta[b][t + 1][u] + log_pr[b][t][u][0],
              beta[b][t][u + 1] + log_pr[b][t][u][labels[b][u + 1]]
            ]), axis=-1
          )

    return beta[:, :, : -1]

  alpha = get_alpha(log_pr, labels, time_lengths, label_lengths)
  beta = get_beta(log_pr, labels, time_lengths, label_lengths)

  print('alpha, beta: ', alpha, beta)
  
  total_log_pr = beta[:, 0, 0]

  vocab_size = log_pr.shape[-1]
  def grad (upstream):
    ret = np.zeros((batch_size, max_time_lengths, max_label_lengths, vocab_size), dtype=np.float32)
    for b in range(batch_size):
      for t in range(time_lengths[b]):
        for u in range(label_lengths[b]):
          if u + 1 < label_lengths[b]:
            ret[b][t][u][labels[b][u + 1]] = -upstream[b] * tf.math.exp(alpha[b][t][u] + beta[b][t][u + 1] - total_log_pr[b])
          if t + 1 < time_lengths[b]:
            ret[b][t][u][0] = -upstream[b] * tf.math.exp(alpha[b][t][u] + beta[b][t + 1][u] - total_log_pr[b])
          else:
            if u < label_lengths[b] - 1:
              # ret[b][t][u][0] = -upstream[b] * tf.math.exp(alpha[b][t][u] + LOG_0 - total_log_pr[b])
              ret[b][t][u][0] = -upstream[b] * 0
            else:
              ret[b][t][u][0] = -upstream[b] * tf.math.exp(alpha[b][t][u] - total_log_pr[b])

    return [tf.convert_to_tensor(ret) * tf.exp(log_pr)] + [None] * 3

  return -total_log_pr, grad


In [3]:
with tf.GradientTape() as tape:
  tape.watch(logits)
  loss = rnnt_loss(
    logits,
    labels,
    tf.convert_to_tensor([
      logits1.shape[1],
      logits2.shape[1]
    ], dtype=tf.int32),
    tf.convert_to_tensor([
      logits1.shape[2],
      logits2.shape[2]
    ], dtype=tf.int32)
  )

grads = tape.gradient(loss, logits)
print(loss)
print(grads)

########################################

# with tf.GradientTape() as tape:
#   tape.watch(logits1)
#   loss = rnnt_loss(logits1, labels1,
#     tf.convert_to_tensor([
#       logits1.shape[1]
#     ], dtype=tf.int32),
#     tf.convert_to_tensor([
#       logits1.shape[2]
#     ], dtype=tf.int32)
#   )

# grads = tape.gradient(loss, logits1)
# print(loss)
# print(grads)

alpha, beta:  [[[ 0.         -0.92599946 -2.3156157  -3.055272   -4.2102556 ]
  [-1.2887211  -1.6292272  -2.5907452  -3.0962415  -4.017257  ]
  [-2.0283346  -2.1455922  -2.7752166  -3.330913   -4.218552  ]
  [-2.9248867  -2.9254246  -3.0963216  -3.4934251  -4.2247667 ]]

 [[ 0.         -0.90289307 -1.5848131  -2.5483189         -inf]
  [-1.4627242  -1.6832632  -2.0035896  -2.5004194         -inf]
  [-2.5198498  -2.1539145  -2.2927456  -2.5829546         -inf]
  [       -inf        -inf        -inf        -inf        -inf]]] [[[-5.392599  -4.8158193 -3.948039  -3.6297054 -3.8899534]
  [-5.325465  -4.5423517 -3.7696152 -3.1361618 -2.853404 ]
  [-5.462956  -4.405232  -3.5552168 -2.658824  -2.0091853]
  [-5.9284177 -4.674135  -3.803445  -2.4680066 -1.1678319]]

 [[-3.8387084 -3.2207396 -2.971735  -3.2371302       -inf]
  [-3.770599  -3.0109746 -2.4250593 -2.1882272       -inf]
  [-3.7591357 -2.8791437 -2.1036036 -1.2557541       -inf]
  [      -inf       -inf       -inf       -inf       -i

In [6]:
def nan_to_zero (tensor):
  return tf.where(tf.math.is_nan(tensor), tf.zeros_like(tensor), tensor)

def rnnt_loss (logits, labels, time_lengths, label_lengths):
  log_pr = tf.math.log_softmax(logits, axis=-1)
  pr = pr_loss(log_pr, labels, time_lengths, label_lengths)
  ret = tf.reduce_sum(pr)
  return ret

@tf.custom_gradient
def pr_loss (log_pr, labels, time_lengths, label_lengths):
  LOG_0 = float('-inf')
  batch_size = tf.shape(log_pr)[0]
  max_time_lengths = tf.shape(log_pr)[1]
  max_label_lengths = tf.shape(log_pr)[2]

  def get_truth_log_pr (log_pr, labels):
    labels_one_hot = tf.one_hot(labels, tf.shape(log_pr)[-1], axis=-1, dtype=tf.float32)
    labels_one_hot = labels_one_hot[:, 1: ]
    labels_one_hot = tf.expand_dims(labels_one_hot, axis=1)
    labels_one_hot = tf.repeat(labels_one_hot, tf.shape(log_pr)[1], axis=1)
    ret = tf.reduce_sum(log_pr[:, :, : -1, :] * labels_one_hot, axis=-1)
    ret = tf.concat([
      ret,
      LOG_0 * tf.ones((tf.shape(log_pr)[0], tf.shape(log_pr)[1], 1), dtype=tf.float32)
    ], axis=-1)
    return ret
  
  def get_blank_log_pr (log_pr):
    return log_pr[:, :, :, 0]

  truth_log_pr = get_truth_log_pr(log_pr, labels)
  blank_log_pr = get_blank_log_pr(log_pr)

  def get_alpha (truth_log_pr, blank_log_pr):
    reversed_truth_log_pr = tf.reverse(truth_log_pr, axis=[-1])
    padded_truth_log_pr = tf.pad(
      reversed_truth_log_pr,
      [[0, 0], [0, 0], [tf.shape(reversed_truth_log_pr)[-2] - 1, 0]],
      constant_values=LOG_0
    )
    truth_diag = tf.linalg.diag_part(
      padded_truth_log_pr,
      k=(0, tf.shape(padded_truth_log_pr)[-1] - 1),
      padding_value=LOG_0,
      align='LEFT_RIGHT'
    )
    truth_diag = tf.transpose(truth_diag, perm=[1, 0, 2])

    reversed_blank_log_pr = tf.reverse(blank_log_pr, axis=[-1])
    padded_blank_log_pr = tf.pad(
      reversed_blank_log_pr,
      [[0, 0], [0, 0], [tf.shape(reversed_blank_log_pr)[-2] - 1, 0]],
      constant_values=LOG_0
    )
    blank_diag = tf.linalg.diag_part(
      padded_blank_log_pr,
      k=(0, tf.shape(padded_blank_log_pr)[-1] - 1),
      padding_value=LOG_0,
      align='LEFT_RIGHT'
    )
    blank_diag = tf.concat([
      LOG_0 * tf.ones((tf.shape(blank_diag)[0], tf.shape(blank_diag)[1], 1), dtype=tf.float32),
      blank_diag[:, :, : -1]
    ], axis=-1)
    blank_diag = tf.transpose(blank_diag, perm=[1, 0, 2])

    initial_diag = tf.concat([
      tf.zeros((tf.shape(blank_diag)[1], 1), dtype=tf.float32),
      LOG_0 * tf.ones((tf.shape(blank_diag)[1], tf.shape(blank_diag)[-1] - 1), dtype=tf.float32)
    ], axis=-1)

    def step (a, x):
      t, b = x
      return (
        tf.reduce_logsumexp(
          tf.stack([
              a + t,
              tf.concat([
                LOG_0 * tf.ones((tf.shape(a)[0], 1), dtype=tf.float32),
                a[:, : -1]
              ], axis=-1) + b
            ],
            axis=0
          ), axis=0
        )
      )
    alpha_diag =  tf.concat([
      tf.expand_dims(initial_diag, axis=0),
      tf.scan(step, (truth_diag, blank_diag), initial_diag)
    ], axis=0)
    alpha_diag = tf.transpose(alpha_diag, perm=[1, 2, 0])
    alpha = tf.linalg.diag_part(alpha_diag, k=(0, max_label_lengths - 1))
    alpha = tf.reverse(alpha, axis=[-2])
    alpha = tf.transpose(alpha, perm=[0, 2, 1])
    return alpha

  def get_beta (truth_log_pr, blank_log_pr):
    reversed_truth_log_pr = tf.reverse(truth_log_pr, axis=[-1])
    reversed_truth_log_pr = tf.concat([
      LOG_0 * tf.ones((tf.shape(reversed_truth_log_pr)[0], tf.shape(reversed_truth_log_pr)[1], 1), dtype=tf.float32),
      reversed_truth_log_pr[:, :, 1: ]
    ], axis=-1)
    padded_truth_log_pr = tf.pad(
      reversed_truth_log_pr,
      [[0, 0], [0, 0], [tf.shape(reversed_truth_log_pr)[-2] - 1, 0]],
      constant_values=LOG_0
    )
    truth_diag = tf.linalg.diag_part(
      padded_truth_log_pr,
      k=(0, tf.shape(padded_truth_log_pr)[-1] - 1),
      padding_value=LOG_0,
      align='LEFT_RIGHT'
    )
    truth_diag = tf.transpose(truth_diag, perm=[1, 0, 2])
    truth_diag = truth_diag[: -1]

    reversed_blank_log_pr = tf.reverse(blank_log_pr, axis=[-1])
    reversed_blank_log_pr = tf.concat([
      reversed_blank_log_pr[:, : -1, :],
      LOG_0 * tf.ones((tf.shape(reversed_blank_log_pr)[0], 1, tf.shape(reversed_blank_log_pr)[-1]), dtype=tf.float32)
    ], axis=-2)
    padded_blank_log_pr = tf.pad(
      reversed_blank_log_pr,
      [[0, 0], [0, 0], [tf.shape(reversed_blank_log_pr)[-2] - 1, 0]],
      constant_values=LOG_0
    )
    blank_diag = tf.linalg.diag_part(
      padded_blank_log_pr,
      k=(0, tf.shape(padded_blank_log_pr)[-1] - 1),
      padding_value=LOG_0,
      align='LEFT_RIGHT'
    )
    blank_diag = tf.transpose(blank_diag, perm=[1, 0, 2])
    blank_diag = blank_diag[: -1]

    mask = tf.sequence_mask(
      time_lengths + label_lengths - 2,
      tf.shape(log_pr)[1] + tf.shape(log_pr)[2] - 2,
      dtype=tf.float32
    )
    mask = tf.transpose(mask, perm=[1, 0])

    dp_start_value = tf.gather_nd(
      blank_log_pr,
      indices=tf.stack([time_lengths, label_lengths], axis=-1) - 1,
      batch_dims=1
    )

    initial_diag_mask = tf.one_hot(time_lengths - 1, depth=tf.shape(log_pr)[1])
    initial_diag = tf.expand_dims(dp_start_value, axis=1) * initial_diag_mask + nan_to_zero(LOG_0 * (1.0 - initial_diag_mask))

    def step (a, x):
      m, t, b = x
      a_next = tf.reduce_logsumexp(
        tf.stack([
          a + t,
          tf.concat([
            a[:, 1: ],
            LOG_0 * tf.ones((tf.shape(a)[0], 1), dtype=tf.float32)
          ], axis=-1) + b
        ], axis=0),
        axis=0
      )
      masked_a_next = nan_to_zero(a_next * tf.expand_dims(m, axis=1)) + nan_to_zero(a * tf.expand_dims(1.0 - m, axis=1))
      return masked_a_next

    beta_diag = tf.concat([
      tf.scan(step, (mask, truth_diag, blank_diag), initial_diag, reverse=True),
      tf.expand_dims(initial_diag, axis=0)
    ], axis=0)

    beta_diag = tf.transpose(beta_diag, perm=[1, 2, 0])
    beta = tf.linalg.diag_part(beta_diag, k=(0, tf.shape(log_pr)[2] - 1), padding_value=LOG_0)
    beta = tf.transpose(beta, perm=[0, 2, 1])
    beta = tf.reverse(beta, axis=[-1])

    return beta
  
  time_mask = tf.sequence_mask(time_lengths, tf.shape(log_pr)[1], dtype=tf.float32)
  label_mask = tf.sequence_mask(label_lengths, tf.shape(log_pr)[2], dtype=tf.float32)
  total_mask = tf.expand_dims(time_mask, axis=2) * tf.expand_dims(label_mask, axis=1)

  alpha = get_alpha(truth_log_pr, blank_log_pr)
  alpha = alpha + nan_to_zero((1.0 - total_mask) * LOG_0)
  beta = get_beta(truth_log_pr, blank_log_pr)
  beta = beta + nan_to_zero((1.0 - total_mask) * LOG_0)

  indices = tf.concat([
    tf.expand_dims(tf.range(0, tf.shape(log_pr)[0]), axis=1),
    tf.stack([
      time_lengths,
      label_lengths - 1
    ], axis=-1),
  ], axis=-1)
  
  beta_mask = tf.scatter_nd(
    indices,
    tf.ones(tf.shape(indices)[0], tf.float32),
    (tf.shape(log_pr)[0], tf.shape(log_pr)[1] + 1, tf.shape(log_pr)[2])
  )
  beta_mask = 1.0 - beta_mask
  beta = nan_to_zero(tf.pad(beta, [[0, 0], [0, 1], [0, 0]], constant_values=LOG_0) * beta_mask)

  total_mask = tf.expand_dims(total_mask, axis=-1)

  total_log_pr = beta[:, 0, 0]

  def grad (upstream):
    blank_grads = \
      alpha + beta[:, 1: , :] \
      - tf.reshape(total_log_pr, shape=(tf.shape(total_log_pr)[0], 1, 1))
    truth_grads = \
      alpha + tf.pad(
        beta[:, : -1, 1: ],
        [[0, 0], [0, 0], [0, 1]],
        constant_values=LOG_0
      ) \
      - tf.reshape(total_log_pr, shape=(tf.shape(total_log_pr)[0], 1, 1))
    blank_one_hot = tf.one_hot(tf.zeros_like(labels, dtype=tf.int32), tf.shape(log_pr)[-1], dtype=tf.float32)
    blank_one_hot = tf.expand_dims(blank_one_hot, axis=1)
    blank_one_hot = tf.repeat(blank_one_hot, tf.shape(log_pr)[1], axis=1)
    blank_grads = tf.exp(tf.expand_dims(blank_grads, axis=-1) + log_pr) * blank_one_hot
    truth_one_hot = tf.one_hot(labels[:, 1: ], tf.shape(log_pr)[-1], dtype=tf.float32)
    truth_one_hot = tf.concat([truth_one_hot, tf.zeros((tf.shape(log_pr)[0], 1, tf.shape(log_pr)[-1]), dtype=tf.float32)], axis=-2)
    truth_one_hot = tf.expand_dims(truth_one_hot, axis=1)
    truth_one_hot = tf.repeat(truth_one_hot, tf.shape(log_pr)[1], axis=1)
    truth_grads = tf.exp(tf.expand_dims(truth_grads, axis=-1) + log_pr) * truth_one_hot

    grads = blank_grads + truth_grads

    return (
      [
        tf.reshape(-upstream, shape=(tf.shape(upstream)[0], 1, 1, 1))
        * grads * total_mask
      ] +
      [None] * 3
    )

  return -total_log_pr, grad


In [9]:
with tf.GradientTape() as tape:
  tape.watch(logits)
  loss = rnnt_loss(
    logits,
    labels,
    tf.convert_to_tensor([
      logits1.shape[1],
      logits2.shape[1]
    ], dtype=tf.int32),
    tf.convert_to_tensor([
      logits1.shape[2],
      logits2.shape[2]
    ], dtype=tf.int32)
  )

grads = tape.gradient(loss, logits)
print(loss)
print(grads)

##########################################

# with tf.GradientTape() as tape:
#   tape.watch(logits1)
#   loss = rnnt_loss(
#     logits1,
#     labels1,
#     tf.convert_to_tensor([
#       logits1.shape[1],
#     ], dtype=tf.int32),
#     tf.convert_to_tensor([
#       logits1.shape[2],
#     ], dtype=tf.int32)
#   )

# grads = tape.gradient(loss, logits1)
# print(loss)
# print(grads)

##########################################

# with tf.GradientTape() as tape:
#   tape.watch(logits2)
#   loss = rnnt_loss(
#     logits2,
#     labels2,
#     tf.convert_to_tensor([
#       logits2.shape[1],
#     ], dtype=tf.int32),
#     tf.convert_to_tensor([
#       logits2.shape[2],
#     ], dtype=tf.int32)
#   )

# grads = tape.gradient(loss, logits2)
# print(loss)
# print(grads)


tf.Tensor(9.231308, shape=(), dtype=float32)
tf.Tensor(
[[[[-0.01913884 -0.30910283  0.3282418 ]
   [-0.06860353 -0.2427846   0.31138816]
   [-0.02351385  0.09838533 -0.07487151]
   [-0.08099362  0.06116809  0.01982552]
   [-0.04304052  0.02878935  0.01425118]]

  [[ 0.01807259 -0.09347808  0.07540552]
   [-0.02855265 -0.12704903  0.15560171]
   [-0.0301151   0.1351668  -0.10505171]
   [-0.10263944  0.14238776 -0.03974833]
   [-0.13002956  0.06003369  0.06999588]]

  [[ 0.01861684 -0.05953721  0.04092037]
   [ 0.02425869 -0.1347973   0.11053863]
   [ 0.03141599  0.13431704 -0.16573304]
   [-0.05987129  0.15816449 -0.09829323]
   [-0.24678552  0.11987937  0.12690614]]

  [[ 0.01138065 -0.02244763  0.01106698]
   [ 0.03280412 -0.06396724  0.03116312]
   [ 0.0459091   0.11735388 -0.16326298]
   [ 0.20161878  0.21029055 -0.41190934]
   [-0.6889597   0.30324212  0.38571763]]]


 [[[-0.01632415  0.36300027 -0.34667614]
   [-0.05000587  0.1575295  -0.10752374]
   [-0.14531748  0.10193256  0.0

In [8]:
import numpy as np

rst = np.zeros_like(logits, dtype=np.float64)
loss_before = rnnt_loss(
  logits,
  labels,
  tf.convert_to_tensor([
    logits1.shape[1],
    logits2.shape[1]
  ], dtype=tf.int32),
  tf.convert_to_tensor([
    logits1.shape[2],
    logits2.shape[2]
  ], dtype=tf.int32)
)
for i in range(logits.shape[0]):
  for j in range(logits.shape[1]):
    for k in range(logits.shape[2]):
      for w in range(logits.shape[3]):
        h = logits[i][j][k][w] * 0.005
        a = np.zeros_like(logits, dtype=np.float64)
        a[i][j][k][w] += h
        loss_after = rnnt_loss(
          logits + a,
          labels,
          tf.convert_to_tensor([
            logits1.shape[1],
            logits2.shape[1]
          ], dtype=tf.int32),
          tf.convert_to_tensor([
            logits1.shape[2],
            logits2.shape[2]
          ], dtype=tf.int32)
        )
        rst[i][j][k][w] = (loss_after - loss_before) / h

print(rst)


[[[[-0.01950178 -0.30899599  0.3284134 ]
   [-0.06956212 -0.24224073  0.31165951]
   [-0.02381245  0.0965459  -0.07500952]
   [-0.08121312  0.06100342  0.01958293]
   [-0.04330702  0.0284321   0.01583567]]

  [[ 0.01805459 -0.09361105  0.07520283]
   [-0.02879723 -0.12782161  0.15542826]
   [-0.03050284  0.13506851 -0.10523877]
   [-0.1027229   0.1424693  -0.03993363]
   [-0.13012888  0.05957048  0.07005575]]

  [[ 0.01833876 -0.05947057  0.04068129]
   [ 0.02410301 -0.13517386  0.11060436]
   [ 0.03130596  0.13415524 -0.1659625 ]
   [-0.06003268  0.15764309 -0.11254264]
   [-0.24703215  0.1189993   0.12697737]]

  [[ 0.01088513 -0.02291232  0.0107045 ]
   [ 0.03228918 -0.06428875  0.03048189]
   [ 0.04443169  0.11729471 -0.16313082]
   [ 0.20091861  0.20994577 -0.41478291]
   [-0.68885601  0.30341071  0.38600522]]]


 [[[-0.01673768  0.36341387 -0.34660751]
   [-0.04998233  0.15688762 -0.10757051]
   [-0.14544274  0.10301723  0.04330653]
   [-0.09286078  0.04694804  0.04555332]
   [  