In [1]:
!sudo nvidia-persistenced
!sudo nvidia-smi -ac 877,1530

Applications clocks set to "(MEM 877, SM 1530)" for GPU 00000000:00:1E.0
All done.


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container {width:95% !important;}</style>"))

from core import *
from torch_backend import *

colors = ColorMap()
draw = lambda graph: display(DotGraph({p: ({'fillcolor': colors[type(v)], 'tooltip': repr(v)}, inputs) for p, (v, inputs) in graph.items() if v is not None}))

### Network definitions

In [2]:
batch_norm = partial(BatchNorm, weight_init=None, bias_init=None)

def res_block(c_in, c_out, stride, **kw):
    block = {
        'bn1': batch_norm(c_in, **kw),
        'relu1': nn.ReLU(True),
        'branch': {
            'conv1': nn.Conv2d(c_in, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
            'bn2': batch_norm(c_out, **kw),
            'relu2': nn.ReLU(True),
            'conv2': nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
        }
    }
    projection = (stride != 1) or (c_in != c_out)    
    if projection:
        block['conv3'] = (nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False), ['relu1'])
    block['add'] =  (Add(), [('conv3' if projection else 'relu1'), 'branch/conv2'])
    return block

def DAWN_net(c=64, block=res_block, prep_bn_relu=False, concat_pool=True, **kw):    
    if isinstance(c, int):
        c = [c, 2*c, 4*c, 4*c]
        
    classifier_pool = {
        'in': Identity(),
        'maxpool': nn.MaxPool2d(4),
        'avgpool': (nn.AvgPool2d(4), ['in']),
        'concat': (Concat(), ['maxpool', 'avgpool']),
    } if concat_pool else {'pool': nn.MaxPool2d(4)}
    
    return {
        'input': (None, []),
        'prep': union({'conv': nn.Conv2d(3, c[0], kernel_size=3, stride=1, padding=1, bias=False)},
                      {'bn': batch_norm(c[0], **kw), 'relu': nn.ReLU(True)} if prep_bn_relu else {}),
        'layer1': {
            'block0': block(c[0], c[0], 1, **kw),
            'block1': block(c[0], c[0], 1, **kw),
        },
        'layer2': {
            'block0': block(c[0], c[1], 2, **kw),
            'block1': block(c[1], c[1], 1, **kw),
        },
        'layer3': {
            'block0': block(c[1], c[2], 2, **kw),
            'block1': block(c[2], c[2], 1, **kw),
        },
        'layer4': {
            'block0': block(c[2], c[3], 2, **kw),
            'block1': block(c[3], c[3], 1, **kw),
        },
        'final': union(classifier_pool, {
            'flatten': Flatten(),
            'linear': nn.Linear(2*c[3] if concat_pool else c[3], 10, bias=True),
        }),
        'logits': Identity(),
    }


def conv_bn(c_in, c_out, bn_weight_init=1.0, **kw):
    return {
        'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), 
        'bn': batch_norm(c_out, bn_weight_init=bn_weight_init, **kw), 
        'relu': nn.ReLU(True)
    }

def basic_net(channels, weight,  pool, **kw):
    return {
        'input': (None, []),
        'prep': conv_bn(3, channels['prep'], **kw),
        'layer1': dict(conv_bn(channels['prep'], channels['layer1'], **kw), pool=pool),
        'layer2': dict(conv_bn(channels['layer1'], channels['layer2'], **kw), pool=pool),
        'layer3': dict(conv_bn(channels['layer2'], channels['layer3'], **kw), pool=pool),
        'pool': nn.MaxPool2d(4),
        'flatten': Flatten(),
        'linear': nn.Linear(channels['layer3'], 10, bias=False),
        'logits': Mul(weight),
    }

