In [51]:
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Iterator, Optional, Protocol, Union
import numpy as np
from matplotlib import pyplot as plt
from functools import wraps
import warnings
IS_CI = os.getenv("IS_CI")
Arr = np.ndarray
grad_tracking_enabled = True

import torch
import torch.utils.data
from torchvision import datasets, transforms
from tqdm.auto import tqdm

In [45]:
TEST_FN_PASSED = {}


def run_and_report(test_func: Callable, name: str, *test_func_args, **test_func_kwargs):
    start = time.time()
    out = test_func(*test_func_args, **test_func_kwargs)
    elapsed = time.time() - start
    print(f"{name} passed in {elapsed:.2f}s.")
    if not TEST_FN_PASSED.get(name):
        report_success(name)
        TEST_FN_PASSED[name] = True
    return out

def report(test_func):
    name = f"{test_func.__module__}.{test_func.__name__}"
    # This can happen when using autoreload, so don't complain about it.
    # if name in TEST_FN_PASSED:
    #     raise KeyError(f"Already registered: {name}")
    TEST_FN_PASSED[name] = False

    @wraps(test_func)
    def wrapper(*args, **kwargs):
        return run_and_report(test_func, name, *args, **kwargs)

    return wrapper

@report
def test_log_back(log_back):
    a = np.array([1, np.e, np.e**np.e])
    b = np.log(a)
    grad_out = np.array([2.0, 2.0, 2.0])
    actual = log_back(grad_out, b, a)
    expected = [2.0, 2.0 / np.e, 2.0 / (np.e**np.e)]
    assert np.allclose(actual, expected)
    

def report_success(testname):
    """POST to the server indicating success at the given test.

    Used to help the TAs know how long each section takes to complete.
    """
    server = os.environ.get("MLAB_SERVER")
    email = os.environ.get("MLAB_EMAIL")
    if server:
        if email:
            r = requests.post(
                server + "/api/report_success",
                json=dict(email=email, testname=testname),
            )
            if r.status_code != http.HTTPStatus.NO_CONTENT:
                raise ValueError(f"Got status code from server: {r.status_code}")
        else:
            raise ValueError(f"Server set to {server} but no MLAB_EMAIL set!")
    else:
        if email:
            raise ValueError(f"Email set to {email} but no MLAB_SERVER set!")
        else:
            return  # local dev, do nothing
        
@report
def test_unbroadcast(unbroadcast):
    small = np.ones((2, 1, 3))
    large = np.broadcast_to(small, (5, 1, 2, 4, 3))
    out = unbroadcast(large, small)
    assert out.shape == small.shape
    assert (out == 20.0).all(), "Each element in the small array appeared 20 times in the large array."

    small = np.ones((2, 1, 3))
    large = np.broadcast_to(small, (5, 1, 2, 1, 3))
    out = unbroadcast(large, small)
    assert out.shape == small.shape
    assert (out == 5.0).all(), "Each element in the small array appeared 5 times in the large array."

    small = np.ones((2, 1, 3))
    large = np.broadcast_to(small, (2, 4, 3))
    out = unbroadcast(large, small)
    assert out.shape == small.shape
    assert (out == 4.0).all(), "Each element in the small array appeared 4 times in the large array."
    
@report
def test_multiply_back(multiply_back0, multiply_back1):
    a = np.array([1, 2, 3])
    b = np.array([2])
    c = a * b
    grad_out = np.array([2.0, 2.0, 2.0])
    actual = multiply_back0(grad_out, c, a, b)
    expected = [4.0, 4.0, 4.0]
    assert np.allclose(actual, expected)
    actual = multiply_back1(grad_out, c, a, b)
    expected = [12.0]
    assert np.allclose(actual, expected)


@report
def test_multiply_back_float(multiply_back0, multiply_back1):
    a = np.array([1, 2, 3])
    b = 2
    c = a * b
    grad_out = np.array([2.0, 2.0, 2.0])
    actual = multiply_back0(grad_out, c, a, b)
    expected = [4.0, 4.0, 4.0]
    assert np.allclose(actual, expected)

    a = np.array([1, 2, 3])
    b = 2
    c = a * b
    grad_out = np.array([2.0, 2.0, 2.0])
    actual = multiply_back1(grad_out, c, b, a)
    expected = [4.0, 4.0, 4.0]
    assert np.allclose(actual, expected)
    
@report
def test_forward_and_back(forward_and_back):
    a = np.array([1, 2, 3])
    b = np.array([2, 3, 1])
    c = np.array([10])
    dg_da, dg_db, dg_dc = forward_and_back(a, b, c)
    expected_dg_da = np.array([1, 1 / 2, 1 / 3])
    expected_dg_db = np.array([1 / 2, 1 / 3, 1])
    expected_dg_dc = np.array([0.13028834])
    assert np.allclose(dg_da, expected_dg_da)
    assert np.allclose(dg_db, expected_dg_db)
    assert np.allclose(dg_dc, expected_dg_dc)
    
@report
def test_back_func_lookup(BackwardFuncLookup):
    backward_funcs = BackwardFuncLookup()
    backward_funcs.add_back_func(np.log, 0, np.exp)
    assert backward_funcs.get_back_func(np.log, 0) == np.exp
    backward_funcs.add_back_func(np.multiply, 0, np.divide)
    assert backward_funcs.get_back_func(np.multiply, 0) == np.divide
    backward_funcs.add_back_func(np.multiply, 1, np.add)
    assert backward_funcs.get_back_func(np.multiply, 1) == np.add
    


@report
def test_log(Tensor, log_forward):
    a = Tensor([np.e, np.e**np.e], requires_grad=True)
    b = log_forward(a)
    assert np.allclose(b.array, [1, np.e])
    assert b.requires_grad == True, "Should require grad because input required grad."
    assert b.is_leaf == False
    assert b.recipe is not None
    assert len(b.recipe.parents) == 1 and b.recipe.parents[0] is a
    assert len(b.recipe.args) == 1 and b.recipe.args[0] is a.array
    assert b.recipe.kwargs == {}
    assert b.recipe.func is np.log
    c = log_forward(b)
    assert np.allclose(c.array, [0, 1])


@report
def test_log_no_grad(Tensor, log_forward):
    d = Tensor([1, np.e])
    e = log_forward(d)
    assert e.requires_grad == False, "Should not require grad because input did not."
    assert e.recipe is None
    assert np.allclose(e.array, [0, 1])
    

@report
def test_multiply(Tensor, multiply):
    a = Tensor([0, 1, 2, 3], requires_grad=True)
    b = Tensor([[0], [1], [10]], requires_grad=True)
    c = multiply(a, b)
    assert c.requires_grad == True, "Should require grad because input required grad."
    assert c.is_leaf == False
    assert c.recipe is not None
    assert len(c.recipe.parents) == 2 and c.recipe.parents[0] is a and c.recipe.parents[1] is b
    assert len(c.recipe.args) == 2 and c.recipe.args[0] is a.array and c.recipe.args[1] is b.array
    assert c.recipe.kwargs == {}
    assert c.recipe.func is np.multiply
    expected = np.array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 10, 20, 30]])
    np.allclose(c.array, expected)


@report
def test_multiply_no_grad(Tensor, multiply):
    a = Tensor([0, 1, 2, 3], requires_grad=False)
    b = Tensor([[0], [1], [10]], requires_grad=False)
    c = multiply(a, b)
    assert c.requires_grad == False, "Should not require grad because input did not require grad."
    assert c.recipe is None
    expected = np.array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 10, 20, 30]])
    np.allclose(c.array, expected)


