In [6]:
from tensor import Tensor, Dependency, add
import numpy as np

In [2]:
t1 = Tensor([1, 2, 3], requires_grad=True)
t2 = t1.sum()
t2.backward()

In [3]:
np.ones_like([1])

array([1])

In [None]:
isinstance(np.array([1, 1]), np.ndarray)

True

In [None]:
def grad_fn(grad: np.ndarray) -> np.ndarray:
    """grad is necessarily a 0-tensor, so each input element contributes that much"""
    return grad * np.ones_like([3, 3, 3])

In [None]:
t1 = Tensor([3.0, 2.0, 4.0], requires_grad=True)
t2 = Tensor([3.0, 3.0, 3.0], requires_grad=True)
t3 = Tensor([3.0, 3.0, 3.0], requires_grad=True, depends_on=[Dependency(t1, grad_fn)])
t4 = Tensor([3.0, 3.0, 3.0], requires_grad=True, depends_on=[Dependency(t2, grad_fn)])
t5 = Tensor(
    [3.0, 3.0, 3.0],
    requires_grad=True,
    depends_on=[Dependency(t3, grad_fn), Dependency(t4, grad_fn)],
)
t6 = Tensor(
    [3.0, 3.0, 3.0],
    requires_grad=True,
    depends_on=[Dependency(t2, grad_fn), Dependency(t5, grad_fn)],
)
t7 = Tensor(
    [3.0, 3.0, 3.0],
    requires_grad=True,
    depends_on=[Dependency(t6, grad_fn), Dependency(t5, grad_fn)],
)
t7.backward(Tensor([9, 9, 9]))


In [29]:
t5.grad

Tensor([18. 18. 18.], requires_grad=False)

In [13]:
t7.data.sum()

np.float64(9.0)

In [2]:
t1 = Tensor(np.array([[1], [2]]), requires_grad=True)  # Shape (2,1)
t2 = Tensor(np.array([[3, 4], [5, 6]]), requires_grad=True)  # Shape (2,2)

result = add(t1, t2)
result

Tensor([[4 5]
 [7 8]], requires_grad=True)

In [9]:
t1 = Tensor([1, 2, 3], requires_grad=True)
t2 = Tensor(10, requires_grad=True)  # Scalar
t3 = add(t1, t2)
t3.backward(Tensor([1, 1, 1]))  # Passing a gradient of [1, 1, 1]

t2.grad.data.tolist()

3.0

In [None]:
from typing import Callable
import numpy as np


def make_grad_fn(
    original_shape: tuple, other_data: np.ndarray, chain_rule_fn: Callable
) -> Callable[[np.ndarray], np.ndarray]:
    """
    Args:
        original_shape: Shape of the Tensor.
        other_data: Other Tensor data.
        chain_rule_fn: Function that takes (grad, other_data) and returns modified grad.
        example: chain_rule_fn = lambda g, x: g * x <-- multiplication chain rule
    """

    def grad_fn(grad: np.ndarray) -> np.ndarray:
        grad = chain_rule_fn(grad, other_data)

        if grad.shape != original_shape:
            # Identify all axes that were either:
            # 1. Added in forward pass (not in original shape), OR
            # 2. Were size-1 in original (and thus broadcasted)
            sum_axes = [
                i
                for i in range(-1, -len(grad.shape) - 1, -1)
                if (i < -len(original_shape))  # Added dimension
                or (original_shape[i] == 1)  # Broadcasted dimension
            ]

            if sum_axes:
                grad = grad.sum(axis=tuple(sum_axes), keepdims=True)

            # Remove any extra dimensions that summing didn't handle
            if grad.ndim > len(original_shape):
                grad = grad.reshape(original_shape)

        return grad

    return grad_fn

In [None]:
original_shape = (3,)
other_data = ()


def chain_rule_fn(g, x):
    return g


def grad_fn(grad: np.ndarray) -> np.ndarray:
    grad = chain_rule_fn(grad, other_data)

    if grad.shape != original_shape:
        # Identify all axes that were either:
        # 1. Added in forward pass (not in original shape), OR
        # 2. Were size-1 in original (and thus broadcasted)
        print(grad.shape, original_shape)
        sum_axes = [
            i
            for i in range(-1, -len(grad.shape) - 1, -1)
            if (i < -len(original_shape))  # Added dimension
            or (original_shape[i] == 1)  # Broadcasted dimension
        ]
        print(sum_axes)
        if sum_axes:
            grad = grad.sum(axis=tuple(sum_axes), keepdims=True)
        # Remove any extra dimensions that summing didn't handle
        if grad.ndim > len(original_shape):
            grad = grad.reshape(original_shape)
        print(grad)
    return grad

In [24]:
grad = grad_fn(
    np.array(
        [
            [[3, 3, 3], [3, 3, 3], [3, 3, 3]],
            [[3, 3, 3], [3, 3, 3], [3, 3, 3]],
            [[3, 3, 3], [3, 3, 3], [3, 3, 3]],
        ]
    )
)
grad

(3, 3, 3) (3,)
[-2, -3]
[[[27 27 27]]]
[27 27 27]


array([27, 27, 27])