def net(channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer3'), **kw):
    channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
    residual = lambda c, **kw: {'in': Identity(), 'res1': conv_bn(c, c, **kw), 'res2': conv_bn(c, c, **kw), 
                                'add': (Add(), ['in', 'res2/relu'])}
    n = basic_net(channels, weight, pool, **kw)
    for layer in res_layers:
        n[layer]['residual'] = residual(channels[layer], **kw)
    for layer in extra_layers:
        n[layer]['extra'] = conv_bn(channels[layer], channels[layer], **kw)       
    return n

remove_identity_nodes = lambda net: remove_by_type(net, Identity)

### Download and preprocess data

In [3]:
DATA_DIR = './data'
dataset = cifar10(DATA_DIR)
timer = Timer()
print('Preprocessing training data')
transforms = [
    partial(normalise, mean=np.array(cifar10_mean, dtype=np.float32), std=np.array(cifar10_std, dtype=np.float32)),
    partial(transpose, source='NHWC', target='NCHW'), 
]
train_set = list(zip(*preprocess(dataset['train'], [partial(pad, border=4)] + transforms).values()))
print(f'Finished in {timer():.2} seconds')
print('Preprocessing test data')
test_set = list(zip(*preprocess(dataset['valid'], transforms).values()))
print(f'Finished in {timer():.2} seconds')

Files already downloaded and verified
Files already downloaded and verified
Preprocessing training data
Finished in 2.4 seconds
Preprocessing test data
Finished in 0.063 seconds


### Training loop

In [4]:
def train(model, lr_schedule, train_set, test_set, batch_size, num_workers=0):
    train_batches = DataLoader(train_set, batch_size, shuffle=True, set_random_choices=True, num_workers=num_workers)
    test_batches = DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers)
    
    lr = lambda step: lr_schedule(step/len(train_batches))/batch_size
    opts = [SGD(trainable_params(model).values(), {'lr': lr, 'weight_decay': Const(5e-4*batch_size), 'momentum': Const(0.9)})]
    logs, state = Table(), {MODEL: model, LOSS: x_ent_loss, OPTS: opts}
    for epoch in range(lr_schedule.knots[-1]):
        logs.append(union({'epoch': epoch+1, 'lr': lr_schedule(epoch+1)}, 
                          train_epoch(state, Timer(torch.cuda.synchronize), train_batches, test_batches)))
    return logs

### [Post 1: Baseline](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_1/) - DAWNbench baseline + no initial bn-relu+ efficient dataloading/augmentation, 1 dataloader process (301s)

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.1, 0.005, 0])
batch_size = 128

n = DAWN_net()
draw(build_graph(n))
model = Network(n).to(device)
#convert all children including batch norms to half precision (triggering slow codepath!)
for v in model.children(): 
    v.half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=1)

FileNotFoundError: [Errno 2] "dot" not found in path.

