# Dynamic Graphs

## "Static" declaration vs "Dynamic" declaration

The majority of DL frameworks use a "static" declaration paradigm for the definition of the computational graph. This programming paradigm requires programmers to define the network architecture once using symbolic expressions before beginning execution (running the session). Then, for a given graph and data samples, the software toolkits can automatically derive the correct algorithm for training or inference following backpropagation and auto-differentiation rules. This procedure can be described with the following pseudo-code:

<img src="files/static.png" width="600">

- execution naturally batched
- improved parallelization
- optimization of the graph at declaration time

However, in some fields we could need a dynamic NN. This dynamicity can come from multiple dimensions, like: variably sized I/O variably structured I/O, non-trivial inference algorithms. Consider this example:

<img src="files/dybamic_example.png" width="400">


The above figure shows an example of a network that takes into account this syntactic structure, generating representations for the sentence by traversing the parse tree bottom-up and combining the representations for each sub-tree using a dynamic NN called **Tree Structured Long Short-term Memory (Tree-LSTM)**. Each node of the tree maps to a LSTM function. Each node takes a variable number of inputs and returns a vector representing the parsing semantics up to that point back to the leaf node. This goes on until the root LSTM node returns a vector representing the semantics of the entire sentence.

**It’s important to observe that the NN structure varies with the underlying parsing tree over each input sample, but the same LSTM cell (i.e. the parametrization point of the model) is constant in shape and repeated at each internal node.**

Such architecture cannot be defined in a static way. Every possible sample could need a different graph. We need the dynamic declaration:

<img src="files/dynamic.png" width="600">

- batching now is not "natural"
- parallelization can be harder. In the wrost case, you must use BATCH_SIZE = 1
- more difficult to debug


## pyTorch Dynamic declaration

We have seen that the computational graph is built in the forward pass of the NN class.
Since each forward pass builds a dynamic computation graph, we can use normal Python control-flow operators like loops or conditional statements when defining the forward pass of the model. All these are perfectly legal, and will be handled correctly by autograd.

To showcase the power of PyTorch dynamic graphs, we will implement a very strange model: a fully-connected ReLU network that on each forward pass randomly chooses a number between 1 and 4 and has that many hidden layers, reusing the same weights multiple times to compute the innermost hidden layers. We will use this for MNIST classification.

In [2]:
from __future__ import print_function
import argparse
import torch
import numpy as np
from torch.nn.functional import sigmoid, relu
import random
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from livelossplot import PlotLosses

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
EPOCHS = 8
LOG_INTERVAL = 100   # how many batches to wait before logging training status

LR = 0.01
MOMENTUM = 0.5

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = x.view(-1, 784)
        x = relu(self.input_linear(x))
        coin = random.randint(0, 3)   # throw a coin to choose {0,1,2,3}
        for _ in range(coin):         # add "coin" number of layers 
            x = relu(self.middle_linear(x))
        x = relu(self.output_linear(x))
        return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch, liveloss):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        liveloss.update({'loss': loss.item()})
        #liveloss.draw()
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def main():


    seed = 666
    torch.manual_seed(seed)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # Download datasets
    train_MNIST = datasets.MNIST('../data', train=True, download=True,
                                 transform=transforms.Compose([transforms.ToTensor(),
                                                               transforms.Normalize((0.1307,), (0.3081,))]))
    test_MNIST =  datasets.MNIST('../data', train=False,
                                 transform=transforms.Compose([transforms.ToTensor(),
                                                               transforms.Normalize((0.1307,), (0.3081,))]))
    # Fork into the DataLoader object
    train_loader = torch.utils.data.DataLoader(dataset=train_MNIST,batch_size=TRAIN_BATCH_SIZE, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_MNIST, batch_size=TEST_BATCH_SIZE, shuffle=True, **kwargs)

    # Define the model (remember to assign it to cuda)
    D_in, H, D_out = 784, 100, 10
    model = DynamicNet(D_in, H, D_out).to(device)
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
    
    liveloss_train = PlotLosses()
    for epoch in range(1, EPOCHS + 1):
        train(model, device, train_loader, optimizer, epoch, liveloss_train)
        test(model, device, test_loader)


if __name__ == '__main__':
    main()


Test set: Average loss: 0.5941, Accuracy: 8227/10000 (82%)


Test set: Average loss: 0.3436, Accuracy: 8998/10000 (90%)


Test set: Average loss: 0.2839, Accuracy: 9195/10000 (92%)


Test set: Average loss: 0.2371, Accuracy: 9305/10000 (93%)


