In [28]:
# Simplified example

import gc
import resource
import torch
from torch.autograd import Variable, backward
from torch.optim import SGD

i = 0  # Iterations
T = 5  # Time steps
b = 3  # Batch size
step = 0.5
h = torch.rand(b, 10)
W = Variable(torch.randn(10, 10), requires_grad=True)
R = Variable(torch.randn(10, 10), requires_grad=True)
O = Variable(torch.randn(10, 10), requires_grad=True)

def register_grad_hook(var):
    def hook(grad):
        var.recoding_grad = grad

    var.register_hook(hook)
    
def get_mem():
    gc.collect()
    return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024

#optim = SGD(params=[W, R, O], lr=0.1)
prev_mem_used = get_mem()

while True:
    
    for t in range(T):
        # Dummy "RNN" logic
        # Placing this line here causes the memory spill
        h = Variable(h, requires_grad=True)
        h = h @ W  # Apply "RNN"
        
        # Placing this line here results in (almost) no new memory being 
        # allocated
        # But this doesn't make sense - we need to record operations that
        # happen during the RNN logic
        #h = Variable(h.data, requires_grad=True)
        register_grad_hook(h)

        # "Recoding"
        delta = h @ R
        backward(delta, grad_tensors=torch.ones(delta.shape))
        h = h - step * h.recoding_grad
        # Detach so recoding gradients only depend on stuff happening at
        # the current time step
        h = h.detach()
        
    # Main backward pass
    out = h @ O  # Predict "output"
    loss = out.sum()
    #print("Loss: ", loss)
    loss.backward()
    #optim.step()
    #optim.zero_grad()
    
    i += 1

    if i % 1000 == 0:
        current_mem_used = get_mem()
        print("Add. mem allocated: {:.2f} MB".format(current_mem_used - prev_mem_used))
        prev_mem_used = current_mem_used

Add. mem allocated: 26072.00 MB
Add. mem allocated: 26200.00 MB
Add. mem allocated: 26184.00 MB
Add. mem allocated: 26200.00 MB
Add. mem allocated: 26228.00 MB


KeyboardInterrupt: 