<core.DotGraph at 0x7f3e203bd910>

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1026.)
  dw.add_(weight_decay, w).mul_(-lr)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16]

       epoch           lr   train time   train loss    train acc   valid time   valid loss    valid acc   total time
           1       0.0067       8.9709       1.6573       0.3901       0.4978       1.1870       0.5663       8.9709


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a016400000004'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           2       0.0133       6.8396       1.0298       0.6308       0.4474       1.0531       0.6460       6.8396


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a016600000006'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           3       0.0200       6.8876       0.7743       0.7298       0.4466       1.0056       0.6826       6.8876


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a016800000008'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           4       0.0267       6.8439       0.6491       0.7772       0.4492       0.6657       0.7764       6.8439


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a016a0000000a'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           5       0.0333       6.8630       0.5722       0.8046       0.4543       0.6490       0.7908       6.8630


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a016c0000000c'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           6       0.0400       6.8963       0.5106       0.8254       0.4422       0.6044       0.8065       6.8963


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a016e0000000e'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           7       0.0467       6.8388       0.4645       0.8406       0.4455       0.5552       0.8189       6.8388


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017000000010'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           8       0.0533       6.8134       0.4280       0.8544       0.4451       0.5107       0.8319       6.8134


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017200000012'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

           9       0.0600       6.8321       0.4010       0.8613       0.4467       0.4965       0.8285       6.8321


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017400000014'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          10       0.0667       6.9679       0.3832       0.8676       0.4456       0.4663       0.8431       6.9679


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017600000016'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          11       0.0733       6.8227       0.3683       0.8732       0.4425       0.4729       0.8412       6.8227


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017800000018'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          12       0.0800       6.8363       0.3581       0.8762       0.4466       0.5066       0.8304       6.8363


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017a0000001a'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          13       0.0867       6.8269       0.3525       0.8784       0.4457       0.4914       0.8323       6.8269


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017c0000001c'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          14       0.0933       6.8988       0.3501       0.8800       0.4520       0.4686       0.8436       6.8988


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a017e0000001e'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          15       0.1000       6.8701       0.3583       0.8776       0.4480       0.5240       0.8269       6.8701


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018000000020'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          16       0.0937       6.8892       0.3518       0.8787       0.4576       0.4120       0.8607       6.8892


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018200000022'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          17       0.0873       6.8307       0.3279       0.8869       0.4495       0.5772       0.8096       6.8307


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018400000024'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          18       0.0810       6.8793       0.3173       0.8917       0.4494       0.4127       0.8592       6.8793


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018600000026'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          19       0.0747       6.8353       0.2935       0.8973       0.4469       0.4140       0.8626       6.8353


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018800000028'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          20       0.0683       6.8441       0.2807       0.9031       0.4492       0.4592       0.8500       6.8441


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018a0000002a'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          21       0.0620       6.8191       0.2605       0.9099       0.4498       0.3379       0.8897       6.8191


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018c0000002c'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          22       0.0557       6.8111       0.2362       0.9186       0.4515       0.3449       0.8857       6.8111


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a018e0000002e'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          23       0.0493       6.8286       0.2213       0.9229       0.4496       0.3679       0.8807       6.8286


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019000000030'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          24       0.0430       6.8143       0.2050       0.9287       0.4480       0.3299       0.8920       6.8143


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019200000032'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          25       0.0367       6.8170       0.1801       0.9363       0.4483       0.2686       0.9113       6.8170


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019400000034'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          26       0.0303       6.8315       0.1532       0.9467       0.4439       0.3119       0.8954       6.8315


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019600000036'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          27       0.0240       6.9069       0.1287       0.9553       0.4467       0.2422       0.9211       6.9069


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019800000038'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          28       0.0177       6.8732       0.0985       0.9655       0.4449       0.2380       0.9271       6.8732


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019a0000003a'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", li

          29       0.0113       6.8898       0.0733       0.9757       0.4456       0.2132       0.9330       6.8898


Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.8/shutil.py", line 718, in rmtree
    _rmtree_safe_fd(fd, path, onerror)
  File "/opt/conda/lib/python3.8/shutil.py", line 675, in _rmtree_safe_fd
    onerror(os.unlink, fullname, sys.exc_info())
  File "/opt/conda/lib/python3.8/shutil.py", line 673, in _rmtree_safe_fd
    os.unlink(entry.name, dir_fd=topfd)
OSError: [Errno 16] Device or resource busy: '.nfs000000000a7a019c0000003c'


### [Post 1: Baseline](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_1/) - 0 dataloader processes (297s)

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.1, 0.005, 0])
batch_size = 128

n = DAWN_net()
draw(build_graph(n))
model = Network(n).to(device)
#convert all children including batch norms to half precision (triggering slow codepath!)
for v in model.children(): 
    v.half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 2: Mini-batches](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_2/) - batch size=512 (256s)

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.44, 0.005, 0])
batch_size = 512

n = DAWN_net()
draw(build_graph(n))
model = Network(n).to(device)
#convert all children including batch norms to half precision (triggering slow codepath!)
for v in model.children(): 
    v.half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 3: Regularisation](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_3/) - speed up batch norms (186s)

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.44, 0.005, 0])
batch_size = 512

n = DAWN_net()
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 3: Regularisation](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_3/) - cutout+30 epochs+batch_size=512 (161s)

In [None]:
lr_schedule = PiecewiseLinear([0, 8, 30], [0, 0.4, 0])
batch_size = 512

n = DAWN_net()
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 3: Regularisation](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_3/)  - batch_size=768 (154s)

In [None]:
lr_schedule = PiecewiseLinear([0, 8, 30], [0, 0.6, 0])
batch_size = 768

n = DAWN_net()
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - backbone (36s; test acc 55.9%)

It seems reasonable to study how the shortest path through the network trains in isolation and to take steps to improve this before adding back the longer branches. 
Eliminating the long branches yields the following backbone network in which all convolutions, except for the initial one, have a stride of two.

Training the shortest path network for 20 epochs yields an unimpressive test accuracy of 55.9% in 36 seconds.

