Squeeze-Excite-Aggregate Networks (SEANet) with base ResNet_20.

Dataset: CIFAR-100

SEANet gives 78.67% validation accuracy.

In [1]:
from pathlib import Path
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from se_resnet import se_resnet20
from baseline import resnet20
from utils import Trainer

In [2]:
def get_dataloader(batch_size, root="~/cifar100"):
    root = Path(root).expanduser()
    if not root.exists():
        root.mkdir()
    root = str(root)

    to_normalized_tensor = [transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    data_augmentation = [transforms.RandomCrop(32, padding=4),
                         transforms.RandomHorizontalFlip()]

    train_loader = DataLoader(
        datasets.CIFAR100(root, train=True, download=True,
                         transform=transforms.Compose(data_augmentation + to_normalized_tensor)),
        batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(
        datasets.CIFAR100(root, train=False, transform=transforms.Compose(to_normalized_tensor)),
        batch_size=batch_size, shuffle=True)
    return train_loader, test_loader

In [3]:
def main(batch_size, baseline, reduction):
    print("Batch_size: {}, Baseline(Resnet20): {}, Reduction: {}".format(batch_size, baseline, reduction))
    train_loader, test_loader = get_dataloader(batch_size)

    if baseline:
        model = resnet20(num_classes= 100)
        print("\n_________Baseline ResNet-20 model______________ \n")
    else:
        model = se_resnet20(num_classes=100, reduction=reduction)
        print("\n_________Baseline Se_ResNet-20 model______________ \n")
        
    optimizer = optim.SGD(params=model.parameters(), lr=1e-1, momentum=0.9,
                          weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 80, 0.1)
    trainer = Trainer(model, optimizer, F.cross_entropy)
    trainer.loop(200, train_loader, test_loader, scheduler)


if __name__ == '__main__':
    batchsize= 64
    reduction= 8
    baseline= False # se_resnet20
    main(batchsize, baseline, reduction)

Batch_size: 64, Baseline(Resnet20): False, Reduction: 8
Files already downloaded and verified

_________Baseline Se_ResNet-20 model______________ 

epochs: 0
>>>[train] loss: 4.42/accuracy: 3.07%
>>>[test] loss: 4.01/accuracy: 6.28%
Best val acc till now:  0.0628
epochs: 1
>>>[train] loss: 3.89/accuracy: 8.78%
>>>[test] loss: 3.81/accuracy: 10.72%
Best val acc till now:  0.1072
epochs: 2
>>>[train] loss: 3.53/accuracy: 14.90%
>>>[test] loss: 3.43/accuracy: 18.78%
Best val acc till now:  0.1878
epochs: 3
>>>[train] loss: 2.98/accuracy: 24.44%
>>>[test] loss: 2.85/accuracy: 26.88%
Best val acc till now:  0.2688
epochs: 4
>>>[train] loss: 2.52/accuracy: 33.70%
>>>[test] loss: 2.58/accuracy: 34.12%
Best val acc till now:  0.3412
epochs: 5
>>>[train] loss: 2.17/accuracy: 41.43%
>>>[test] loss: 2.39/accuracy: 39.05%
Best val acc till now:  0.3905
epochs: 6
>>>[train] loss: 1.90/accuracy: 47.64%
>>>[test] loss: 2.04/accuracy: 46.60%
Best val acc till now:  0.466
epochs: 7
>>>[train] loss: 1.6

In [1]:
import torch
checkpoint= torch.load("./best_model.pth")

In [2]:
checkpoint.keys()

dict_keys(['epoch', 'state_dict', 'optimizer', 'best_acc1'])

In [3]:
checkpoint['best_acc1']

0.7867

In [4]:
checkpoint['epoch']

170