In [1]:
from compgraph.nodes import *
import compgraph as cg
import numpy as np

# Finding out the correct implementation for `max_grad`

## First Attempt

In [2]:
def max(array, axis=None, keepdims=False, name=None):
    if not isinstance(array, Node):
        array = ConstantNode.create_using(array)
    opvalue = np.max(array, axis=axis, keepdims=keepdims)
    opnode = OperationalNode.create_using(opvalue, 'max', array, name=name)

    return opnode

def max_grad(prev_adjoint, node):
    doperand_a = np.where(node.operand_a == node, 1, 0)
    normalizers = np.sum(doperand_a, keepdims=True)
    normalized_doperand_a = doperand_a / normalizers

    return [prev_adjoint * normalized_doperand_a, None]

### Works for an array like `[1, 4, 4]`

In [4]:
x = cg.variable(np.array([1, 4, 4]))
max_x = max(x)
print(max_grad(1., max_x))  # prints [array([0. , 0.5, 0.5]), None]

[array([0. , 0.5, 0.5]), None]


### But doesn't work for `[[0, 1, 4], [0, 7, 1]]` along the first axis

In [7]:
x = cg.variable(np.array([[0, 1, 4], [0, 7, 1]]))
max_x = max(x, axis=0)
print(max_x) # prints [0, 7, 4]

"""
prints
[array([[0.25, 0.  , 0.25],
        [0.25, 0.25, 0.  ]]), None]
while it should print
[array([[0.5, 0, 1],
        [0.5, 1, 0]]), None]
"""
print(max_grad(1, max_x))

[0 7 4]
[array([[0.25, 0.  , 0.25],
       [0.25, 0.25, 0.  ]]), None]


## Second Attempt

We save the axis along which the `max` is taken to use to in calculating the normalizers in the gradient

In [25]:
def max(array, axis=None, keepdims=False, name=None):
    if not isinstance(array, Node):
        array = ConstantNode.create_using(array)
    opvalue = np.max(array, axis=axis, keepdims=keepdims)
    opnode = OperationalNode.create_using(opvalue, 'max', array, name=name)

    # save info for gradient computation
    opnode.axis = axis

    return opnode

def max_grad(prev_adjoint, node):
    doperand_a = np.where(node.operand_a == node, 1, 0)
    normalizers = np.sum(doperand_a, axis=node.axis, keepdims=True)
    normalized_doperand_a = doperand_a / normalizers

    return [prev_adjoint * normalized_doperand_a, None]

### Now it works for `[[0, 1, 4], [0, 7, 1]]` along the first axis


In [26]:
x = cg.variable(np.array([[0, 1, 4], [0, 7, 1]]))
max_x = max(x, axis=0)
print(max_x) # prints [0, 7, 4]

"""
prints
[array([[0.5, 0, 1],
        [0.5, 1, 0]]), None]
"""
print(max_grad(1, max_x))

[0 7 4]
[array([[0.5, 0. , 1. ],
       [0.5, 1. , 0. ]]), None]


### But it doesn't work when we shift the axis to `1`

In [31]:
max_x_1 = max(x, axis=1)
print(max_x_1)

print(max_grad(1, max_x_1))

[4 7]
[array([[0., 0., 1.],
       [0., 1., 0.]]), None]


## Final Attempt

In [29]:
def max(array, axis=None, keepdims=False, name=None):
    if not isinstance(array, Node):
        array = ConstantNode.create_using(array)
    opvalue = np.max(array, axis=axis, keepdims=keepdims)
    opnode = OperationalNode.create_using(opvalue, 'max', array, name=name)

    # save info for gradient computation
    opnode.axis = axis
    opnode.with_keepdims = np.max(array, axis=axis, keepdims=True)

    return opnode

def max_grad(prev_adjoint, node):
    doperand_a = np.where(node.operand_a == node.with_keepdims, 1, 0)
    normalizers = np.sum(doperand_a, axis=node.axis, keepdims=True)
    normalized_doperand_a = doperand_a / normalizers

    return [prev_adjoint * normalized_doperand_a, None]

### Now it works when we shift the axis to `1`

In [32]:
max_x_1 = max(x, axis=1)
print(max_x_1)

print(max_grad(1, max_x_1))

[4 7]
[array([[0., 0., 1.],
       [0., 1., 0.]]), None]