@report
def test_multiply_float(Tensor, multiply):
    a = Tensor([0, 1, 2, 3], requires_grad=True)
    b = 3
    c = multiply(a, b)
    assert c.requires_grad == True
    assert c.recipe is not None
    assert len(c.recipe.parents) == 1 and c.recipe.parents[0] is a
    assert len(c.recipe.args) == 2 and c.recipe.args[0] is a.array and c.recipe.args[1] is b
    assert c.recipe.kwargs == {}
    assert c.recipe.func is np.multiply
    expected = np.array([0, 3, 6, 9])
    np.allclose(c.array, expected)

    a = Tensor([0, 1, 2, 3], requires_grad=True)
    b = 3
    c = multiply(b, a)
    assert c.requires_grad == True
    assert c.recipe is not None
    assert len(c.recipe.parents) == 1 and c.recipe.parents[1] is a
    assert len(c.recipe.args) == 2 and c.recipe.args[0] is b and c.recipe.args[1] is a.array
    assert c.recipe.kwargs == {}
    assert c.recipe.func is np.multiply
    expected = np.array([0, 3, 6, 9])
    np.allclose(c.array, expected)
    
@report
def test_topological_sort_linked_list(topological_sort):
    z = Node()
    y = Node(z)
    x = Node(y)

    expected = [z, y, x]
    for e, a in zip(expected, topological_sort(x, get_children)):
        assert e is a
        
        
@report
def test_topological_sort_branching(topological_sort):
    z = Node()
    y = Node()
    x = Node(y, z)
    w = Node(x)

    name_lookup = {w: "w", x: "x", y: "y", z: "z"}
    out = "".join([name_lookup[n] for n in topological_sort(w, get_children)])
    assert out == "zyxw" or out == "yzxw"
    
@report
def test_topological_sort_rejoining(topological_sort):
    z = Node()
    y = Node(z)
    x = Node(y)
    w = Node(z, x)

    name_lookup = {w: "w", x: "x", y: "y", z: "z"}
    out = "".join([name_lookup[n] for n in topological_sort(w, get_children)])
    assert out == "zyxw"
    

@report
def test_topological_sort_cyclic(topological_sort):
    z = Node()
    y = Node(z)
    x = Node(y)
    z.children = [x]

    try:
        topological_sort(x, get_children)
    except:
        assert True
    else:
        assert False
        

@report
def test_backprop(Tensor):
    a = Tensor([np.e, np.e**np.e], requires_grad=True)
    b = a.log()
    c = b.log()
    c.backward(end_grad=np.array([1.0, 1.0]))
    assert c.grad is None
    assert b.grad is None
    assert a.grad is not None
    assert np.allclose(a.grad.array, 1 / b.array / a.array)


@report
def test_backprop_branching(Tensor):
    a = Tensor([1, 2, 3], requires_grad=True)
    b = Tensor([1, 2, 3], requires_grad=True)
    c = a * b
    c.backward(end_grad=np.array([1.0, 1.0, 1.0]))
    assert np.allclose(a.grad.array, b.array)
    assert np.allclose(b.grad.array, a.array)


@report
def test_backprop_requires_grad_false(Tensor):
    a = Tensor([1, 2, 3], requires_grad=True)
    b = Tensor([1, 2, 3], requires_grad=False)
    c = a * b
    c.backward(end_grad=np.array([1.0, 1.0, 1.0]))
    assert np.allclose(a.grad.array, b.array)
    assert b.grad is None


@report
def test_backprop_float_arg(Tensor):
    a = Tensor([1, 2, 3], requires_grad=True)
    b = 2
    c = a * b
    d = 2
    e = d * c
    e.backward(end_grad=np.array([1.0, 1.0, 1.0]))
    assert e.grad is None
    assert c.grad is None
    assert a.grad is not None
    assert np.allclose(a.grad.array, np.array([4.0, 4.0, 4.0]))
    
@report
def test_negative_back(Tensor):
    a = Tensor([-1, 0, 1], requires_grad=True)
    b = -a
    c = -b
    c.backward(end_grad=np.array([[1.0, 1.0, 1.0]]))
    assert a.grad is not None
    assert np.allclose(a.grad.array, [1, 1, 1])
    

@report
def test_exp_back(Tensor):
    a = Tensor([-1.0, 0.0, 1.0], requires_grad=True)
    b = a.exp()
    b.backward(end_grad=np.array([[1.0, 1.0, 1.0]]))
    assert a.grad is not None
    assert np.allclose(a.grad.array, 1 / np.e, 0, np.e)

    a = Tensor([-1.0, 0.0, 1.0], requires_grad=True)
    b = a.exp()
    c = b.exp()
    c.backward(end_grad=np.array([[1.0, 1.0, 1.0]]))

    def d(x):
        return (np.e**x) * (np.e ** (np.e**x))

    assert a.grad is not None
    assert np.allclose(a.grad.array, *[d(x) for x in a.array])
    
@report
def test_reshape_back(Tensor):
    a = Tensor([1, 2, 3, 4, 5, 6], requires_grad=True)
    b = a.reshape((3, 2))
    b.backward(end_grad=np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]))
    assert a.grad is not None and np.allclose(a.grad.array, np.ones(6))
    
@report
def test_permute_back(Tensor):
    a = Tensor(np.arange(24).reshape((2, 3, 4)), requires_grad=True)
    out = a.permute((2, 0, 1))
    out.backward(np.arange(24).reshape((4, 2, 3)))

    assert a.grad is not None
    assert np.allclose(
        a.grad.array,
        np.array(
            [
                [[0.0, 6.0, 12.0, 18.0], [1.0, 7.0, 13.0, 19.0], [2.0, 8.0, 14.0, 20.0]],
                [[3.0, 9.0, 15.0, 21.0], [4.0, 10.0, 16.0, 22.0], [5.0, 11.0, 17.0, 23.0]],
            ]
        ),
    )
    
@report
def test_expand(Tensor):
    a = Tensor(np.ones((2, 1, 3)), requires_grad=True)
    b = a.expand((5, 1, 2, 4, 3))
    b.backward(np.full_like(b.array, 10.0))
    assert a.grad is not None and a.grad.shape == a.array.shape
    assert (a.grad.array == 20 * 10.0).all()


@report
def test_expand_negative_length(Tensor):
    a = Tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
    b = a.expand((3, 2, -1))
    assert b.shape == (3, 2, 5)
    b.backward(end_grad=np.ones(b.shape))
    assert a.grad is not None and a.grad.shape == a.array.shape
    assert (a.grad.array == 6).all()
    
@report
def test_sum_keepdim_false(Tensor):
    a = Tensor(np.array([[0.0, 1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0, 9.0]]), requires_grad=True)
    b = a.sum(0)
    c = b.sum(0)
    c.backward(np.array(2))
    assert a.grad is not None
    assert a.grad.shape == a.shape
    assert (a.grad.array == 2).all()


@report
def test_sum_keepdim_true(Tensor):
    a = Tensor(np.array([[0.0, 1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0, 9.0]]), requires_grad=True)
    b = a.sum(1, keepdim=True)
    c = a.sum(0, keepdim=True)
    assert np.allclose(c.array, np.array([[5.0, 7.0, 9.0, 11.0, 13.0]]))
    c.backward(end_grad=np.ones(c.shape))
    assert a.grad is not None
    assert a.grad.shape == a.shape
    assert (a.grad.array == 1).all()


