In [1]:
"""
@Description :   CNN 方法对比
@Author      :   Xubo Luo 
@Time        :   2024/01/08 16:04:58
"""

from utils.core import *
from utils.torch_backend import *

## 网络&函数定义

In [8]:
colors = ColorMap()
draw = lambda graph: display(DotGraph({p: ({'fillcolor': colors[type(v)], 'tooltip': repr(v)}, inputs) for p, (v, inputs) in graph.items() if v is not None}))


batch_norm = partial(BatchNorm, weight_init=None, bias_init=None)

def res_block(c_in, c_out, stride, **kw):
    block = {
        'bn1': batch_norm(c_in, **kw),
        'relu1': nn.ReLU(True),
        'branch': {
            'conv1': nn.Conv2d(c_in, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
            'bn2': batch_norm(c_out, **kw),
            'relu2': nn.ReLU(True),
            'conv2': nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
        }
    }
    projection = (stride != 1) or (c_in != c_out)    
    if projection:
        block['conv3'] = (nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False), ['relu1'])
    block['add'] =  (Add(), [('conv3' if projection else 'relu1'), 'branch/conv2'])
    return block

def DAWN_net(c=64, block=res_block, prep_bn_relu=False, concat_pool=True, **kw):    
    if isinstance(c, int):
        c = [c, 2*c, 4*c, 4*c]
        
    classifier_pool = {
        'in': Identity(),
        'maxpool': nn.MaxPool2d(4),
        'avgpool': (nn.AvgPool2d(4), ['in']),
        'concat': (Concat(), ['maxpool', 'avgpool']),
    } if concat_pool else {'pool': nn.MaxPool2d(4)}
    
    return {
        'input': (None, []),
        'prep': union({'conv': nn.Conv2d(3, c[0], kernel_size=3, stride=1, padding=1, bias=False)},
                      {'bn': batch_norm(c[0], **kw), 'relu': nn.ReLU(True)} if prep_bn_relu else {}),
        'layer1': {
            'block0': block(c[0], c[0], 1, **kw),
            'block1': block(c[0], c[0], 1, **kw),
        },
        'layer2': {
            'block0': block(c[0], c[1], 2, **kw),
            'block1': block(c[1], c[1], 1, **kw),
        },
        'layer3': {
            'block0': block(c[1], c[2], 2, **kw),
            'block1': block(c[2], c[2], 1, **kw),
        },
        'layer4': {
            'block0': block(c[2], c[3], 2, **kw),
            'block1': block(c[3], c[3], 1, **kw),
        },
        'final': union(classifier_pool, {
            'flatten': Flatten(),
            'linear': nn.Linear(2*c[3] if concat_pool else c[3], 10, bias=True),
        }),
        'logits': Identity(),
    }


def conv_bn(c_in, c_out, bn_weight_init=1.0, **kw):
    return {
        'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), 
        'bn': batch_norm(c_out, **kw), 
        # 'bn': batch_norm(c_out, bn_weight_init=bn_weight_init, **kw), 
        'relu': nn.ReLU(True)
    }

def basic_net(channels, weight,  pool, **kw):
    return {
        'input': (None, []),
        'prep': conv_bn(3, channels['prep'], **kw),
        'layer1': dict(conv_bn(channels['prep'], channels['layer1'], **kw), pool=pool),
        'layer2': dict(conv_bn(channels['layer1'], channels['layer2'], **kw), pool=pool),
        'layer3': dict(conv_bn(channels['layer2'], channels['layer3'], **kw), pool=pool),
        'pool': nn.MaxPool2d(4),
        'flatten': Flatten(),
        'linear': nn.Linear(channels['layer3'], 10, bias=False),
        'logits': Mul(weight),
    }

def net(channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer3'), **kw):
    channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
    residual = lambda c, **kw: {'in': Identity(), 'res1': conv_bn(c, c, **kw), 'res2': conv_bn(c, c, **kw), 
                                'add': (Add(), ['in', 'res2/relu'])}
    n = basic_net(channels, weight, pool, **kw)
    for layer in res_layers:
        n[layer]['residual'] = residual(channels[layer], **kw)
    for layer in extra_layers:
        n[layer]['extra'] = conv_bn(channels[layer], channels[layer], **kw)       
    return n

remove_identity_nodes = lambda net: remove_by_type(net, Identity)

def train(model, lr_schedule, train_set, test_set, batch_size, num_workers=0):
    train_batches = DataLoader(train_set, batch_size, shuffle=True, set_random_choices=True, num_workers=num_workers)
    test_batches = DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers)
    
    lr = lambda step: lr_schedule(step/len(train_batches))/batch_size
    opts = [SGD(trainable_params(model).values(), {'lr': lr, 'weight_decay': Const(5e-4*batch_size), 'momentum': Const(0.9)})]
    logs, state = Table(), {MODEL: model, LOSS: x_ent_loss, OPTS: opts}
    for epoch in range(lr_schedule.knots[-1]):
        logs.append(union({'epoch': epoch+1, 'lr': lr_schedule(epoch+1)}, 
                          train_epoch(state, Timer(torch.cuda.synchronize), train_batches, test_batches)))
    return logs

## 导入数据