In [None]:
def shortcut_block(c_in, c_out, stride, **kw):
    block = {
        'bn1': batch_norm(c_in, **kw),
        'relu1': nn.ReLU(True),
    }
    projection = (stride != 1) or (c_in != c_out)    
    if projection:
        block['conv3'] = (nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False), ['relu1'])
    return block

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(block=shortcut_block)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - backbone, remove repeat bn-relu (32s; test acc 56.0%)

Removing the repeated batch norm-ReLU groups, reduces training time to 32s and leaves test accuracy approximately unchanged.

In [None]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv':  nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(block=shortcut_block, prep_bn_relu=True)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - backbone, 3x3 convs (36s; test acc 85.6%)

A serious shortcoming of this network is that the downsampling convolutions have 1x1 kernels and a stride of two, so that rather than enlarging the receptive field they are simply discarding information. 

If we replace these with 3x3 convolutions, things improve considerably and test accuracy after 20 epochs is 85.6% in a time of 36s.

In [None]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=stride, padding=1, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(block=shortcut_block, prep_bn_relu=True)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - backbone, maxpool downsampling (43s; test acc 89.7%)

We can further improve the downsampling stages by applying 3x3 convolutions of stride one followed by a pooling layer instead of using strided convolutions. 

We choose max pooling with a 2x2 window size leading to a final test accuracy of 89.7% after 43s. Using average pooling gives a similar result but takes slightly longer.

In [None]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
            'pool': nn.MaxPool2d(2),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(block=shortcut_block, prep_bn_relu=True)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - backbone, 2x output dim, global maxpool (47s; test acc 90.7%)

The final pooling layer before the classifier is a concatenation of global average pooling and max pooling layers, inherited from the original network. 

We replace this with a more standard global max pooling layer and double the output dimension of the final convolution to compensate for the reduction in input dimension to the classifier, leading to a final test accuracy of 90.7% in 47s. Note that average pooling at this stage underperforms max pooling significantly.


In [None]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
            'pool': nn.MaxPool2d(2),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(c=[64,128,256,512], block=shortcut_block, prep_bn_relu=True, concat_pool=False)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - backbone, bn scale init=1, classifier weight=0.125 (47s; test acc 91.1%)

By default in PyTorch (0.4), initial batch norm scales are chosen uniformly at random from the interval [0,1]. Channels which are initialised near zero could be wasted so we replace this with a constant initialisation at 1. 
This leads to a larger signal through the network and to compensate we introduce an overall constant multiplicative rescaling of the final classifier. A rough manual optimisation of this extra hyperparameter suggest that 0.125 is a reasonable value. 
(The low value makes predictions less certain and appears to ease optimisation.) 

With these changes in place, 20 epoch training reaches a test accuracy of 91.1% in 47s. 

In [None]:
lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = net(extra_layers=(), res_layers=())
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - double width, 60 epoch train! (321s; test acc 93.5%)

ne approach that doesn't seem particularly promising is to just add width. 

If we double the channel dimensions and train for 60 epochs we can reach 93.5% test accuracy with a 5 layer network. This is nice but not efficient since training now takes 321s.

In [None]:
lr_schedule = PiecewiseLinear([0, 12, 60], [0, 0.4, 0])
batch_size = 512
c = 128

n = net(channels={'prep': c, 'layer1': 2*c, 'layer2': 4*c, 'layer3': 8*c}, extra_layers=(), res_layers=())
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - extra:L1+L2+L3 network, 60 epochs, cutout=12 (180s, 95.0% test acc) 

In [None]:
lr_schedule = PiecewiseLinear([0, 12, 60], [0, 0.4, 0])
batch_size = 512
cutout=12

n = net(extra_layers=['layer1', 'layer2', 'layer3'], res_layers=())
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(cutout, cutout)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - final network Residual:L1+L3, 20 epochs (66s; test acc 93.7%)

In [None]:
lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = net()
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

### [Post 4: Architecture](https://www.myrtle.ai/2018/10/26/how_to_train_your_resnet_4/)  - final network, 24 epochs (79s; test acc 94.1%)

In [None]:
lr_schedule = PiecewiseLinear([0, 5, 24], [0, 0.4, 0])
batch_size = 512

n = net()
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)