@report
def test_sum_dim_none(Tensor):
    a = Tensor(np.array([[0.0, 1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0, 9.0]]), requires_grad=True)
    b = a.sum()
    b.backward(np.array(4))
    assert a.grad is not None
    assert a.grad.shape == a.shape
    assert (a.grad.array == 4).all()
    

@report
def test_getitem_int(Tensor):
    a = Tensor([[0, 1, 2], [3, 4, 5]], requires_grad=True)
    b = a[1]
    c = b.sum(0)
    c.backward(np.array(10.0))
    assert a.grad is not None and np.allclose(a.grad.array, np.array([[0, 0, 0], [10, 10, 10]]))


@report
def test_getitem_tuple(Tensor):
    a = Tensor([[0, 1, 2], [3, 4, 5]], requires_grad=True)
    b = a[(1, 2)]
    b.backward(np.array(10.0))
    assert a.grad is not None and np.allclose(a.grad.array, np.array([[0, 0, 0], [0, 0, 10]]))


@report
def test_getitem_integer_array(Tensor):
    a = Tensor([[0, 1, 2], [3, 4, 5]], requires_grad=True)
    index = np.array([0, 1, 0, 1, 0]), np.array([0, 0, 1, 2, 0])
    out = a[index]
    out.sum().backward(np.array(10.0))
    assert a.grad is not None
    assert np.allclose(a.grad.array, np.array([[20, 10, 0], [10, 0, 10]]))


@report
def test_getitem_integer_tensor(Tensor):
    a = Tensor([[0, 1, 2], [3, 4, 5]], requires_grad=True)
    index = Tensor(np.array([0, 1, 0, 1, 0])), Tensor(np.array([0, 0, 1, 2, 0]))
    out = a[index]
    out.sum().backward(np.array(10.0))
    assert a.grad is not None
    assert np.allclose(a.grad.array, np.array([[20, 10, 0], [10, 0, 10]]))
    

@report
def test_add_broadcasted(Tensor):
    a = Tensor([0, 1, 2, 3], requires_grad=True)
    b = Tensor([[0], [1], [10]], requires_grad=True)
    c = a + b
    c.backward(end_grad=np.ones(c.shape))
    assert a.grad is not None
    assert a.grad.shape == a.shape
    assert (a.grad.array == 3).all()
    assert b.grad is not None
    assert b.grad.shape == b.shape
    assert (b.grad.array == 4).all()


@report
def test_subtract_broadcasted(Tensor):
    a = Tensor([0, 1, 2, 3], requires_grad=True)
    b = Tensor([[0], [1], [10]], requires_grad=True)
    c = a - b
    c.backward(end_grad=np.ones(c.shape))
    assert a.grad is not None
    assert a.grad.shape == a.shape
    assert (a.grad.array == 3).all()
    assert b.grad is not None
    assert b.grad.shape == b.shape
    assert (b.grad.array == -4).all()


@report
def test_truedivide_broadcasted(Tensor):
    a = Tensor([0, 6, 12, 18], requires_grad=True)
    b = Tensor([[1], [2], [3]], requires_grad=True)
    c = a / b
    c.backward(end_grad=np.ones(c.shape))
    assert a.grad is not None
    assert a.grad.shape == a.shape
    assert (a.grad.array == (1 + 1 / 2 + 1 / 3)).all()
    assert b.grad is not None
    assert b.grad.shape == b.shape
    assert np.equal(b.grad.array, np.array([[-36.0], [-9.0], [-4.0]])).all()
    
@report
def test_maximum(Tensor):
    a = Tensor([0, 1, 2], requires_grad=True)
    b = Tensor([-1, 1, 3], requires_grad=True)
    out = a.maximum(b)
    assert np.allclose(out.array, [0, 1, 3])
    out.backward(end_grad=np.ones(out.shape))

    assert a.grad is not None
    assert b.grad is not None
    assert np.allclose(a.grad.array, [1, 0.5, 0])
    assert np.allclose(b.grad.array, [0, 0.5, 1])


@report
def test_maximum_broadcasted(Tensor):
    a = Tensor([0, 1, 2], requires_grad=True)
    b = Tensor([[-1], [1], [3]], requires_grad=True)
    out = a.maximum(b)
    assert np.allclose(out.array, np.array([[0, 1, 2], [1, 1, 2], [3, 3, 3]]))
    out.backward(end_grad=np.ones(out.shape))
    assert a.grad is not None and np.allclose(a.grad.array, np.array([1.0, 1.5, 2.0]))
    assert b.grad is not None and np.allclose(b.grad.array, np.array([[0.0], [1.5], [3.0]]))
    

@report
def test_relu(Tensor):
    a = Tensor([-1, 0, 1], requires_grad=True)
    out = a.relu()
    out.backward(end_grad=np.ones(out.shape))
    assert a.grad is not None and np.allclose(a.grad.array, np.array([0, 0.5, 1.0]))
    

@report
def test_matmul2d(Tensor):
    a = Tensor(np.arange(-3, 3).reshape((2, 3)), requires_grad=True)
    b = Tensor(np.arange(-4, 5).reshape((3, 3)), requires_grad=True)
    out = a @ b
    out.backward(end_grad=np.ones(out.shape))
    assert a.grad is not None
    assert b.grad is not None
    assert np.allclose(a.grad.array, np.array([[-9, 0, 9], [-9, 0, 9]]))
    assert np.allclose(b.grad.array, np.array([[-3, -3, -3], [-1, -1, -1], [1, 1, 1]]))
    
@report
def test_cross_entropy(Tensor, cross_entropy):
    logits = Tensor([[float("-inf"), float("-inf"), 0], [1 / 3, 1 / 3, 1 / 3], [float("-inf"), 0, 0]])
    true_labels = Tensor([2, 0, 0])
    expected = Tensor([0.0, np.log(3), float("inf")])
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        actual = cross_entropy(logits, true_labels)
    assert np.allclose(actual.array, expected.array)
    

@report
def test_no_grad(Tensor, NoGrad):
    a = Tensor([1], requires_grad=True)
    with NoGrad():
        b = a + a
    c = a + a
    assert b.requires_grad == False
    assert b.recipe is None
    assert c.requires_grad
    assert c.recipe is not None


@report
def test_no_grad_nested(Tensor, NoGrad):
    a = Tensor([1], requires_grad=True)
    with NoGrad():
        with NoGrad():
            with NoGrad():
                b = a + a
    assert b.requires_grad == False
    assert b.recipe is None

In [3]:
def log_back(grad_out : Arr, Out : Arr, x : Arr) -> Arr:
    
    """Backwards function for f(x) = log(x)
    
    grad_out: Gradient of some loss wrt out
    out: the output of np.log(x). Provided as an optimization in case it's cheaper to express the gradient in terms of the output.
    x: the input of np.log.
    
    Return: gradient of the given loss wrt x
    """
    return (1 / x) * grad_out
    
test_log_back(log_back)

__main__.test_log_back passed in 0.00s.