In [3]:
DATA_DIR = '../../cifar-10-python/'
dataset = cifar10(DATA_DIR)
timer = Timer()
print('Preprocessing training data')
transforms = [
    partial(normalise, mean=np.array(cifar10_mean, dtype=np.float32), std=np.array(cifar10_std, dtype=np.float32)),
    partial(transpose, source='NHWC', target='NCHW'), 
]
train_set = list(zip(*preprocess(dataset['train'], [partial(pad, border=4)] + transforms).values()))
print(f'Finished in {timer():.2} seconds')
print('Preprocessing test data')
test_set = list(zip(*preprocess(dataset['valid'], transforms).values()))
print(f'Finished in {timer():.2} seconds')

Files already downloaded and verified
Files already downloaded and verified
Preprocessing training data
Finished in 1.6 seconds
Preprocessing test data
Finished in 0.066 seconds


## 1x1conv

In [4]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv':  nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(block=shortcut_block, prep_bn_relu=True)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

pydot is needed for network visualisation

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:1025.)
  dw.add_(weight_decay, w).mul_(-lr)


       epoch           lr   train time   train loss    train acc   valid time   valid loss    valid acc   total time
           1       0.1000       5.7514       2.1632       0.2573       0.2142       1.7764       0.3457       5.7514
           2       0.2000       2.2766       1.7967       0.3425       0.1582       1.6905       0.4017       2.2766
           3       0.3000       2.2493       1.6980       0.3822       0.1511       1.6319       0.4077       2.2493
           4       0.4000       2.2360       1.6616       0.3977       0.1608       1.6039       0.4261       2.2360
           5       0.3750       2.2701       1.6383       0.4064       0.1509       1.5561       0.4418       2.2701
           6       0.3500       2.2868       1.6021       0.4220       0.1601       1.5527       0.4362       2.2868
           7       0.3250       2.3301       1.5742       0.4303       0.1669       1.4972       0.4617       2.3301
           8       0.3000       2.3133       1.5483       0.4395

## 3x3conv

In [5]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=stride, padding=1, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(block=shortcut_block, prep_bn_relu=True)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

pydot is needed for network visualisation

       epoch           lr   train time   train loss    train acc   valid time   valid loss    valid acc   total time
           1       0.1000       8.3853       1.9080       0.3319       0.4811       1.4173       0.4839       8.3853
           2       0.2000       7.0179       1.4362       0.4807       0.3416       1.1469       0.5898       7.0179
           3       0.3000       6.9860       1.2225       0.5640       0.3327       1.0714       0.6154       6.9860
           4       0.4000       7.1769       1.1203       0.6029       0.3511       1.0311       0.6370       7.1769
           5       0.3750       7.1263       1.0148       0.6419       0.3477       0.8888       0.6886       7.1263
           6       0.3500       7.3021       0.9288       0.6728       0.3431       0.8877       0.6986       7.3021
           7       0.3250       7.3242       0.8602       0.6974       0.3510       0.9942       0.6666       7.3242
           8       0.3000       7.4000       0.8101       0.7158

## output

In [6]:
def shortcut_block(c_in, c_out, stride, **kw):
    projection = (stride != 1) or (c_in != c_out)
    if projection:
        return {
            'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), 
            'bn': batch_norm(c_out, **kw),
            'relu': nn.ReLU(True),
            'pool': nn.MaxPool2d(2),
        }
    else:
        return {'id': Identity()}

lr_schedule = PiecewiseLinear([0, 4, 20], [0, 0.4, 0])
batch_size = 512

n = DAWN_net(c=[64,128,256,512], block=shortcut_block, prep_bn_relu=True, concat_pool=False)
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

pydot is needed for network visualisation

       epoch           lr   train time   train loss    train acc   valid time   valid loss    valid acc   total time
           1       0.1000      17.6184       2.6301       0.2590       1.2205       1.6399       0.4023      17.6184
           2       0.2000      11.8325       1.5426       0.4441       0.7181       1.3402       0.5103      11.8325
           3       0.3000      11.8162       1.2086       0.5693       0.7147       1.0695       0.6004      11.8162
           4       0.4000      11.7407       1.0536       0.6328       0.7219       1.0388       0.6322      11.7407
           5       0.3750      11.7504       0.8769       0.6948       0.7180       1.7411       0.5360      11.7504
           6       0.3500      11.6977       0.8053       0.7229       0.7185       0.5967       0.7943      11.6977
           7       0.3250      11.8204       0.6829       0.7624       0.7191       0.6444       0.7772      11.8204
           8       0.3000      11.8723       0.6211       0.7840

## shortcut

In [9]:
lr_schedule = PiecewiseLinear([0, 5, 24], [0, 0.4, 0])
batch_size = 512

n = net()
draw(build_graph(n))
model = Network(n).to(device).half()
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary = train(model, lr_schedule, train_set_x, test_set, batch_size=batch_size, num_workers=0)

pydot is needed for network visualisation

       epoch           lr   train time   train loss    train acc   valid time   valid loss    valid acc   total time
           1       0.0800      18.7367       1.6489       0.4047       1.2715       1.2100       0.5499      18.7367
           2       0.1600      16.6009       0.9494       0.6624       1.0272       1.0279       0.6544      16.6009
           3       0.2400      16.6721       0.7362       0.7425       1.0589       1.1312       0.6476      16.6721
           4       0.3200      17.2054       0.6191       0.7855       1.0985       0.6790       0.7642      17.2054
           5       0.4000      17.4312       0.5612       0.8054       1.0870       0.6438       0.7775      17.4312
           6       0.3789      17.2692       0.4995       0.8284       1.0920       0.4727       0.8389      17.2692
           7       0.3579      17.2604       0.4427       0.8499       1.0917       0.6155       0.7845      17.2604
           8       0.3368      17.1863       0.4127       0.8587