In [None]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10

In [None]:
!pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

In [None]:
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        resnet18 = models.resnet18()
        resnet18.fc = nn.Linear(in_features=512, out_features=10, bias=True)
        self.resnet = resnet18
    def forward(self, x):
        x = self.resnet(x)
        return x

In [None]:
class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        resnet50 = models.resnet50()
        resnet50.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
        self.resnet = resnet50
    def forward(self, x):
        x = self.resnet(x)
        return x

In [None]:
class DGNet18(nn.Module):
    def __init__(self, threshold = 0):
        super(DGNet18, self).__init__()
        resnet = models.resnet18()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[64]*4, [128]*4, [256]*4, [512]*4]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=512, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(4)])
        for i in range(4):
            layer = layers[i]
            for j in range(2):
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv1,
                    layer[j].bn1,
                    layer[j].relu,
                ))
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv2,
                    layer[j].bn2,
                    layer[j].relu,
                ))
        self.routers = nn.ModuleList([nn.ModuleList() for _ in range(len(self.nodes))])
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                self.routers[i].append(nn.Sequential(
                    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                    nn.Flatten(1),
                    nn.Linear(self.c_out[i][j],  len(self.nodes[i]) - j, bias=True),
                    nn.Sigmoid(),
                    nn.Threshold(threshold, 0),
            ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                route = self.routers[i][j](x)
                for k in range(j, len(self.nodes[i])):
                    r = route[:, k - j].reshape(-1, 1, 1, 1).repeat(1, x.size(1), x.size(2), x.size(3))
                    r_output = r * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class DGNet50(nn.Module):
    def __init__(self, threshold = 0):
        super(DGNet50, self).__init__()
        resnet = models.resnet50()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[256]*3, [512]*4, [1024]*6, [2048]*3]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(len(layers))])
        for i in range(len(layers)):
            for j in range(len(self.c_out[i])):
                self.nodes[i].append(nn.Sequential(
                    layers[i][j],
                ))
        self.routers = nn.ModuleList([nn.ModuleList() for _ in range(len(self.nodes))])
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                self.routers[i].append(nn.Sequential(
                    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                    nn.Flatten(1),
                    nn.Linear(self.c_out[i][j],  len(self.nodes[i]) - j, bias=True),
                    nn.Sigmoid(),
                    nn.Threshold(threshold, 0),
                ))
                self.routers[i][j].apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=1.0, std=1.0)

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                route = self.routers[i][j](x)
                for k in range(j, len(self.nodes[i])):
                    r = route[:, k - j].reshape(-1, 1, 1, 1).repeat(1, x.size(1), x.size(2), x.size(3))
                    r_output = r * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class DGNet18_Norm(nn.Module):
    def __init__(self):
        super(DGNet18_Norm, self).__init__()
        resnet = models.resnet18()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[64]*4, [128]*4, [256]*4, [512]*4]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=512, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(4)])
        for i in range(4):
            layer = layers[i]
            for j in range(2):
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv1,
                    layer[j].bn1,
                    layer[j].relu,
                ))
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv2,
                    layer[j].bn2,
                    layer[j].relu,
                ))
        self.routers = nn.ModuleList([nn.ModuleList() for _ in range(len(self.nodes))])
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                self.routers[i].append(nn.Sequential(
                    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                    nn.Flatten(1),
                    nn.Linear(self.c_out[i][j],  len(self.nodes[i]) - j, bias=True),
                    nn.Sigmoid(),
                    nn.LayerNorm(len(self.nodes[i]) - j, elementwise_affine=False),
            ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                route = self.routers[i][j](x)
                for k in range(j, len(self.nodes[i])):
                    r = route[:, k - j].reshape(-1, 1, 1, 1).repeat(1, x.size(1), x.size(2), x.size(3))
                    r_output = r * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class DGNet50_Norm(nn.Module):
    def __init__(self):
        super(DGNet50_Norm, self).__init__()
        resnet = models.resnet50()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[256]*3, [512]*4, [1024]*6, [2048]*3]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(len(layers))])
        for i in range(len(layers)):
            for j in range(len(self.c_out[i])):
                self.nodes[i].append(nn.Sequential(
                    layers[i][j],
                ))
        self.routers = nn.ModuleList([nn.ModuleList() for _ in range(len(self.nodes))])
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                self.routers[i].append(nn.Sequential(
                    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                    nn.Flatten(1),
                    nn.Linear(self.c_out[i][j],  len(self.nodes[i]) - j, bias=True),
                    nn.Sigmoid(),
                    nn.LayerNorm(len(self.nodes[i]) - j, elementwise_affine=False),
            ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                route = self.routers[i][j](x)
                for k in range(j, len(self.nodes[i])):
                    r = route[:, k - j].reshape(-1, 1, 1, 1).repeat(1, x.size(1), x.size(2), x.size(3))
                    r_output = r * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class DGNet18_Soft(nn.Module):
    def __init__(self):
        super(DGNet18_Soft, self).__init__()
        resnet = models.resnet18()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[64]*4, [128]*4, [256]*4, [512]*4]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=512, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(4)])
        for i in range(4):
            layer = layers[i]
            for j in range(2):
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv1,
                    layer[j].bn1,
                    layer[j].relu,
                ))
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv2,
                    layer[j].bn2,
                    layer[j].relu,
                ))
        self.routers = nn.ModuleList([nn.ModuleList() for _ in range(len(self.nodes))])
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                self.routers[i].append(nn.Sequential(
                    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                    nn.Flatten(1),
                    nn.Linear(self.c_out[i][j],  len(self.nodes[i]) - j, bias=True),
                    nn.Sigmoid(),
                    nn.Softmax(dim=-1),
            ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                route = self.routers[i][j](x)
                for k in range(j, len(self.nodes[i])):
                    r = route[:, k - j].reshape(-1, 1, 1, 1).repeat(1, x.size(1), x.size(2), x.size(3))
                    r_output = r * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class DGNet50_Soft(nn.Module):
    def __init__(self):
        super(DGNet50_Soft, self).__init__()
        resnet = models.resnet50()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[256]*3, [512]*4, [1024]*6, [2048]*3]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(len(layers))])
        for i in range(len(layers)):
            for j in range(len(self.c_out[i])):
                self.nodes[i].append(nn.Sequential(
                    layers[i][j],
                ))
        self.routers = nn.ModuleList([nn.ModuleList() for _ in range(len(self.nodes))])
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                self.routers[i].append(nn.Sequential(
                    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                    nn.Flatten(1),
                    nn.Linear(self.c_out[i][j],  len(self.nodes[i]) - j, bias=True),
                    nn.Sigmoid(),
                    nn.Softmax(dim=-1),
            ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                route = self.routers[i][j](x)
                for k in range(j, len(self.nodes[i])):
                    r = route[:, k - j].reshape(-1, 1, 1, 1).repeat(1, x.size(1), x.size(2), x.size(3))
                    r_output = r * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class Dense18(nn.Module):
    def __init__(self):
        super(Dense18, self).__init__()
        resnet = models.resnet18()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[64]*4, [128]*4, [256]*4, [512]*4]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=512, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(4)])
        for i in range(4):
            layer = layers[i]
            for j in range(2):
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv1,
                    layer[j].bn1,
                    layer[j].relu,
                ))
                self.nodes[i].append(nn.Sequential(
                    layer[j].conv2,
                    layer[j].bn2,
                    layer[j].relu,
                ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                for k in range(j, len(self.nodes[i])):
                    r_output = 1 * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
class Dense50(nn.Module):
    def __init__(self):
        super(Dense50, self).__init__()
        resnet = models.resnet50()
        layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4])
        self.c_out = [[256]*3, [512]*4, [1024]*6, [2048]*3]
        self.first = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
        )
        self.fc = nn.Sequential(
            resnet.avgpool,
            nn.Flatten(1),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )
        self.nodes = nn.ModuleList([nn.ModuleList() for _ in range(len(layers))])
        for i in range(len(layers)):
            for j in range(len(self.c_out[i])):
                self.nodes[i].append(nn.Sequential(
                    layers[i][j],
                ))

    def forward(self, x):
        outputs = [[[] for _ in range(len(self.nodes[i]))] for i in range(len(self.nodes))]
        output_x = self.first(x)
        for i in range(len(self.nodes)):
            for j in range(len(self.nodes[i])):
                x = self.nodes[i][j](output_x)
                for k in range(j, len(self.nodes[i])):
                    r_output = 1 * x
                    outputs[i][k].append(r_output)
                output_x = outputs[i][j][0]
                for output in outputs[i][j][1:]:
                    output_x += output
        x = self.fc(output_x)
        return x

