In [43]:
import torch
from torch import Tensor

from typing import *

In [46]:
def a_simple_function() -> None:
    # x = [1.0, 2.0, 3.0, 4.0], allocates memory of x and its gradient 
    x: Tensor = torch.arange(4.0, requires_grad=True)

    def func(x: Tensor) -> Tensor:
        """The function that we want to compute."""
        return 2 * torch.dot(x, x)

    def func_prime(x: Tensor) -> Tensor:
        """Derivative of func."""
        return 4 * x

    # Computes y
    y: Tensor = func(x)
    # The gradient if y w.r.t. x is not computed yet, so the gradient is still None
    assert x.grad is None

    # Compute the gradient of y w.r.t. x using backpropagation
    y.backward()

    assert x.grad is not None
    assert torch.allclose(x.grad, func_prime(x))

    z: Tensor = func(x)
    z.backward()

    # If we don't reset the gradient (set to `None`) and are not
    # zeroing it (calling *.grad.zero_()) then the gradients will accumulate 
    assert torch.allclose(x.grad, 2.0 * func_prime(x))
a_simple_function()

In [57]:
def backward_for_non_scalar_variables() -> None:
    x: Tensor = torch.arange(4.0, requires_grad=True)

    def func(x: Tensor) -> Tensor:
        """Elementwise Squaring"""
        return x ** 2 

    def func_prime(x: Tensor) -> Tensor:
        """d/dx(x^2) = 2x elementwise"""
        return 2 * x

    # Compute the Function
    y = func(x)

    # As y is a vector, it's derivative is a Jacobian matrix, so we need to give some
    # way to reduce teh backpropagation.
    y.backward(gradient=torch.ones(len(x)))

    # Alternatives we could also do direct backpropagation over the sum, that would be faster
    # y.sum().backward()

    assert torch.allclose(x.grad, func_prime(x))
backward_for_non_scalar_variables()

In [74]:
def duplicate_backpropagation() -> None:
    x: Tensor = torch.arange(4.0, requires_grad=True)

    y: Tensor = 2 * (x ** 5) 
    # After transversing the graph once, the nodes get freed, if we want to
    # backpropagate again, we need to set the retain_graph=True
    y.backward(gradient=torch.ones(len(x)), retain_graph=True)
    x_grad_once: Tensor = x.grad.clone()
    y.backward(gradient=torch.ones(len(x)))
    assert torch.allclose(x.grad, 2 * x_grad_once)
duplicate_backpropagation()