In [4]:
def unbroadcast(broadcasted : Arr, original : Arr) -> Arr:
    
    """Sum 'broadcasted' until it has the shape of 'original'.
    broadcasted: An array that was formerly of the same shape of 'original' and was expanded by broadcasting rules.
    """
    
    # 1) Sum and remove dimensions that were prepended to the front of the original shape.
    prepended_dims = broadcasted.ndim - original.ndim
    if prepended_dims > 0:
        broadcasted = broadcasted.sum(axis = tuple(range(prepended_dims)))
        
    # 2) Sum dimensions that were originally 1 back to the size 1 (using keepdims=True).
    broadcasted_dims = tuple([axis for axis, size in enumerate(original.shape) if size == 1 and broadcasted.shape[axis] > 1])
    broadcasted = broadcasted.sum(axis = broadcasted_dims, keepdims = True)
    assert broadcasted.shape == original.shape
    return broadcasted

test_unbroadcast(unbroadcast)

__main__.test_unbroadcast passed in 0.00s.


In [5]:
def multiply_back0(grad_out : Arr, out : Arr, x : Arr,
                  y : Union[Arr, float]) -> Arr:
    """Backwards function for x * y wrt argument 0 aka x."""
    op = unbroadcast(grad_out * y, x)
    return op
    

def multiply_back1(grad_out : Arr, out : Arr, x : Union[Arr, float],
                  y : Arr) -> Arr:
    """Backwards function for x * y wrt argument 1 aka y."""
    op = unbroadcast(grad_out * x, y)
    return op

test_multiply_back(multiply_back0, multiply_back1)
test_multiply_back_float(multiply_back0, multiply_back1)

__main__.test_multiply_back passed in 0.00s.
__main__.test_multiply_back_float passed in 0.00s.


In [6]:
def forward_and_back(a : Arr, b : Arr, c : Arr) -> tuple[Arr, Arr, Arr]:
    
    """
    Calculates the output of the computational graph above (g), then backpropogates the gradients and returns dg/da, dg/db, and dg/dc
    """
    d = a * b
    e = np.log(c)
    f = d * e
    g = np.log(f)
    
    dg_df = log_back(1, g, f)
    df_dd = multiply_back0(dg_df, f, d, e)
    df_de = multiply_back1(dg_df, f, d, e)
    dg_da = multiply_back0(df_dd, d, a, b)
    dg_db = multiply_back1(df_dd, d, a, b)
    dg_dc = log_back(df_de, f, c)
    return dg_da, dg_db, dg_dc

test_forward_and_back(forward_and_back)

__main__.test_forward_and_back passed in 0.00s.


In [7]:
@dataclass(frozen=True)
class Recipe:
    """Extra information necessary to run backpropagation. You don't need to modify this."""

    func: Callable
    "The 'inner' NumPy function that does the actual forward computation."
    args: tuple
    "The input arguments passed to func."
    kwargs: dict[str, Any]
    "Keyword arguments passed to func. To keep things simple today, we aren't going to backpropagate with respect to these."
    parents: dict[int, "Tensor"]
    "Map from positional argument index to the Tensor at that position, in order to be able to pass gradients back along the computational graph."



In [8]:
class BackwardFuncLookup:
    
    def __init__(self) -> None:
        self.back_funcs : defaultdict[Callable, defaultdict[int, Callable]] = defaultdict(dict)
    
    def add_back_func(self, forward_fn : Callable,
                     arg_position : int, back_fn : Callable) -> None:
        self.back_funcs[forward_fn][arg_position] = back_fn
    
    def get_back_func(self, forward_fn : Callable,
                     arg_position: int) -> Callable:
        return self.back_funcs[forward_fn][arg_position]
    
test_back_func_lookup(BackwardFuncLookup)
BACK_FUNCS = BackwardFuncLookup()
BACK_FUNCS.add_back_func(np.log, 0, log_back)
BACK_FUNCS.add_back_func(np.multiply, 0, multiply_back0)
BACK_FUNCS.add_back_func(np.multiply, 1, multiply_back1)

__main__.test_back_func_lookup passed in 0.00s.


In [9]:
class Tensor:
    """
    A drop-in replacement for torch.Tensor supporting a subset of features.
    """

    array: Arr
    "The underlying array. Can be shared between multiple Tensors."
    requires_grad: bool
    "If True, calling functions or methods on this tensor will track relevant data for backprop."
    grad: Optional["Tensor"]
    "Backpropagation will accumulate gradients into this field."
    recipe: Optional[Recipe]
    "Extra information necessary to run backpropagation."

    def __init__(self, array: Union[Arr, list], requires_grad=False):
        self.array = array if isinstance(array, Arr) else np.array(array)
        self.requires_grad = requires_grad
        self.grad = None
        self.recipe = None
        "If not None, this tensor's array was created via recipe.func(*recipe.args, **recipe.kwargs)."

    def __neg__(self) -> "Tensor":
        return negative(self)

    def __add__(self, other) -> "Tensor":
        return add(self, other)

    def __radd__(self, other) -> "Tensor":
        return add(other, self)

    def __sub__(self, other) -> "Tensor":
        return subtract(self, other)

    def __rsub__(self, other):
        return subtract(other, self)

    def __mul__(self, other) -> "Tensor":
        return multiply(self, other)

    def __rmul__(self, other):
        return multiply(other, self)

    def __truediv__(self, other):
        return true_divide(self, other)

    def __rtruediv__(self, other):
        return true_divide(self, other)

    def __matmul__(self, other):
        return matmul(self, other)

    def __rmatmul__(self, other):
        return matmul(other, self)

    def __eq__(self, other):
        return eq(self, other)

    def __repr__(self) -> str:
        return f"Tensor({repr(self.array)}, requires_grad={self.requires_grad})"

    def __len__(self) -> int:
        if self.array.ndim == 0:
            raise TypeError
        return self.array.shape[0]

    def __hash__(self) -> int:
        return id(self)

    def __getitem__(self, index) -> "Tensor":
        return getitem(self, index)

    def add_(self, other: "Tensor", alpha: float = 1.0) -> "Tensor":
        add_(self, other, alpha=alpha)
        return self

    @property
    def T(self) -> "Tensor":
        return permute(self)

    def item(self):
        return self.array.item()

    def sum(self, dim=None, keepdim=False):
        return sum(self, dim=dim, keepdim=keepdim)

    def log(self):
        return log(self)

    def exp(self):
        return exp(self)

    def reshape(self, new_shape):
        return reshape(self, new_shape)

    def expand(self, new_shape):
        return expand(self, new_shape)

    def permute(self, dims):
        return permute(self, dims)

    def maximum(self, other):
        return maximum(self, other)

    def relu(self):
        return relu(self)

    def argmax(self, dim=None, keepdim=False):
        return argmax(self, dim=dim, keepdim=keepdim)

    def uniform_(self, low: float, high: float) -> "Tensor":
        self.array[:] = np.random.uniform(low, high, self.array.shape)
        return self

    def backward(self, end_grad: Union[Arr, "Tensor", None] = None) -> None:
        if isinstance(end_grad, Arr):
            end_grad = Tensor(end_grad)
        return backprop(self, end_grad)

    def size(self, dim: Optional[int] = None):
        if dim is None:
            return self.shape
        return self.shape[dim]

    @property
    def shape(self):
        return self.array.shape

    @property
    def ndim(self):
        return self.array.ndim

    @property
    def is_leaf(self):
        """Same as https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html"""
        if self.requires_grad and self.recipe and self.recipe.parents:
            return False
        return True

    def __bool__(self):
        if np.array(self.shape).prod() != 1:
            raise RuntimeError("bool value of Tensor with more than one value is ambiguous")
        return bool(self.item())


