In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [17]:
torch.manual_seed(64)

<torch._C.Generator at 0x222a449c150>

In [18]:
# Utility function to compare manual and PyTorch gradients
def cmp(s, dt, t):
    """
    Compare manual gradients (dt) with PyTorch autograd gradients (t.grad).

    Args:
        s (str): Description or name of the gradient being compared.
        dt (torch.Tensor): Manually computed gradient.
        t (torch.Tensor): PyTorch tensor with autograd gradient.

    Prints:
        Comparison results including exact match, approximate match, and maximum difference.
    """
    exact = torch.all(dt == t.grad).item()
    approx = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(
        f"{s:15s} | exact: {str(exact):5s} | approximate: {str(approx):5s} | maxdiff: {maxdiff:.6f}"
    )

In [19]:
# Load Tiny Shakespeare Data
# ------------------------------------------------------------------------------------
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [
    stoi[c] for c in s
]  # encoder: take a string, output a list of integers
decode = lambda l: "".join(
    [itos[i] for i in l]
)  # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
# -------------------------------------------------------------------------------------

# Model Variables
# ---------------------------------------------------------------------------------
vocab_size = len(itos)
block_size = 8
d_model = 24  # embedding dimension
n_hidden = 200
batch_size = 1


# ------------------------------------------------------------------------------------
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    if split == "train":
        data = train_data
        ix = torch.randint(len(data) - block_size, (batch_size,))
        x = torch.stack([data[i : i + block_size] for i in ix])
        y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
        x, y = x.to(device), y.to(device)
    else:
        data = val_data
        ix = torch.arange(len(data) - block_size)
        x = torch.stack([data[i : i + block_size] for i in ix])
        y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
        x, y = x.to(device), y.to(device)
    return x, y

In [20]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, d_model, n_hidden, block_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_hidden = n_hidden
        self.block_size = block_size
        self.C = nn.Embedding(self.vocab_size, self.d_model)
        self.U = nn.Linear(self.d_model, self.n_hidden, bias=True)
        self.W = nn.Linear(n_hidden, n_hidden, bias=False)
        self.tan = nn.Tanh()
        self.V = nn.Linear(n_hidden, vocab_size, bias=True)
        self.h_0 = 0.1 * torch.ones(1, n_hidden)

    def forward(self, xb, targets):
        B, T = xb.shape
        emb = self.C(xb)  # [B, T, d_model]
        h_t = self.h_0.to(xb.device)
        h_all = torch.zeros(B, T, n_hidden, device=xb.device)
        intermediate_tensors = []
        intermediate_embeddings = []
        for t in range(T):
            x_t = emb[:, t, :]  # [B, d_model]
            intermediate_embeddings.extend([x_t])
            a_t = self.U(x_t) + self.W(h_t)  # [B, n_hidden]
            h_t = self.tan(a_t)  # [B, n_hidden]
            h_all[:, t, :] = h_t
            h_t.retain_grad()
            intermediate_tensors.extend([h_t])
        logits = self.V(h_all) # broadcast
        logits.retain_grad()
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss, intermediate_embeddings, intermediate_tensors

    def get_parameters(self):
        b = self.U.bias
        c = self.V.bias
        return b, c, self.C.weight, self.U.weight, self.W.weight, self.V.weight

In [21]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
model = CharRNN(vocab_size, d_model, n_hidden, block_size)

In [22]:
# Forward pass and get PyTorch grads
xb, yb = get_batch("train")
logits, loss, intermediate_embeddings, intermediate_tensors = model(xb, yb)
b, c, C, U, W, V = model.get_parameters()
loss.backward()

Now we manually compute the gradients (see derivations.pdf), and use the function 
cmp to compare to the PyTorch gradients. Due to numerical over/underflow,
torch.all may return False, but as long as torch.allclose returns True we 
should be confident the gradients are correct.

In [23]:
with torch.no_grad():
    # Compute gradients
    one_hot = F.one_hot(yb, num_classes=vocab_size).float()
    dlogits = (1/block_size) * (F.softmax(logits, dim = -1) - one_hot) 

    dhT = dlogits[:, -1, :] @ V
    dhs = [dhT]
    for i in range(1, block_size):
        dhs.append(dhs[i - 1] @ torch.diag((1 - intermediate_tensors[-i] ** 2).view(-1)) @ W + dlogits[:, -(i + 1), :] @ V)

    dc = dlogits.sum(dim = 1)

    db = torch.zeros_like(b.view(1, -1))
    for i in range(block_size):
        db += dhs[i] @ torch.diag((1 - intermediate_tensors[-(i + 1)] ** 2).view(-1))

    dV = torch.zeros_like(V)
    for i in range(block_size):
        dV += dlogits[:, i, :].T @ intermediate_tensors[i]

    intermediate_tensors_2 = [model.h_0] + intermediate_tensors # include h_0
    dW = torch.zeros_like(W)
    for i in range(block_size):
        dW += (dhs[-(i + 1)] * (1 - intermediate_tensors_2[i + 1] ** 2)).T @ intermediate_tensors_2[i]

    dU = torch.zeros_like(U)
    for i in range(block_size):
        dU += (dhs[-(i + 1)] * (1 - intermediate_tensors[i] ** 2)).T @ intermediate_embeddings[i]

    dC = torch.zeros_like(C)
    for t in range(block_size):
        da_t = dhs[-(t + 1)].view(-1) * (1 - intermediate_tensors[t] ** 2).view(-1)
        dx_t = da_t @ U  # all emb_grads
        idx = xb.view(-1)[t]
        dC[idx, :] += dx_t

    # Compare gradients
    cmp("logits_grad", dlogits, logits)
    for i in range(block_size):
        string = f"h_{block_size - i}_grad"
        cmp(string, dhs[i], intermediate_tensors[-(i + 1)])
    cmp("c_grad", dc, c)

    cmp("b_grad", db, b)
    cmp("V_grad", dV, V)
    cmp("W_grad", dW, W)
    cmp("U_grad", dU, U)
    cmp("C_grad", dC, C)

logits_grad     | exact: False | approximate: True  | maxdiff: 0.000000
h_8_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_7_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_6_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_5_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_4_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_3_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_2_grad        | exact: False | approximate: True  | maxdiff: 0.000000
h_1_grad        | exact: False | approximate: True  | maxdiff: 0.000000
c_grad          | exact: False | approximate: True  | maxdiff: 0.000000
b_grad          | exact: False | approximate: True  | maxdiff: 0.000000
V_grad          | exact: False | approximate: True  | maxdiff: 0.000000
W_grad          | exact: False | approximate: True  | maxdiff: 0.000000
U_grad          | exact: False | approximate: True  | maxdiff: 0