In [None]:
import torch
from torch import nn
import torch.autograd

from torchviz import make_dot  # https://github.com/szagoruyko/pytorchviz

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import IPython.display as disp

import pprint

### Initialise things

In [None]:
device = torch.device("cuda:0")
pp = pprint.PrettyPrinter(indent=4)

### Display functions 

In [None]:
def tinfo(tensor, name=None):
    if name is not None:
        print("{}:".format(name))
        print("------------------")
    disp.display(tensor.shape)
    disp.display(tensor.dtype)
    disp.display(tensor.device)
    print("------------------")
    disp.display(tensor)
    print()

# Simple computation graph example 

In [None]:
# Tensors:
# Note: wrapped in nn.Parameter(...) only because that allows printing the paramter name on the computational graph, not necessary for computation.
x = nn.Parameter(torch.tensor([[1, 2, 3, 4, 5]], dtype=float, requires_grad=True))
W = nn.Parameter(torch.tensor(np.array([[10, 10], [20, 20], [30, 30], [40, 40], [50, 50]]), dtype=float, requires_grad=True))
b = nn.Parameter(torch.tensor([1.2, 1.2], dtype=float, requires_grad=True))

tinfo(x, "x")
tinfo(W, "W")
tinfo(b, "b")

# Define computations:
xW = torch.mm(x, W)
z = xW + b
o = torch.log(z)

tinfo(xW, "xW = torch.mm(x, W)")
tinfo(z, "z = xW + b")
tinfo(o, "o = torch.log(z)")

In [None]:
# Name the parameters
params = {"x": x, "W": W, "b": b}

In [None]:
# Display graph:
make_dot(o, params=params)

# Model Example
* From https://github.com/szagoruyko/pytorchviz/blob/master/examples.ipynb

In [None]:
model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

x = torch.randn(1,8)

params_dict = dict(model.named_parameters())
pp.pprint(params_dict)

make_dot(model(x), params=params_dict)

# Implement gradient descent

In [None]:
# Data:
x_np = np.linspace(0, 2 * np.pi, 10)
y_np = np.sin(x_np)
disp.display(x_np, y_np)
plt.scatter(x_np, y_np)

In [None]:
# Model.

class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.W1 = nn.Linear(10, 10)
        self.b = nn.Parameter(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.float32, requires_grad=True))
    
    def forward(self, x):
        return self.W1(x) + self.b

In [None]:
# Loss.
def mse(y, t):
    return torch.mean((y - t)**2)
def mse_np(y, t):
    return np.mean((y - t)**2)

In [None]:
# Show computation graph.

x = torch.tensor([list(x_np)], dtype=torch.float32, requires_grad=False)  # requires_grad=True not needed?
y = torch.tensor([list(y_np)], dtype=torch.float32, requires_grad=False)

torch.manual_seed(9)
model = Model()

params_dict = dict(list(model.named_parameters()) + [("x", x)])
pp.pprint(params_dict)

make_dot(model(x), params=params_dict)

In [None]:
# Initial "prediction".

y_pred = model(x).squeeze().detach().numpy()

def show_plot(x, y, y_pred, ylim=(-1., 1.), print_mse=True):
    # print(y, y_pred)
    plt.cla()
    plt.scatter(x, y, color="red")
    plt.scatter(x, y_pred, color="blue")
    plt.ylim(ylim)
    if print_mse:
        print("MSE = {}".format(mse_np(y_pred, y)))

show_plot(x_np, y_np, y_pred, ylim=None)

In [None]:
# Implement GD.

VERB = False
use_manual_gd = True

torch.manual_seed(9)
model = Model()

iters = 20
lr = 0.01

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0)