def empty(*shape: int) -> Tensor:
    """Like torch.empty."""
    return Tensor(np.empty(shape))


def zeros(*shape: int) -> Tensor:
    """Like torch.zeros."""
    return Tensor(np.zeros(shape))


def arange(start: int, end: int, step=1) -> Tensor:
    """Like torch.arange(start, end)."""
    return Tensor(np.arange(start, end, step=step))


def tensor(array: Arr, requires_grad=False) -> Tensor:
    """Like torch.tensor."""
    return Tensor(array, requires_grad=requires_grad)
    
    

In [10]:
def log_forward(x : Tensor) -> Tensor:
    output = Tensor(np.log(x.array))
    if grad_tracking_enabled and (x.requires_grad or x.recipe is not None):
        output.recipe = Recipe(np.log, (x.array,), {}, {0:x})
        output.requires_grad = True
        
    return output
    

log = log_forward
test_log(Tensor, log_forward)
test_log_no_grad(Tensor, log_forward)
a = Tensor([1], requires_grad = True)
grad_tracking_enabled = False
b = log_forward(a)
grad_tracking_enabled = True
assert not b.requires_grad, "should not require grad if grad tracking globally disabled"
assert b.recipe is None, "should not create recipe if grad tracking globally disabled"


__main__.test_log passed in 0.00s.
__main__.test_log_no_grad passed in 0.00s.


In [11]:
def multiply_forward(a : Union[Tensor, int], b : Union[Tensor, int]) -> Tensor:
    assert isinstance(a, Tensor) or isinstance(b, Tensor)
    a_arr = a.array if isinstance(a, Tensor) else a
    b_arr = b.array if isinstance(b, Tensor) else b
    out_arr = a_arr * b_arr
    output = Tensor(out_arr)
    assert isinstance(out_arr, np.ndarray)
    if grad_tracking_enabled and (
        isinstance(a, Tensor)
        and (a.requires_grad or a.recipe is not None)
        or isinstance(b, Tensor)
        and (b.requires_grad or b.recipe is not None)
    ):
        parents = {}
        if isinstance(a, Tensor):
            parents[0] = a
        if isinstance(b, Tensor):
            parents[1] = b
            
        output.recipe = Recipe(np.multiply, (a_arr, b_arr), {}, parents)
        output.requires_grad = True
        
    return output
    


multiply = multiply_forward
test_multiply(Tensor, multiply_forward)
test_multiply_no_grad(Tensor, multiply_forward)
test_multiply_float(Tensor, multiply_forward)
a = Tensor([2], requires_grad=True)
b = Tensor([3], requires_grad=True)
grad_tracking_enabled = False
b = multiply_forward(a, b)
grad_tracking_enabled = True
assert not b.requires_grad, "should not require grad if grad tracking globally disabled"
assert b.recipe is None, "should not create recipe if grad tracking globally disabled"

__main__.test_multiply passed in 0.00s.
__main__.test_multiply_no_grad passed in 0.00s.
__main__.test_multiply_float passed in 0.00s.


In [12]:
def wrap_forward_fn(numpy_func: Callable, is_differentiable=True) -> Callable:
    """
    numpy_func: function. It takes any number of positional arguments, some of which may be NumPy arrays, and any number of keyword arguments which we aren't allowing to be NumPy arrays at present. It returns a single NumPy array.
    is_differentiable: if True, numpy_func is differentiable with respect to some input argument, so we may need to track information in a Recipe. If False, we definitely don't need to track information.

    Return: function. It has the same signature as numpy_func, except wherever there was a NumPy array, this has a Tensor instead.
    """

    def tensor_func(*args: Any, **kwargs: Any) -> Tensor:
        
        args_arr = tuple([arg.array if isinstance(arg, Tensor) else arg for arg in args])
        for k, v in kwargs.items():
            if isinstance(v, Tensor):
                raise ValueError(f"Keyword tensors not supported, got key: {k}, value: {v}")
        
        out = Tensor(numpy_func(*args_arr, **kwargs))
        if grad_tracking_enabled and is_differentiable:
            parents = {
                i : arg
                for i, arg in enumerate(args)
                if isinstance(arg, Tensor) and (arg.requires_grad or arg.recipe is not None)
            }
            if parents:
                out.recipe = Recipe(numpy_func, args_arr, kwargs, parents)
                out.requires_grad = True
                
        return out

    return tensor_func

log = wrap_forward_fn(np.log)
multiply = wrap_forward_fn(np.multiply)
test_log(Tensor, log)
test_log_no_grad(Tensor, log)
test_multiply(Tensor, multiply)
test_multiply_no_grad(Tensor, multiply)
test_multiply_float(Tensor, multiply)

try:
    log(x = Tensor([100]))
except Exception as e:
    print("Got a nice exception as intended:")
    print(e)
else:
    assert False, "Passing tensor by keyword should raise some informative exception."


__main__.test_log passed in 0.00s.
__main__.test_log_no_grad passed in 0.00s.
__main__.test_multiply passed in 0.00s.
__main__.test_multiply_no_grad passed in 0.00s.
__main__.test_multiply_float passed in 0.00s.
Got a nice exception as intended:
Keyword tensors not supported, got key: x, value: Tensor(array([100]), requires_grad=False)


In [13]:
class Node(Protocol):
    """
    A protocol defining the Node's interface in topological sort. Any object will do!
    """


class ChildrenGetter(Protocol):
    """A protocol defining the get_children_fns passed to topological sort, to get the node's children"""

    def __call__(self, node: Any) -> list[Any]:
        """Get the given node's children, returning a list of nodes"""
        ...


def topological_sort(node: Node, get_children_fn: ChildrenGetter) -> list[Any]:
    """
    Return a list of node's descendants in reverse topological order from future to past.
    """
    
    visited: set[Node] = set()
    solution: set[Node] = set()
    result: list[Node] = []
        
    def visit(curr : Node):
        
        if curr in solution:
            return
        if curr in visited:
            raise ValueError('Detected cyclic graph')
        visited.add(curr)
        
        for next in get_children_fn(curr):
            visit(next)
            
        visited.remove(curr)
        solution.add(curr)
        result.append(curr)
        
    visit(node)
    return result
        
"""
test_topological_sort_linked_list(topological_sort)
test_topological_sort_branching(topological_sort)
test_topological_sort_rejoining(topological_sort)
test_topological_sort_cyclic(topological_sort)
"""

'\ntest_topological_sort_linked_list(topological_sort)\ntest_topological_sort_branching(topological_sort)\ntest_topological_sort_rejoining(topological_sort)\ntest_topological_sort_cyclic(topological_sort)\n'

In [14]:
def sorted_computational_graph(node: Tensor) -> list[Tensor]:
    """
    For a given tensor, return a list of Tensors that make up the nodes of the given Tensor's computational graph, in reverse topological order
    """
    
    def get_parents(node : Tensor) -> list[Tensor]:
        if node.recipe is None:
            return []
        return list(node.recipe.parents.values())
    
    return list(reversed(topological_sort(node, get_parents)))

a = Tensor([1], requires_grad = True)
b = Tensor([2], requires_grad = True)
c = Tensor([3], requires_grad = True)
d = a * b
e = c.log()
f = d * e
g = f.log()
name_lookup = {a: "a", b: "b", c: "c", d: "d", e: "e", f: "f", g: "g"}
print([name_lookup[t] for t in sorted_computational_graph(g)])
    

