This is an investigation into graph and state, which could be turned into a real tutorial or appended to the existing "pytorch for torchies" tutorial.

Also see https://discuss.pytorch.org/t/understanding-graphs-and-state/224

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

## State/buffer clearing during backward pass

Behaves a little different than I expect.

This is straight from the "pytorch for torchies" tutorial.

In [None]:
x = Variable(torch.ones(2, 2), requires_grad = True)
y = x + 2
z = y * y * 3
out = z.mean()
# backward on part of the graph
z.backward(torch.range(1,4).view(2,2)) # grad_outp[1,2,3,4] * 6*(xi+2)
print(x.grad)
# out.backward() # fails cause buffers have been freed

So we'd expect this to fail too but it doesnt:

In [None]:
x = Variable(torch.ones(2,3), requires_grad=True)
y = x.mean(dim=1).squeeze() + 3 # size (2,)
z = y.pow(2).mean() # size 1

In [None]:
y.backward(torch.ones(2))
z.backward() # should fail! But only fails on second execution
y.backward(torch.ones(2)) # still fine, though we're calling it for the second time
z.backward() # this fails (finally!)

My guess: it's not guaranteed that an error is raised on the second backward pass through part of the graph. But of course if we need to keep buffers on part of the graph, we have to supply retain_variables=True. Cause buffers *could* have been freed.

Probably the specific simple operations for y (mean, add) don't need buffers for backward, while the `z=y.pow(2).mean()` does need a buffer to store the result of `y.pow(2)`. correct?

## Auxilary loss functions on small convnet

Next question is about when a new graph (and thus new state) is allocated.

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3)
        self.conv2 = nn.Conv2d(1, 1, 3)
    def forward(self, x):
        out1 = F.relu(self.conv1(x))
        out2 = F.relu(self.conv2(out1))
        return out1, out2
net = Net()
inp = Variable(torch.randn(1,1,6,6))
inp2 = Variable(torch.randn(1,1,6,6))

In [None]:
out1, out2 = net(inp)

In [None]:
net.zero_grad()
print(net.conv1.weight.grad)
print(net.conv2.weight.grad)
out1.backward(torch.ones(1,1,4,4), retain_variables=True) # out2.backward fails without the flag, as expected
out2.backward(torch.ones(1,1,2,2))
print(net.conv1.weight.grad)
print(net.conv2.weight.grad)

Everything as expected here.

As explained in tutorial, two different inputs through a net will give different graphs, and thus hold different state:

In [None]:
_, out = net(inp)
_, out2 = net(inp2)
out.backward(torch.ones(1,1,2,2))
out2.backward(torch.ones(1,1,2,2))
# out.backward(torch.ones(1,1,2,2)) # fails as expected, buffers are freed

But two times the same variable, doesnt overwrite state in the same graph, but rather the two forward passes become separate graphs?

In [None]:
_, out = net(inp)
_, out2 = net(inp) # same input this time
out.backward(torch.ones(1,1,2,2))
out2.backward(torch.ones(1,1,2,2)) # doesnt fail -> has a different state than the first fw pass?!

The problem I see with this design is that often (during testing, or when you detach() to cut off gradients, or anytime you add an extra operation just for monitoring) there's just a fw-pass on part of the graph - so is that state then kept around forever and just starts consuming more memory on every new fw-pass of the same variable?

I understand that the volatile flag is probably introduced for this problem and I see it's used during testing in most example code.

But I think these are some examples where there's just fw-pass without `volatile` flag:

+ `fake = netG(noise).detach()` to avoid bpropping through netG  https://github.com/pytorch/examples/blob/master/dcgan/main.py#L216
+ test on non-volatile variables: https://github.com/pytorch/examples/blob/master/super_resolution/main.py#L74
+ If you finetune only top layers of a feedforward net, bottom layers see only fw-passes

But in general, if I understand this design correctly, this means anytime you have a part of a network which isn't backpropped through, you need to supply volatile flag? Then when you use that intermediate volatile variable in another part of the network which is backpropped through, you need to re-wrap and turn volatile off?

## 2 sequentials from same list of modules

Check if we make a separate Sequential() from same list of modules, and input same variable, we get same graph?

Question actually becomes irrelevant if above observation is correct.

In [None]:
mods=[
 nn.Conv2d(1, 1, 3),
 nn.ReLU(),
 nn.Conv2d(1, 1, 3),
 nn.ReLU(),
 nn.Conv2d(1, 1, 3)]
seq1 = nn.Sequential(*mods)
seq2 = nn.Sequential(*mods)
inp  = Variable(torch.randn(1,1,7,7))

In [None]:
out1=seq1(inp).squeeze()
out2=seq2(inp).squeeze()
net.zero_grad()
out1.backward()
out2.backward() # doesnt fail, so is separate graph with separate state.
# out1.backward() # fails