In [1]:
import gc
import resource
import torch
from torch.autograd import Variable, backward
from torch.optim import SGD
import time

def register_grad_hook(var):
    def hook(grad):
        var.recoding_grad = grad

    var.register_hook(hook)
    
def get_mem():
    gc.collect()
    return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024

In [2]:
BREAK = 10  # End after this many iterations
T = 10  # Time steps
b = 10  # Batch size
step = 0.5
h_init = torch.rand(b, 10)
W_init = Variable(torch.randn(10, 10), requires_grad=True)
R_init = Variable(torch.randn(10, 10), requires_grad=True)
O_init = Variable(torch.randn(10, 10), requires_grad=True)

In [3]:
# +++ Old version +++
# Use an actual optimizer to perform the recoding gradient step

h = h_init.clone()
W, R, O = W_init.clone(), R_init.clone(), O_init.clone()
prev_mem_used = get_mem()
old_gradients = []
i = 0

while True:
    
    for t in range(T):
        
        h = Variable(h, requires_grad=True)
        optim_h = SGD([h], lr=step)
        optim_h.zero_grad()
        
        h_prime = h @ W  # Apply "RNN"

        # "Recoding"
        delta = h_prime @ R
        backward(delta, grad_tensors=torch.ones(delta.shape))
        old_gradients.append(h.grad)
        optim_h.step()

        # Detach so recoding gradients only depend on stuff happening at
        # the current time step
        h = h.detach()
        
    # Main backward pass
    out = h @ O  # Predict "output"
    loss = out.sum()
    loss.backward()
    
    i += 1

    current_mem_used = get_mem()
    print("Add. mem allocated: + {:.2f} KB".format(current_mem_used - prev_mem_used))
    prev_mem_used = current_mem_used

    if i >= BREAK:
        break

Add. mem allocated: + 564.00 KB
Add. mem allocated: + 20.00 KB
Add. mem allocated: + 16.00 KB
Add. mem allocated: + 8.00 KB
Add. mem allocated: + 12.00 KB
Add. mem allocated: + 12.00 KB
Add. mem allocated: + 8.00 KB
Add. mem allocated: + 16.00 KB
Add. mem allocated: + 12.00 KB
Add. mem allocated: + 4.00 KB


In [4]:
# +++ New Version after refactoring +++
# Use gradient hook to assign gradient to special attribute
# Perform update step manually
# Cleaner but memory spill

h = h_init.clone()
W, R, O = W_init.clone(), R_init.clone(), O_init.clone()
prev_mem_used = get_mem()
faulty_gradients = []
i = 0

while True:
    
    for t in range(T):
        # Dummy "RNN" logic
        # Placing this line here causes the memory spill
        h = Variable(h, requires_grad=True)
        h = h @ W  # Apply "RNN"

        # "Recoding"
        delta = h @ R
        register_grad_hook(h)
        backward(delta, grad_tensors=torch.ones(delta.shape))
        faulty_gradients.append(h.recoding_grad)
        h = h - step * h.recoding_grad  # Manual update step
        
        # Detach so recoding gradients only depend on stuff happening at
        # the current time step
        h = h.detach()
        
    # Main backward pass
    out = h @ O  # Predict "output"
    loss = out.sum()
    loss.backward()
    
    i += 1

    current_mem_used = get_mem()
    print("Add. mem allocated: +{:.2f} KB".format(current_mem_used - prev_mem_used))
    prev_mem_used = current_mem_used
        
    if i >= BREAK:
        break

Add. mem allocated: +156.00 KB
Add. mem allocated: +52.00 KB
Add. mem allocated: +56.00 KB
Add. mem allocated: +52.00 KB
Add. mem allocated: +56.00 KB
Add. mem allocated: +60.00 KB
Add. mem allocated: +52.00 KB
Add. mem allocated: +44.00 KB
Add. mem allocated: +64.00 KB
Add. mem allocated: +60.00 KB


In [5]:
# So what was the difference here? We assigned h = h @ W to the same 
# variable, somehow messing up the computational graph and creating a 
# memory spill (the code still runs but produces faulty gradients and at
# some point an OOM exception).
# If we did this with the old version, hidden.grad would simply by None.

h = h_init.clone()
W, R, O = W_init.clone(), R_init.clone(), O_init.clone()
prev_mem_used = get_mem()
new_gradients = []
i = 0

while True:
    
    for t in range(T):
        # Dummy "RNN" logic
        # Placing this line here causes the memory spill
        h = Variable(h, requires_grad=True)
        
        # +++ Only difference! Call this new car h_prime
        # and don't reassign to h +++
        h_prime = h @ W  # Apply "RNN"

        # "Recoding"
        delta = h_prime @ R
        register_grad_hook(h)
        backward(delta, grad_tensors=torch.ones(delta.shape))
        new_gradients.append(h.recoding_grad)
        h = h - step * h.recoding_grad  # Manual update step
        
        # Detach so recoding gradients only depend on stuff happening at
        # the current time step
        h = h.detach()
        
    # Main backward pass
    out = h @ O  # Predict "output"
    loss = out.sum()
    loss.backward()
    
    i += 1

    current_mem_used = get_mem()
    print("Add. mem allocated: +{:.2f} KB".format(current_mem_used - prev_mem_used))
    prev_mem_used = current_mem_used
        
    if i >= BREAK:
        break

Add. mem allocated: +44.00 KB
Add. mem allocated: +12.00 KB
Add. mem allocated: +20.00 KB
Add. mem allocated: +16.00 KB
Add. mem allocated: +8.00 KB
Add. mem allocated: +20.00 KB
Add. mem allocated: +0.00 KB
Add. mem allocated: +12.00 KB
Add. mem allocated: +12.00 KB
Add. mem allocated: +8.00 KB


In [6]:
# Test whether gradients are the same
stacked_old = torch.stack(old_gradients)
stacked_new = torch.stack(new_gradients)
stacked_faulty = torch.stack(faulty_gradients)
print("Diff old - faulty: ", ((stacked_old - stacked_faulty) ** 2).sum().item())
print("Diff old - new: ", ((stacked_old - stacked_new) ** 2).sum().item())

Diff old - faulty:  2157421.25
Diff old - new:  0.0