Test set: Average loss: 0.2311, Accuracy: 9328/10000 (93%)


Test set: Average loss: 0.1874, Accuracy: 9452/10000 (95%)


Test set: Average loss: 0.1711, Accuracy: 9492/10000 (95%)


Test set: Average loss: 0.1728, Accuracy: 9496/10000 (95%)



Another example can be the following: let's throw a coin and create a "tree like" structure randomly, varying both in the number of layers and neurons per layer

In [4]:
from __future__ import print_function
import argparse
import torch
import numpy as np
from torch.nn.functional import sigmoid, relu
import random
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from livelossplot import PlotLosses

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
EPOCHS = 8
LOG_INTERVAL = 100   # how many batches to wait before logging training status

LR = 0.01
MOMENTUM = 0.5

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H1, H2, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H1)
        self.middle_linear1 = torch.nn.Linear(H1, H2)
        self.middle_linear2 = torch.nn.Linear(H2, H1)
        self.middle_linear3 = torch.nn.Linear(H1, H1)
        self.output_linear = torch.nn.Linear(H1, D_out)

    def forward(self, x):
        x = x.view(-1, 784)
        x = relu(self.input_linear(x))
        
        coin = random.randint(0, 3)
        if coin == 1:
            x = relu(self.middle_linear1(x))
        elif coin == 2:
            x = relu(self.middle_linear1(x))
            x = relu(self.middle_linear2(x))
        elif coin == 3:
            x = relu(self.middle_linear1(x))
            x = relu(self.middle_linear2(x))
            x = relu(self.middle_linear3(x))
        else:
            x = relu(self.output_linear(x))
        return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch, liveloss):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        liveloss.update({'loss': loss.item()})
        #liveloss.draw()
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def main():


    seed = 666
    torch.manual_seed(seed)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # Download datasets
    train_MNIST = datasets.MNIST('../data', train=True, download=True,
                                 transform=transforms.Compose([transforms.ToTensor(),
                                                               transforms.Normalize((0.1307,), (0.3081,))]))
    test_MNIST =  datasets.MNIST('../data', train=False,
                                 transform=transforms.Compose([transforms.ToTensor(),
                                                               transforms.Normalize((0.1307,), (0.3081,))]))
    # Fork into the DataLoader object
    train_loader = torch.utils.data.DataLoader(dataset=train_MNIST,batch_size=TRAIN_BATCH_SIZE, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_MNIST, batch_size=TEST_BATCH_SIZE, shuffle=True, **kwargs)

    # Define the model (remember to assign it to cuda)
    D_in, H1, H2, D_out = 784, 100, 300, 10
    model = DynamicNet(D_in, H1, H2, D_out).to(device)
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
    
    liveloss_train = PlotLosses()
    for epoch in range(1, EPOCHS + 1):
        train(model, device, train_loader, optimizer, epoch, liveloss_train)
        test(model, device, test_loader)


if __name__ == '__main__':
    main()


Test set: Average loss: 0.7379, Accuracy: 8389/10000 (84%)


Test set: Average loss: 0.7323, Accuracy: 8476/10000 (85%)


Test set: Average loss: 0.4656, Accuracy: 8864/10000 (89%)


Test set: Average loss: 0.6679, Accuracy: 8495/10000 (85%)


Test set: Average loss: 0.6558, Accuracy: 8559/10000 (86%)


Test set: Average loss: 0.5252, Accuracy: 8800/10000 (88%)



Process Process-34:
KeyboardInterrupt
Traceback (most recent call last):
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python2.7/dist-packages/torchvision/datasets/mnist.py", line 77, in __getitem__
    img = self.transform(img)
  File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms/transforms.py", line 49, in __call__
    img = t(img)
  File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms/transforms.py", line 76, in __call__
    return F.to_tensor(pic)
  File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms/functional.py", line 70, in to_tensor
    img = torch.ByteTensor(torch.ByteStorage.

Traceback (most recent call last):
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/ultratb.py", line 1132, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/ultratb.py", line 313, in wrapped
    return f(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/ultratb.py", line 376, in _fixed_getinnerframes
    lines = ulinecache.getlines(file)[start:end]
  File "/usr/local/lib/python2.7/dist-packages/IPython/utils/ulinecache.py", line 37, in getlines
    return [l.decode(encoding, 'replace') for l in lines]
  File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 227, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 11671) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.


IndexError: string index out of range

Basically, in your model definition you can go full crazy and use arbitrary python code to define your model structure. Conditional statements can be done even evaluating some input tensor propriety, like:

```
 while x.norm(2) < 10:
     x = self.conv1(x)
```
