In [2]:
import tensorflow as tf

In [3]:
class R2(tf.keras.metrics.Metric):
  def __init__(self, num_targets, summarize=True, name='r2', **kwargs):
    super(R2, self).__init__(name=name, **kwargs)
    self._summarize = summarize
    self._shape = (num_targets,)
    self._count = self.add_weight(name='count', shape=self._shape, initializer='zeros')

    self._true_sum = self.add_weight(name='true_sum', shape=self._shape, initializer='zeros')
    self._true_sumsq = self.add_weight(name='true_sumsq', shape=self._shape, initializer='zeros')

    self._product = self.add_weight(name='product', shape=self._shape, initializer='zeros')
    self._pred_sumsq = self.add_weight(name='pred_sumsq', shape=self._shape, initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.cast(y_true, 'float32')
    y_pred = tf.cast(y_pred, 'float32')

    if len(y_true.shape) == 2:
      reduce_axes = 0
    else:
      reduce_axes = [0,1]

    true_sum = tf.reduce_sum(y_true, axis=reduce_axes)
    self._true_sum.assign_add(true_sum)

    true_sumsq = tf.reduce_sum(tf.math.square(y_true), axis=reduce_axes)
    self._true_sumsq.assign_add(true_sumsq)

    product = tf.reduce_sum(tf.multiply(y_true, y_pred), axis=reduce_axes)
    self._product.assign_add(product)

    pred_sumsq = tf.reduce_sum(tf.math.square(y_pred), axis=reduce_axes)
    self._pred_sumsq.assign_add(pred_sumsq)

    count = tf.ones_like(y_true)
    count = tf.reduce_sum(count, axis=reduce_axes)
    self._count.assign_add(count)

  def result(self):
    true_mean = tf.divide(self._true_sum, self._count)
    true_mean2 = tf.math.square(true_mean)

    total = self._true_sumsq - tf.multiply(self._count, true_mean2)

    resid1 = self._pred_sumsq
    resid2 = -2*self._product
    resid3 = self._true_sumsq
    resid = resid1 + resid2 + resid3

    r2 = tf.ones_like(self._shape, dtype=tf.float32) - tf.divide(resid, total)

    if self._summarize:
        return tf.reduce_mean(r2)
    else:
        return r2

  def reset_state(self):
    K.batch_set_value([(v, np.zeros(self._shape)) for v in self.variables])

In [10]:
tf.ones_like(
  (54,), dtype=tf.float32)

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>

In [4]:
x = tf.constant([[3],
                  [-0.5],
                  [2],
                  [7]])
y = tf.constant([[2.5],
                 [0.0],
                 [2],
                 [8]])

In [5]:
x

<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[ 3. ],
       [-0.5],
       [ 2. ],
       [ 7. ]], dtype=float32)>

In [6]:
pearsonr = R2(num_targets=1, summarize=False)

In [7]:
pearsonr.update_state(x, y)

In [8]:
pearsonr.result()

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.94860816], dtype=float32)>