In [None]:
import IPython.display as disp
import pandas as pd

import torch

from torchviz import make_dot

In [None]:
def _print(text):
    print("\n{}".format(text))
    
def print_(text):
    print("{}\n".format(text))

def _print_(text):
    print("\n{}\n".format(text))
    
def print_all(W, b, x, y, t):
    variables = (W, b, x, y, t)
    names = ("W", "b", "x", "y", "t")
    for idx, v in enumerate(variables):
        print("{}.data:\n{}".format(names[idx], v))
        print("{}.grad:\n{}".format(names[idx], v.grad))

def print_all_id(id_, W, b, x, y, t):
    variables = (W, b, x, y, t)
    names = ["{}{}".format(a, id_) for a in ("W", "b", "x", "y", "t")]
    print("-- Parameters of fast model {} --".format(id_))
    for idx, v in enumerate(variables):
        print("{}.data:\n{}".format(names[idx], v))
        print("{}.grad:\n{}".format(names[idx], v.grad))
    print()

def print_wb(W, b):
    variables = (W, b)
    names = ("W", "b")
    print("-- W, b --")
    for idx, v in enumerate(variables):
        print("{}.data:\n{}".format(names[idx], v))
        print("{}.grad:\n{}".format(names[idx], v.grad))
    print("-- -- --\n")

In [None]:
# Check the double gradient calculation.

CREATE_GRAPH = True  # Toggling this toggles second order calculation's inclusion of gradients.


# Define "model".

W = torch.tensor([[1., 0.5], [-0.8, 1.3]], requires_grad=True)
b = torch.tensor([[1., 1.]], requires_grad=True)

x = torch.tensor([[4.5, 9.5]], requires_grad=True)

y = torch.matmul(x, W) + b
t = torch.tensor([[2., 3.]], requires_grad=False)

disp.display(make_dot(y, params={"W": W, "b": b, "x": x, "y": y}))
print_all(W, b, x, y, t)

loss = ((t - y)**2).mean()
_print("loss = {}".format(loss))



_print_("------------------------------------------------------------------")

print("Calling autograd.grad...\n")
grads = torch.autograd.grad(outputs=loss, inputs=(W, b, x), grad_outputs=torch.tensor(1.), retain_graph=True, create_graph=CREATE_GRAPH)
print_("grads calculated: {}".format(grads))

print_("Updating W, b...")
W.grad = grads[0]
b.grad = grads[1]
W = W - 0.1 * W.grad
b = b - 0.1 * b.grad

print_all(W, b, x, y, t)

_print("W graph (after updating):")
disp.display(make_dot(W))

t2 = torch.tensor([[2.5, 3.5]], requires_grad=False)
y2 = torch.matmul(x, W) + b
# disp.display(make_dot(y2))
loss2 = ((t2 - y2)**2).mean()



_print_("------------------------------------------------------------------")

print_("Calling autograd.grad 2nd time...")
print("grads:")
grads = torch.autograd.grad(outputs=loss2, inputs=(W, b, x), grad_outputs=torch.tensor(1.), retain_graph=False, create_graph=False)
print_(grads)


In [None]:
CREATE_GRAPH = True

# Learning rates.
lri = 0.001  # Inner.
lro = 0.05  # Outer.



# Theta parameters:
    
W = torch.tensor([[1., 0.5], [-0.8, 1.3]], requires_grad=True)
b = torch.tensor([[1., 1.]], requires_grad=True)
print("W, b initially:")
print_wb(W, b)



# -------------------------------------------------------
# Fast model 1.

_print_("------------------------------------------------------------------")
print_("Fast model 1:")

W1 = W.clone()
b1 = b.clone()

x1 = torch.tensor([[4.5, 9.5]], requires_grad=True)

y1 = torch.matmul(x1, W1) + b1
t1 = torch.tensor([[2., 3.]], requires_grad=False)

print_all_id(1, W1, b1, x1, y1, t1)

loss1 = ((t1 - y1)**2).mean()
_print_("loss1 = {}".format(loss1))
print("loss1 graph:")
disp.display(make_dot(loss1, params={"W": W, "b": b, "W1": W1, "b1": b1, "x1": x1, "y1": y1, "loss1": loss1}))

grads1 = torch.autograd.grad(outputs=loss1, inputs=(W1, b1), grad_outputs=torch.tensor(1.), retain_graph=True, create_graph=CREATE_GRAPH)
W1.grad = grads1[0]
b1.grad = grads1[1]
W1 = W1 - lri * W1.grad
b1 = b1 - lri * b1.grad



# -------------------------------------------------------
# Fast model 2.

_print_("------------------------------------------------------------------")
print_("Fast model 2:")

W2 = W.clone()
b2 = b.clone()

x2 = torch.tensor([[11.1, 22.2]], requires_grad=True)

y2 = torch.matmul(x2, W2) + b2
t2 = torch.tensor([[4., 7.]], requires_grad=False)

print_all_id(2, W2, b2, x2, y2, t2)

loss2 = ((t2 - y2)**2).mean()
_print_("loss2 = {}".format(loss2))
print("loss2 graph:")
disp.display(make_dot(loss2, params={"W": W, "b": b, "W2": W2, "b2": b2, "x2": x2, "y2": y2, "loss2": loss2}))

grads2 = torch.autograd.grad(outputs=loss2, inputs=(W2, b2), grad_outputs=torch.tensor(1.), retain_graph=True, create_graph=CREATE_GRAPH)
W2.grad = grads2[0]
b2.grad = grads2[1]
W2 = W2 - lri * W2.grad
b2 = b2 - lri * b2.grad



# -------------------------------------------------------
# Meta updagte.

# print_wb(W, b)  # <-- To check these theta parameters are unchanged from their initial values.

_print_("------------------------------------------------------------------")
print("Meta update...")

# Calculate losses using theta-dash (updated) parameters.
y1_ = torch.matmul(x1, W1) + b1
loss1_ = ((t1 - y1_)**2).mean()
_print("loss1_ = {}".format(loss1_))
y2_ = torch.matmul(x2, W2) + b2
loss2_ = ((t2 - y2_)**2).mean()
print_("loss2_ = {}".format(loss2_))

sumloss_ = loss1_ + loss2_
print_("sumloss_ = {}".format(sumloss_))

print("sumloss_ graph:")
disp.display(make_dot(sumloss_))

grads_meta = torch.autograd.grad(outputs=sumloss_, inputs=(W, b), grad_outputs=torch.tensor(1.), retain_graph=False, create_graph=False)
W.grad = grads_meta[0]
b.grad = grads_meta[1]
W = W - lro * W.grad
b = b - lro * b.grad

print("W, b after update:")
print_wb(W, b)