Squeeze-Excite-Aggregate Networks (SEANet) with base ResNet_20. 
Dataset: CIFAR-10
SEANet gives 95.7% validation accuracy.

Source of original SENet: https://github.com/moskomule/senet.pytorch

In [2]:
from pathlib import Path
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from se_resnet import se_resnet20
from baseline import resnet20
from utils import Trainer

In [2]:
def get_dataloader(batch_size, root="~/cifar10"):
    root = Path(root).expanduser()
    if not root.exists():
        root.mkdir()
    root = str(root)

    to_normalized_tensor = [transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    data_augmentation = [transforms.RandomCrop(32, padding=4),
                         transforms.RandomHorizontalFlip()]

    train_loader = DataLoader(
        datasets.CIFAR10(root, train=True, download=True,
                         transform=transforms.Compose(data_augmentation + to_normalized_tensor)),
        batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(
        datasets.CIFAR10(root, train=False, transform=transforms.Compose(to_normalized_tensor)),
        batch_size=batch_size, shuffle=True)
    return train_loader, test_loader

In [3]:
def main(batch_size, baseline, reduction):
    print("Batch_size: {}, Baseline(Resnet20): {}, Reduction: {}".format(batch_size, baseline, reduction))
    train_loader, test_loader = get_dataloader(batch_size)

    if baseline:
        model = resnet20()
        print("\n_________Baseline ResNet-20 model______________ \n")
    else:
        model = se_resnet20(num_classes=10, reduction=reduction)
        print("\n_________Baseline Se_ResNet-20 model______________ \n")
        
    optimizer = optim.SGD(params=model.parameters(), lr=1e-1, momentum=0.9,
                          weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 80, 0.1)
    trainer = Trainer(model, optimizer, F.cross_entropy)
    trainer.loop(200, train_loader, test_loader, scheduler)


if __name__ == '__main__':
    batchsize= 64
    reduction= 8
    baseline= False # Se_resnet20
    main(batchsize, baseline, reduction)

Batch_size: 64, Baseline(Resnet20): False, Reduction: 8
Files already downloaded and verified

_________Baseline Se_ResNet-20 model______________ 

epochs: 1
>>>[train] loss: 2.09/accuracy: 23.93%
>>>[test] loss: 1.74/accuracy: 33.95%
Best val acc till now:  0.3395
epochs: 2
>>>[train] loss: 1.64/accuracy: 38.30%
>>>[test] loss: 1.50/accuracy: 43.88%
Best val acc till now:  0.4388
epochs: 3
>>>[train] loss: 1.34/accuracy: 51.18%
>>>[test] loss: 1.20/accuracy: 56.82%
Best val acc till now:  0.5682
epochs: 4
>>>[train] loss: 1.05/accuracy: 62.37%
>>>[test] loss: 1.04/accuracy: 64.75%
Best val acc till now:  0.6475
epochs: 5
>>>[train] loss: 0.90/accuracy: 68.27%
>>>[test] loss: 0.85/accuracy: 70.81%
Best val acc till now:  0.7081
epochs: 6
>>>[train] loss: 0.72/accuracy: 75.07%
>>>[test] loss: 0.65/accuracy: 78.06%
Best val acc till now:  0.7806
epochs: 7
>>>[train] loss: 0.60/accuracy: 79.09%
>>>[test] loss: 0.58/accuracy: 79.69%
Best val acc till now:  0.7969
epochs: 8
>>>[train] loss:

In [4]:
val_acc=  [0.3395, 0.4388, 0.5682, 0.6475, 0.7081, 0.7806, 0.7969, 0.7883, 0.8169, 0.8341, 0.8451, 0.8506, 0.8712, 0.8557, 0.8653, 0.8511, 0.8581, 0.8344, 0.8634, 0.8359, 0.8635, 0.8851, 0.8918, 0.8749, 0.8707, 0.8741, 0.8974, 0.8819, 0.8822, 0.8833, 0.8973, 0.8607, 0.8888, 0.8884, 0.8556, 0.8833, 0.888, 0.8922, 0.8939, 0.8905, 0.8915, 0.8826, 0.8829, 0.8991, 0.8973, 0.8956, 0.8915, 0.8775, 0.9004, 0.8895, 0.8887, 0.8946, 0.8846, 0.8933, 0.8972, 0.8981, 0.9014, 0.8961, 0.8763, 0.8864, 0.9071, 0.8521, 0.9, 0.9013, 0.8979, 0.8823, 0.8855, 0.9108, 0.9149, 0.8977, 0.9124, 0.9013, 0.8966, 0.8954, 0.9008, 0.8871, 0.8921, 0.9001, 0.908, 0.9119, 0.9453, 0.9498, 0.9511, 0.9512, 0.9513, 0.952, 0.9535, 0.9529, 0.9519, 0.9532, 0.9529, 0.9549, 0.954, 0.9541, 0.9548, 0.9542, 0.9538, 0.954, 0.9535, 0.9545, 0.9535, 0.9538, 0.9539, 0.9544, 0.9544, 0.9534, 0.9543, 0.955, 0.9537, 0.9526, 0.9539, 0.9562, 0.9547, 0.9543, 0.9533, 0.9541, 0.9533, 0.9528, 0.955, 0.9531, 0.954, 0.953, 0.9542, 0.9538, 0.9535, 0.9537, 0.9544, 0.9542, 0.955, 0.9528, 0.9523, 0.9538, 0.9528, 0.9513, 0.9523, 0.9527, 0.9531, 0.953, 0.953, 0.9521, 0.953, 0.9533, 0.952, 0.9534, 0.9523, 0.9528, 0.9533, 0.9539, 0.954, 0.9528, 0.9529, 0.9545, 0.9548, 0.952, 0.9532, 0.9529, 0.9531, 0.9555, 0.9536, 0.9521, 0.9539, 0.9541, 0.9543, 0.9553, 0.9546, 0.9548, 0.9549, 0.955, 0.9552, 0.955, 0.954, 0.9538, 0.9552, 0.955, 0.9546, 0.9542, 0.9554, 0.9543, 0.9543, 0.955, 0.9557, 0.9552, 0.9557, 0.9558, 0.9548, 0.9554, 0.9566, 0.9554, 0.957, 0.956, 0.955, 0.9554, 0.9564, 0.9544, 0.9556, 0.9545, 0.9561, 0.9554, 0.9555, 0.9537]

In [5]:
max(val_acc)

0.957