for it in range(iters):
    
    print("------------------------------\nit {}:\n------------------------------\n".format(it + 1))
    
    # Zero the gradients.
    with torch.no_grad():
        for name, parameter in model.named_parameters():
            if parameter.grad is not None:
                parameter.grad.data.zero_()
    
    for name, parameter in model.named_parameters():
        if VERB:
            if name == "b":
                print("\n{}:\n---\n".format(name))
                print("AT THE START")
                print(parameter)
                print()
    
    # Compute loss.
    y_pred = model(x)
    
    print(">>> y_pred:")
    print(y_pred)
    
    print(">>> y:")
    print(y)
    
    mse_loss = mse(y_pred, y)
    
    print(">>> mse = {}".format(mse_loss))
    
    # Backprop.
    mse_loss.backward(torch.tensor(1.))
    
    # Gradient descent.
    if use_manual_gd is False:
        for name, parameter in model.named_parameters():
            if VERB:
                if name == "b":
                    print("\n{}:\n---\n".format(name))
                    print("BEFORE")
                    print(parameter)
                    print("GRAD")
                    print(parameter.grad)
        opt.step()
        for name, parameter in model.named_parameters():
            if VERB:
                if name == "b":
                    print("\n{}:\n---\n".format(name))
                    print("AFTER")
                    print(parameter)
    
    if use_manual_gd is True:
        with torch.no_grad():  # Do not track.
            for name, parameter in model.named_parameters():
                if VERB:
                    if name == "b":
                        print("\n{}:\n---\n".format(name))
                        print("BEFORE")
                        print(parameter)
                        print("GRAD")
                        print(parameter.grad)
                parameter.data -= lr * parameter.grad.data
                if VERB:
                    if name == "b":
                        print("AFTER")
                        print(parameter)
    
    # Display.
    show_plot(x_np, y_np, y_pred.squeeze().detach().numpy(), ylim=(-4., 4.), print_mse=False)
    if it < iters-1:
        disp.clear_output(wait=True)
        disp.display(plt.gcf())

plt.close()

# Investigate graph node state 

In [None]:
def show_state(vars_tuple_list):
    df = pd.DataFrame(columns=("name", "requires_grad", "is_leaf", "requires_grad & is_leaf", "grad", "grad_fn"), index=list(range(len(vars_tuple_list))))
    for idx, (name, var) in enumerate(vars_tuple_list):
        df.loc[idx, "name"] = name
        df.loc[idx, "requires_grad"] = var.requires_grad
        df.loc[idx, "is_leaf"] = var.is_leaf
        df.loc[idx, "requires_grad & is_leaf"] = var.requires_grad & var.is_leaf
        df.loc[idx, "grad"] = str(var.grad)
        df.loc[idx, "grad_fn"] = str(var.grad_fn)
    return df

In [None]:
torch.manual_seed(9)
model = Model()

y_pred = model(x)
mse_loss = mse(y_pred, y)

params_list = [("x", x)] + list(model.named_parameters()) + [("y_pred", y_pred), ("y", y), ("mse_loss", mse_loss)]

# Graph.
print("\nCalculation Graph, all the way to the loss:")
chart = make_dot(mse_loss, params=dict(params_list))
disp.display(chart)

# Graph node info.
print("\nGraph nodes (BEFORE backprop):")
df = show_state(params_list)
disp.display(df)

# --------------------------
# Do one run of backprop.
mse_loss.backward()

# Graph node info.
print("\nGraph nodes (AFTER backprop):")
df = show_state(params_list)
disp.display(df)

In [None]:
# Investigate `retain_graph`

torch.manual_seed(9)
model = Model()

y_pred = model(x)
mse_loss = mse(y_pred, y)

params_list = [("x", x)] + list(model.named_parameters()) + [("y_pred", y_pred), ("y", y), ("mse_loss", mse_loss)]

# Graph node info.
print("\nGraph nodes (BEFORE backprop):")
df = show_state(params_list)
disp.display(df)

# --------------------------
# Do TWO RUNS of backprop.
mse_loss.backward(retain_graph=True)
mse_loss.backward()

# The below causes an exception:
# mse_loss.backward()
# mse_loss.backward()

# All the above does is compute 1st order derivative twice and accumulate.
# NOT 2nd order derivatives.

# Graph node info.
print("\nGraph nodes (AFTER backprop):")
df = show_state(params_list)
disp.display(df)

In [None]:
# Investigate `create_graph`

