# Part 4: More advanced networks

__Before starting, we recommend you enable GPU acceleration if you're running on Colab.__

In [16]:
# Execute this code block to install dependencies when running on colab
try:
    import torch
except:
    from os.path import exists
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

try:
    import torchbearer
except:
    !pip install torchbearer

Recent network models, such as the deep residual network (ResNet) and GoogLeNet architectures, do not follow a straight path from input to output. Instead, these models incorporate branches and merges to create a computation graph. Branching and merging is easy to implement in PyTorch as shown in the following code snippet:

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

class BranchModel(nn.Module):
    def __init__(self):
        super(BranchModel, self).__init__()
        self.left  = nn.Conv2d(1, 16, (1, 1), padding=0)
        self.right = nn.Conv2d(1, 16, (5, 5), padding=2)
        self.fc1 = nn.Linear(16*14*14, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        out_l = self.left(x)
        out_l = F.relu(out_l)

        out_r = self.right(x)
        out_r = F.relu(out_r)

        out = out_l + out_r

        out = F.max_pool2d(out, (2,2))
        out = F.dropout(out, 0.2)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)

        return out

This defines a variant of our initial simple CNN model in which the input is split into two paths and then merged again; the left hand path consists of a 1x1 convolution layer, whilst the right-hand path has a 5x5 convolutional layer. The 1x1 convolutions will have the effect of increasing the number of bands in the input from 1 to 16 (with each band a (potentially different) scalar multiple of the input). Padding is used to ensure the feature maps have the same shape on the left and right branches. In this case the left and right branches are merged by summing them together (element-wise, layer by layer).

__Use the code block below to train and evaluate the above model.__

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchbearer
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchbearer import Trial

seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# convert each image to tensor format
transform = transforms.Compose([
    transforms.ToTensor()  # convert to tensor
])

# load data
trainset = MNIST(".", train=True, download=True, transform=transform)
testset = MNIST(".", train=False, download=True, transform=transform)

# create data loaders
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 14837933.37it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 101623904.21it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 3761883.30it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20662178.71it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [None]:
print(trainloader.batch_sampler.sampler.)

AttributeError: ignored

In [None]:
# build the model
model = BranchModel()

# define the loss function and the optimiser
loss_function = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters())

device = "cuda:0" if torch.cuda.is_available() else "cpu"
trial = Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)
trial.with_generators(trainloader, test_generator=testloader)
trial.run(epochs=10)
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

# {'test_loss': 0.05024200305342674, 'test_acc': 0.9830999970436096}

0/10(t):   0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

## Going further

None of the network topology we have experimented with thus far are optimised. Nor are they reproductions of network topologies from recent papers.

__There is a lot of opportunity for you to tune and improve upon these models. What is the best error rate score you can achieve?__


In [22]:
import torch
import torch.nn.functional as F
from torch import nn

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchbearer
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchbearer import Trial

import torch
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch import nn
import torchbearer

class MyDataset(Dataset):
  def __init__(self, size=5000, dim=40, random_offset=0):
        super(MyDataset, self).__init__()
        self.size = size
        self.dim = dim
        self.random_offset = random_offset

  def __getitem__(self, index):
      if index >= len(self):
          raise IndexError("{} index out of range".format(self.__class__.__name__))

      rng_state = torch.get_rng_state()
      torch.manual_seed(index + self.random_offset)

      while True:
        img = torch.zeros(self.dim, self.dim)
        dx = torch.randint(-10,10,(1,),dtype=torch.float)
        dy = torch.randint(-10,10,(1,),dtype=torch.float)
        c = torch.randint(-20,20,(1,), dtype=torch.float)

        params = torch.cat((dy/dx, c))
        xy = torch.randint(0,img.shape[1], (20, 2), dtype=torch.float)
        xy[:,1] = xy[:,0] * params[0] + params[1]

        xy.round_()
        xy = xy[ xy[:,1] > 0 ]
        xy = xy[ xy[:,1] < self.dim ]
        xy = xy[ xy[:,0] < self.dim ]

        for i in range(xy.shape[0]):
          x, y = xy[i][0], self.dim - xy[i][1]
          img[int(y), int(x)]=1
        if img.sum() > 2:
          break

      torch.set_rng_state(rng_state)
      return img.unsqueeze(0), params

  def __len__(self):
      return self.size

train_data = MyDataset()
val_data = MyDataset(size=500, random_offset=33333)
test_data = MyDataset(size=500, random_offset=99999)

In [24]:
trainloader = DataLoader(train_data, batch_size=128, shuffle=True)
testloader = DataLoader(test_data, batch_size=128, shuffle=True)
valloader = DataLoader(val_data, batch_size=128, shuffle=True, )

print(trainloader.dataset[0][0])


tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])


In [47]:
from torch.cuda import OutOfMemoryError
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 48, (3, 3), padding=1, stride=1)
        self.conv2 = nn.Conv2d(48, 48, (3, 3), padding=1, stride=1)
        self.maxpool2d = nn.AdaptiveMaxPool2d((1,1))
        self.fc1 = nn.Linear(48, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = F.relu(out)
        out = self.maxpool2d(out)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

# build the model
model = CNNModel()

# define the loss function and the optimiser
loss_function = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters())

device = "cuda:0" if torch.cuda.is_available() else "cpu"
trial = Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)
trial.with_generators(trainloader, test_generator=testloader, val_generator=valloader)
trial.run(epochs=100)
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

#1:  {'test_loss': -27173800.0, 'test_acc': 0.628000020980835}

0/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

0/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

1/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

1/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

2/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

2/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

3/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

3/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

4/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

4/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

5/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

5/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

6/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

6/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

7/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

7/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

8/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

8/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

9/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

9/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

10/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

10/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

11/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

11/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

12/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

12/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

13/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

13/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

14/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

14/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

15/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

15/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

16/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

16/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

17/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

17/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

18/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

18/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

19/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

19/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

20/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

20/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

21/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

21/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

22/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

22/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

23/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

23/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

24/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

24/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

25/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

25/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

26/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

26/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

27/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

27/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

28/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

28/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

29/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

29/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

30/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

30/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

31/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

31/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

32/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

32/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

33/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

33/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

34/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

34/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

35/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

35/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

36/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

36/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

37/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

37/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

38/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

38/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

39/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

39/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

40/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

40/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

41/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

41/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

42/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

42/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

43/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

43/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

44/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

KeyboardInterrupt: ignored