['g', 'f', 'e', 'c', 'd', 'b', 'a']


In [15]:
def backprop(end_node: Tensor, end_grad: Optional[Tensor] = None) -> None:
    """Accumulate gradients in the grad field of each leaf node.
    
    tensor.backward() is equivalent to backprop(tensor).
    
    end_node: the rightmost node in the computation graph. If it contains more than one element, end_grad must be provided.
    end_grad: A tensor of the same shape as end_node. Set to 1 if not specified and end_node has only one element. 
    """
    
    if np.array(end_node.shape).prod() > 1 and end_grad is None:
        raise RuntimeError("backprop from non-scalar tensors requires end_grad")
    end_grad_arr = np.ones_like(end_node.array) if end_grad is None else end_grad.array
    grads: dict[Tensor, Arr] = {end_node: end_grad_arr}
    
    for node in sorted_computational_graph(end_node):
        last_grad = grads.pop(node)
        if node.is_leaf:
            if node.grad is None:
                node.grad = Tensor(last_grad)
            else:
                node.grad += last_grad
        
        if node.recipe is None:
            continue
            
        for argnum, parent in node.recipe.parents.items():
            back_fn = BACK_FUNCS.get_back_func(node.recipe.func, argnum)
            in_grad = back_fn(last_grad, node.array, *node.recipe.args, **node.recipe.kwargs)
            cur_pgrad = grads.get(parent)
            if cur_pgrad is None:
                grads[parent] = in_grad
            else:
                grads[parent] += in_grad
        

test_backprop(Tensor)
test_backprop_branching(Tensor)
test_backprop_float_arg(Tensor)
test_backprop_requires_grad_false(Tensor)

__main__.test_backprop passed in 0.00s.
__main__.test_backprop_branching passed in 0.00s.
__main__.test_backprop_float_arg passed in 0.00s.
__main__.test_backprop_requires_grad_false passed in 0.00s.


In [16]:
def _argmax(x:Arr, dim = None, keepdim = False):
    """Like torch.argmax"""
    return np.argmax(x, axis = dim, keepdims = keepdim)

argmax = wrap_forward_fn(_argmax, is_differentiable = False)
eq = wrap_forward_fn(np.equal, is_differentiable = False)

a = Tensor([1.0, 0.0, 3.0, 4.0], requires_grad = False)
b = a.argmax()
assert not b.requires_grad
assert b.recipe is None
assert b.item() == 3

In [17]:
def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = -x elementwise."""
    return -np.ones_like(x) * grad_out

negative = wrap_forward_fn(np.negative)
BACK_FUNCS.add_back_func(np.negative, 0, negative_back)
test_negative_back(Tensor)

__main__.test_negative_back passed in 0.00s.


In [18]:
def exp_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    return out * grad_out

exp = wrap_forward_fn(np.exp)
BACK_FUNCS.add_back_func(np.exp, 0, exp_back)
test_exp_back(Tensor)

__main__.test_exp_back passed in 0.00s.


In [19]:
def reshape_back(grad_out: Arr, out: Arr, x: Arr, new_shape: tuple) -> Arr:
    return grad_out.reshape(x.shape)

reshape = wrap_forward_fn(np.reshape)
BACK_FUNCS.add_back_func(np.reshape, 0, reshape_back)
test_reshape_back(Tensor)

__main__.test_reshape_back passed in 0.00s.


In [20]:
def permute_back(grad_out: Arr, out: Arr, x: Arr, axes: tuple) -> Arr:
    new_axes = np.argsort(axes)
    return grad_out.transpose(new_axes)

BACK_FUNCS.add_back_func(np.transpose, 0, permute_back)
permute = wrap_forward_fn(np.transpose)
test_permute_back(Tensor)

__main__.test_permute_back passed in 0.00s.


In [21]:
def expand_back(grad_out: Arr, out: Arr, x: Arr, new_shape: tuple) -> Arr:
    return unbroadcast(grad_out, x)

def _expand(x: Arr, new_shape) -> Arr:
    """Like torch.expand, calling np.broadcast_to internally.
    
    Note torch.expand supports -1 for a dimension size meaning "don't change the size".
    np.broadcast_to does not natively support this.
    """
    
    shape_diff = len(new_shape) - x.ndim
    new_expanded_shape = tuple([x.shape[i - shape_diff] if s == -1 else s for i, s in enumerate(new_shape)])
    return np.broadcast_to(x, new_expanded_shape)

expand = wrap_forward_fn(_expand)
BACK_FUNCS.add_back_func(_expand, 0, expand_back)
test_expand(Tensor)
test_expand_negative_length(Tensor)

__main__.test_expand passed in 0.00s.
__main__.test_expand_negative_length passed in 0.00s.


In [22]:
def _coerce_dim(x, dim: Union[None, int, tuple[int]]) -> tuple[int]:
    if dim is None:
        return tuple(range(x.ndim))
    elif isinstance(dim, int):
        return tuple([dim])
    return dim

def sum_back(grad_out: Arr, out: Arr, x: Arr,
            dim = None, keepdim = False):
    dim = _coerce_dim(x, dim)
    if not keepdim:
        unsqueezed = tuple([1 if d in dim else size for d, size in enumerate(x.shape)])
        grad_out = grad_out.reshape(unsqueezed)
    return np.broadcast_to(grad_out, x.shape)
    

def _sum(x: Arr, dim = None, keepdim = False) -> Arr:
    """Like torch.sum, calling np.sum internally."""
    return x.sum(axis = dim, keepdims = keepdim)

sum = wrap_forward_fn(_sum)
BACK_FUNCS.add_back_func(_sum, 0, sum_back)
test_sum_keepdim_false(Tensor)
test_sum_keepdim_true(Tensor)
test_sum_dim_none(Tensor)

__main__.test_sum_keepdim_false passed in 0.00s.
__main__.test_sum_keepdim_true passed in 0.00s.
__main__.test_sum_dim_none passed in 0.00s.


In [23]:
Index = Union[int, tuple[int, ...], tuple[Arr], tuple[Tensor]]

def _coerce_index(item: Union[int, Arr, Tensor]) -> Union[int, Arr]:
    
    if isinstance(item, Arr):
        return item
    if isinstance(item, Tensor):
        return item.array
    return item

def _getitem(x: Arr, index: Index) -> Arr:
    """Like x[index] when x is a torch.Tensor."""
    if isinstance(index, int):
        return x[index]
    else:
        new_idx = tuple([_coerce_index(idx) for idx in index])
        return x[new_idx]

def getitem_back(grad_out: Arr, out: Arr, x: Arr,
                index: Index):
    """Backwards function for _getitem.
    
    Hint: use np.add.at(a, indices, b)
    """
    new_arr = np.zeros_like(x)
    if isinstance(index, int):
        np.add.at(new_arr, index, grad_out)
    else:
        new_idx = tuple([_coerce_index(idx) for idx in index])
        np.add.at(new_arr, new_idx, grad_out)
    return new_arr

getitem = wrap_forward_fn(_getitem)
BACK_FUNCS.add_back_func(_getitem, 0, getitem_back)
test_getitem_int(Tensor)
test_getitem_tuple(Tensor)
test_getitem_integer_array(Tensor)
test_getitem_integer_tensor(Tensor)

__main__.test_getitem_int passed in 0.00s.
__main__.test_getitem_tuple passed in 0.00s.
__main__.test_getitem_integer_array passed in 0.00s.
__main__.test_getitem_integer_tensor passed in 0.00s.


In [24]:
add = wrap_forward_fn(np.add)
subtract = wrap_forward_fn(np.subtract)
true_divide = wrap_forward_fn(np.true_divide)
BACK_FUNCS.add_back_func(np.add, 0, lambda grad_out, out, x, y : unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.add, 1, lambda grad_out, out, x, y : unbroadcast(grad_out, y))
BACK_FUNCS.add_back_func(np.subtract, 0, lambda grad_out, out, x, y : unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.subtract, 1, lambda grad_out, out, x, y : unbroadcast(-grad_out, y))
BACK_FUNCS.add_back_func(np.true_divide, 0, lambda grad_out, out, x, y : unbroadcast(grad_out * (1 / y), x))
BACK_FUNCS.add_back_func(np.true_divide, 1, lambda grad_out, out, x, y : unbroadcast(grad_out * (x / -np.square(y)), y))

test_add_broadcasted(Tensor)
test_subtract_broadcasted(Tensor)
test_truedivide_broadcasted(Tensor)

__main__.test_add_broadcasted passed in 0.00s.
__main__.test_subtract_broadcasted passed in 0.00s.
__main__.test_truedivide_broadcasted passed in 0.00s.


In [25]:
def add_(x: Tensor, other: Tensor, alpha: float = 1.0) -> Tensor:
    """Like torch.add_. Compute x += other * alpha in-place and return tensor."""
    np.add(x.array, other.array * alpha, out = x.array)
    
def safe_example():
    """This example should work properly."""
    a = Tensor([0.0, 1.0, 2.0, 3.0], requires_grad = True)
    b = Tensor([2.0, 3.0, 4.0, 5.0], requires_grad = True)
    a.add_(b)
    c = a * b
    c.sum().backward()
    assert a.grad is not None and np.allclose(a.grad.array, [2.0, 3.0, 4.0, 5.0])
    assert b.grad is not None and np.allclose(b.grad.array, [2.0, 4.0, 6.0, 8.0])
    
def unsafe_example():
    """This example is expected to compute the wrong gradients."""
    a = Tensor([0.0, 1.0, 2.0, 3.0], requires_grad = True)
    b = Tensor([2.0, 3.0, 4.0, 5.0], requires_grad = True)
    c = a * b
    a.add_(b)
    c.sum().backward()
    if a.grad is not None and np.allclose(a.grad.array, [2.0, 3.0, 4.0, 5.0]):
        print("Grad wrt a is OK!")
    else:
        print("Grad wrt a is WRONG!")
        
    if b.grad is not None and np.allclose(b.grad.array, [0.0, 1.0, 2.0, 3.0]):
        print("Grad wrt b is OK!")
    else:
        print("Grad wrt b is WRONG!")
        
safe_example()
unsafe_example()

Grad wrt a is OK!
Grad wrt b is WRONG!


In [26]:
a = Tensor([0, 1, 2, 3], requires_grad=True)
(a * 2).sum().backward()
b = Tensor([0, 1, 2, 3], requires_grad=True)
(2 * b).sum().backward()
assert a.grad is not None
assert b.grad is not None
assert np.allclose(a.grad.array, b.grad.array)

In [27]:
def maximum_back0(grad_out: Arr, out: Arr,
                 x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt x."""
    out = grad_out.copy()
    out = np.where(x == y, grad_out / 2, out)
    out = np.where(x < y, 0.0, out)
    return unbroadcast(out, x)