torch.manual_seed(9)
model = Model()

y_pred = model(x)
mse_loss = mse(y_pred, y)

params_list = [("x", x)] + list(model.named_parameters()) + [("y_pred", y_pred), ("y", y), ("mse_loss", mse_loss)]

# Graph.
print("\n`mse_loss` graph:")
chart = make_dot(mse_loss, params=dict(params_list))
disp.display(chart)

# --------------------------
# Do one run of backprop.
mse_loss.backward(retain_graph=True, create_graph=True)

# Can see that there's now a grath for model.W1.weight.**grad**
# Without create_graph=True, there is no graph there.
print("\n`model.W1.weight.grad` graph (after `backward()` with `create_graph=True`):")
W1w_grad = model.W1.weight.grad
chart = make_dot(W1w_grad)
disp.display(chart)

# Using `torch.autograd.grad()` 

In [None]:
# Manually calculate and update gradients using autograd.

# Use `torch.autograd.grad()`.
# `torch.autograd.backward()` can be used as equivalent to calling `.backward()` on the loss.

VERB = False

torch.manual_seed(9)
model = Model()

x.requires_grad_(True)

iters = 20
lr = 0.01

for it in range(iters):
    
    print("------------------------------\nit {}:\n------------------------------\n".format(it + 1))
    
    # Zero the gradients.
    with torch.no_grad():
        for name, parameter in model.named_parameters():
            if parameter.grad is not None:
                parameter.grad.data.zero_()
    
    # Compute loss.
    y_pred = model(x)
    
    print(">>> y_pred:")
    print(y_pred)
    
    print(">>> y:")
    print(y)
    
    mse_loss = mse(y_pred, y)
    
    print(">>> mse = {}".format(mse_loss))
    
    # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- 
    # Backprop.
    if VERB:
        print("-------------------")
        print("BEFORE UPDATE:")
        for p in model.parameters():
            print("parameter:")
            print(p)
            print("parameter.grad:")
            print(p.grad)
            print("----")
    
    params_list = list(model.parameters())
    # MANUALLY update the gradients (note that `only_inputs=True` no longer works in new PyTorch versions):
    grads = torch.autograd.grad(outputs=mse_loss, inputs=params_list, grad_outputs=torch.tensor(1.), retain_graph=False, create_graph=False)
    for idx, parameter in enumerate(params_list):
        if parameter.grad is None:  # Only relevant to the 0th iteration.
            parameter.grad = grads[idx]
        else:
            parameter.grad.data = grads[idx]
    
    if VERB:
        print("-------------------")
        print("AFTER UPDATE:")
        for p in model.parameters():
            print("parameter:")
            print(p)
            print("parameter.grad:")
            print(p.grad)
            print("----")
    # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- 
        
    with torch.no_grad():  # Do not track.
        for name, parameter in model.named_parameters():
            parameter.data -= lr * parameter.grad.data
    
    # Display.
    show_plot(x_np, y_np, y_pred.squeeze().detach().numpy(), ylim=(-4., 4.), print_mse=False)
    if it < iters-1:
        disp.clear_output(wait=True)
        disp.display(plt.gcf())

plt.close()

# Higher order derivatives 

In [None]:
# Simple 1D function:
x = torch.tensor([3.], requires_grad=True)
f = torch.exp(x) + 20 * torch.log(x)

# Show graph:
disp.display(make_dot(f, params={"x": x}))

print("\n f = {}. Expected: 42.05778...".format(f.item()))

# First order derivative wrt x: df/dx.
df_dx = torch.autograd.grad(outputs=f, inputs=x, grad_outputs=None, retain_graph=True, create_graph=True)[0]
print("\n df_dx = {}. Expected: 26.75220...".format(df_dx.item()))

# Second order derivative wrt x: d2f/dx2.
d2f_dx2 = torch.autograd.grad(outputs=df_dx, inputs=x, grad_outputs=None, retain_graph=False, create_graph=False)[0]
print("\n d2f_dx2 = {}. Expected: 17.8633...".format(d2f_dx2.item()))