In [None]:
BATCH_SIZE = 128
EPOCH = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([                       
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

train_dataset = CIFAR10('data/cifar10', train=True, download=True, transform=transform)
test_dataset = CIFAR10('data/cifar10', train=False, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

In [None]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, coefficient):
        super(LabelSmoothingLoss, self).__init__()
        self.coefficient = coefficient
    def forward(self, logit, y):
        logit = logit.log_softmax(dim = -1)
        with torch.no_grad():
            yy = torch.ones_like(logit) * (self.coefficient / (logit.size(-1)))
            yy.scatter_(-1, y.data.unsqueeze(1), 1 - self.coefficient)
        return torch.mean(torch.sum(-yy * logit, -1))

In [None]:
def train_net(net, optimizer, scheduler):
    global_step, best_accuracy = 0, 0
    for epoch in range(EPOCH):
        net.train()
        for batch_idx, (x, y) in enumerate(train_dataloader):
            global_step += 1
            x, y = x.to(device), y.to(device)
            logit = net(x)
            loss = LabelSmoothingLoss(0.1)(logit, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        net.eval()
        with torch.no_grad():
            test_loss, test_accuracy, test_num_data = 0, 0, 0
            for batch_idx, (x, y) in enumerate(test_dataloader):
                x, y = x.to(device), y.to(device)
                logit = net(x)
                loss = LabelSmoothingLoss(0.1)(logit, y)
                accuracy = (logit.argmax(dim=1) == y).float().mean()
                test_loss += loss.item()*x.shape[0]
                test_accuracy += accuracy.item()*x.shape[0]
                test_num_data += x.shape[0]
            test_loss /= test_num_data
            test_accuracy /= test_num_data
            print(f'Test result of epoch {epoch + 1}/{EPOCH} || loss : {test_loss:.3f}, acc: {test_accuracy:.3f}')
            best_accuracy = max(test_accuracy, best_accuracy)
        scheduler.step()
    return best_accuracy

In [None]:
net_types = ['ResNet18', 'DGNet18(0)', 'DGNet18(0.1)', 'DGNet18(0.01)', 'DGNet18(0.001)', 'DGNet18(0.0001)', 'DGNet18(0.5)']
networks = [ResNet18().to(device), DGNet18(0).to(device), DGNet18(1e-1).to(device), DGNet18(1e-2).to(device), DGNet18(1e-3).to(device), DGNet18(1e-4).to(device), DGNet18(0.5).to(device)]

#net_types = ['ResNet50', 'DGNet50(0)', 'DGNet50(0.1)', 'DGNet50(0.01)', 'DGNet50(0.001)', 'DGNet50(0.0001)', 'DGNet50(0.00001)']
#networks = [ResNet50().to(device), DGNet50(0).to(device), DGNet50(1e-1).to(device), DGNet50(1e-2).to(device), DGNet50(1e-3).to(device), DGNet50(1e-4).to(device), DGNet50(1e-5).to(device)]

#net_types = ['Dense18', 'DGNet18_Norm', 'DGNet18_Soft', 'Dense50', 'DGNet50_Norm', 'DGNet50_Soft']
#networks = [Dense18().to(device), DGNet18_Norm().to(device), DGNet18_Soft().to(device), Dense50().to(device), DGNet50_Norm().to(device), DGNet50_Soft().to(device)]

for net_type, net in zip(net_types, networks):
    num_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'# of parameters in {net_type} : {num_parameters}')

In [None]:
final_accs = {}
for net_type, net in zip(net_types, networks):
    try:
        optimizer = optim.SGD(net.parameters(), lr=0.4, momentum=0.9, weight_decay=0.0001)
        scheduler = CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=100, cycle_mult=1, max_lr=0.4, min_lr=0, warmup_steps=5, gamma=1)
        t1 = time.time()
        accuracy = train_net(net, optimizer, scheduler)
        t = time.time()-t1
        print(f'Best test accuracy of {net_type} network : {(accuracy*100):.3f}% took {t:.3f} secs')
        final_accs[f'{net_type}'] = accuracy*100
    except Exception as e:
        print(e)

In [None]:
for key in final_accs.keys():
    print(f'Best accuracy of {key} = {final_accs[key]:.3f}%')