def maximum_back1(grad_out: Arr, out: Arr,
                 x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt y."""
    out = grad_out.copy()
    out = np.where(x == y, grad_out / 2, out)
    out = np.where(x > y, 0.0, out)
    return unbroadcast(out, y)

maximum = wrap_forward_fn(np.maximum)
BACK_FUNCS.add_back_func(np.maximum, 0, maximum_back0)
BACK_FUNCS.add_back_func(np.maximum, 1, maximum_back1)
test_maximum(Tensor)
test_maximum(Tensor)


__main__.test_maximum passed in 0.00s.
__main__.test_maximum passed in 0.00s.


In [28]:
def relu(x: Tensor) -> Tensor:
    """Like torch.nn.function.relu(x, inplace=False)."""
    return maximum(0.0, x)

test_relu(Tensor)

__main__.test_relu passed in 0.00s.


In [33]:
def _matmul2d(x: Arr, y: Arr) -> Arr:
    """Matrix multiply restricted to the case where both inputs are exactly 2D."""
    return x @ y

def matmul2d_back0(grad_out: Arr, out: Arr,
                  x: Arr, y: Arr) -> Arr:
    return grad_out @ np.transpose(y)

def matmul2d_back1(grad_out: Arr, out: Arr,
                  x: Arr, y: Arr) -> Arr:
    return np.transpose(x) @ grad_out

matmul = wrap_forward_fn(_matmul2d)
BACK_FUNCS.add_back_func(_matmul2d, 0, matmul2d_back0)
BACK_FUNCS.add_back_func(_matmul2d, 1, matmul2d_back1)
test_matmul2d(Tensor)

__main__.test_matmul2d passed in 0.00s.


In [34]:
class Parameter(Tensor):
    def __init__(self, tensor : Tensor, requires_grad = True):
        """Share the array with the provided tensor."""
        return super().__init__(tensor.array, requires_grad = requires_grad)
    
    def __repr__(self):
        return f"Parameter containing:\n{super().__repr__()}"
    
x = Tensor([1.0, 2.0, 3.0])
p = Parameter(x)
assert p.requires_grad
assert p.array is x.array
assert repr(p) == "Parameter containing:\nTensor(array([1., 2., 3.]), requires_grad=True)"
x.add_(Tensor(np.array(2.0)))
assert np.allclose(
        p.array, np.array([3.0, 4.0, 5.0])
    ), "in-place modifications to the original tensor should affect the parameter"


In [37]:
class Module:
    _modules: dict[str, "Module"]
    _parameters: dict[str, Parameter]
        
    def __init__(self):
        self._modules = {}
        self._parameters = {}
    
    def modules(self):
        """Return the direct child modules of this module."""
        return self.__dict__["_modules"].values()
    
    def _iter_params(self):
        yield from self._parameters.values()
        for child in self._modules.values():
            yield from child.parameters(recurse=True)
    
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Return an iterator over Module parameters.
        
        recurse: if True, the iterator includes parameters of submodules, recursively.
        """
        if recurse:
            return self._iter_params()
        else:
            return iter(self._parameters.values())
    
    def __setattr__(self, key: str, val: Any) -> None:
        """
        If val is a Parameter or Module, store it in the appropriate _parameters or _modules dict.
        Otherwise, call the superclass.
        """
        if isinstance(val, Parameter):
            self.__dict__["_parameters"][key] = val
        elif isinstance(val, Module):
            self.__dict__["_modules"][key] = val
        else:
            super().__setattr__(key, val)
    
    def __getattr__(self, key: str) -> Union[Parameter, "Module"]:
        """
        If key is in _parameters or _modules, return the corresponding value.
        Otherwise, raise KeyError.
        """
        if key in self.__dict__['_parameters']:
            return self.__dict__['_parameters'][key]
        
        if key in self.__dict__['_modules']:
            return self.__dict__['_modules'][key]
        
        raise KeyError(key)
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    
    def forward(self):
        raise NotImplementedError("Subclasses must implement forward!")
        
    def __repr__(self):
        
        def _addindent(s_, numSpaces):
            s = s_.split("\n")
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(numSpaces * " ") + line for line in s]
            s = "\n".join(s)
            s = first + "\n" + s
            return s

        child_lines = []

        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append("(" + key + "): " + mod_str)
        lines = child_lines

        main_str = self.__class__.__name__ + "("
        if lines:
            # simple one-liner info, which most builtin Modules will use
            main_str += "\n  " + "\n  ".join(lines) + "\n"

        main_str += ")"
        return main_str
    
class TestInnerModule(Module):
    def __init__(self):
        super().__init__()
        self.param1 = Parameter(Tensor([1.0]))
        self.param2 = Parameter(Tensor([2.0]))

class TestModule(Module):
    def __init__(self):
        super().__init__()
        self.inner = TestInnerModule()
        self.param3 = Parameter(Tensor([3.0]))

mod = TestModule()
assert list(mod.modules()) == [mod.inner]
assert list(mod.parameters()) == [
    mod.param3,
    mod.inner.param1,
    mod.inner.param2,
], "parameters should come before submodule parameters"
print("Manually verify that the repr looks reasonable:")
print(mod)
    


Manually verify that the repr looks reasonable:
TestModule(
  (inner): TestInnerModule()
)


In [38]:
class Linear(Module):
    weight: Parameter
    bias: Optional[Parameter]

    def __init__(self, in_features: int, out_features: int, bias=True):
        """A simple linear (technically, affine) transformation.

        The fields should be named `weight` and `bias` for compatibility with PyTorch.
        If `bias` is False, set `self.bias` to None.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        bound = in_features ** -0.5
        self.weight = Parameter(empty(self.out_features, self.in_features).uniform_(-bound, bound))
        if bias:
            self.bias = Parameter(empty(out_features).uniform_(-bound, bound))
        else:
            self.bias = None

    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (*, in_features)
        Return: shape (*, out_features)
        """
        out = x @ self.weight.permute((1, 0))
        if self.bias is not None:
            out = out + self.bias
        return out

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
    


In [39]:
class MLP(Module):
    def __init__(self):
        super().__init__()
        self.linear1 = Linear(28 * 28, 64)
        self.linear2 = Linear(64, 64)
        self.output = Linear(64, 10)

    def forward(self, x):
        x = x.reshape((x.shape[0], 28 * 28))
        x = relu(self.linear1(x))
        x = relu(self.linear2(x))
        x = self.output(x)
        return x

In [43]:
def cross_entropy(logits: Tensor, true_labels: Tensor) -> Tensor:
    """Like torch.nn.functional.cross_entropy with reduction='none'.
    
    logits: shape (batch, classes)
    true_labels: shape (batch, ). Each element is the index of the correct label in the logits.
    
    Return: shape(batch, ) containing the per-example loss.
    """
    batch_size, n_classes = logits.shape
    true = logits[arange(0, batch_size), true_labels]
    return -log(exp(true) / exp(logits).sum(1))

test_cross_entropy(Tensor, cross_entropy)

__main__.test_cross_entropy passed in 0.00s.


In [46]:
class NoGrad:
    """Context manager that disables grad inside the block. Like torch.no_grad."""

    was_enabled: bool

    def __enter__(self):
        global grad_tracking_enabled
        self.was_enabled = grad_tracking_enabled
        grad_tracking_enabled = False

    def __exit__(self, type, value, traceback):
        global grad_tracking_enabled
        grad_tracking_enabled = self.was_enabled


test_no_grad(Tensor, NoGrad)
test_no_grad_nested(Tensor, NoGrad)

__main__.test_no_grad passed in 0.00s.
__main__.test_no_grad_nested passed in 0.00s.


In [52]:
def visualize(dataloader):
    """Call this if you want to see some of your data."""
    plt.figure(figsize=(12, 12))
    (sample, sample_labels) = next(iter(dataloader))
    for i in range(10):
        plt.subplot(5, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(sample[i, 0], cmap=plt.cm.binary)
    plt.show()
    
def get_mnist(subsample: Optional[int] = None):
    """Return MNIST data using the provided Tensor class."""
    mnist_train = datasets.MNIST("../data", train=True, download=True)
    mnist_test = datasets.MNIST("../data", train=False)
    if subsample is None:
        subsample = 1
    print("Preprocessing data...")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.28,), (0.35,))])
    train_indexes = range(0, len(mnist_train), subsample)
    train_reduced = [mnist_train[i] for i in train_indexes]
    train_tensors = torch.utils.data.TensorDataset(
        torch.stack([transform(img) for img, label in tqdm(train_reduced, desc="Training data")]),
        torch.tensor([label for img, label in train_reduced]),
    )

    test_indexes = range(0, len(mnist_test), subsample)
    test_reduced = [mnist_test[i] for i in test_indexes]
    test_tensors = torch.utils.data.TensorDataset(
        torch.stack([transform(img) for img, label in tqdm(test_reduced, desc="Test data")]),
        torch.tensor([label for img, label in test_reduced]),
    )

    train_loader = torch.utils.data.DataLoader(train_tensors, shuffle=True, batch_size=512)
    test_loader = torch.utils.data.DataLoader(test_tensors, batch_size=512)
    return train_loader, test_loader
    

subsample = 20 if IS_CI else None
(train_loader, test_loader) = get_mnist(subsample)
    
class SGD:
    def __init__(self, params: Iterable[Parameter], lr: float):
        """Vanilla SGD with no additional features."""
        self.params = list(params)
        self.lr = lr
        self.b = [None for _ in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    def step(self) -> None:
        with NoGrad():
            for (i, p) in enumerate(self.params):
                assert isinstance(p.grad, Tensor)
                p.add_(p.grad, -self.lr)
                


def train(model, train_loader, optimizer, epoch):
    for (batch_idx, (data, target)) in enumerate(train_loader):
        data = Tensor(data.numpy())
        target = Tensor(target.numpy())
        optimizer.zero_grad()
        output = model(data)
        loss = cross_entropy(output, target).sum() / len(output)
        loss.backward()
        optimizer.step()
        if batch_idx % 50 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


def test(model, test_loader):
    test_loss = 0
    correct = 0
    with NoGrad():
        for (data, target) in test_loader:
            data = Tensor(data.numpy())
            target = Tensor(target.numpy())
            output = model(data)
            test_loss += cross_entropy(output, target).sum().item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred == target.reshape(pred.shape)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
        )
    )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 82366633.21it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 38639774.74it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 19008527.03it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13406424.19it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

Preprocessing data...


Training data:   0%|          | 0/60000 [00:00<?, ?it/s]

Test data:   0%|          | 0/10000 [00:00<?, ?it/s]

In [53]:
num_epochs = 5
model = MLP()
start = time.time()
optimizer = SGD(model.parameters(), 0.01)
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, epoch)
    test(model, test_loader)
    optimizer.step()
    
print(f"Completed in {time.time() - start: .2f}s")


Test set: Average loss: 1.8699, Accuracy: 5055/10000 (51%)


Test set: Average loss: 1.0379, Accuracy: 7679/10000 (77%)


Test set: Average loss: 0.6645, Accuracy: 8418/10000 (84%)


Test set: Average loss: 0.5214, Accuracy: 8653/10000 (87%)


Test set: Average loss: 0.4510, Accuracy: 8793/10000 (88%)

Completed in  16.97s
