Instructions:

Before running, put the files data_batch_1, data_batch_2, data_batch_3, data_batch_4, data_batch_5, test_batch in the root directory. 

In our experiment, epoch_num = 100, batch_size=100.

Load CIFAR-10

In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from keras.datasets import cifar10
import time

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

class MyDataset(Dataset):
    def __init__(self, mode='train', root_path='/'):
        super(MyDataset, self).__init__()
        if mode == 'train':
            file_path = os.path.join(root_path, 'data_batch_{}')
            self.data, self.labels = load_traindata(file_path=file_path)
        elif mode == 'test':
            file_path = os.path.join(root_path, 'test_batch')
            data_dict = unpickle(file_path)
            self.data = data_dict[b'data']
            self.labels = data_dict[b'labels']
        self.data = self.data/255
        self.num = len(self.labels)

    def __len__(self):
        return self.num

    def __getitem__(self, index):
        return self.data[index, :].reshape(3, 32, 32).astype(np.float32), self.labels[index]


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


def load_traindata(file_path):
    train_data = None
    train_labels = None
    if not os.path.exists(file_path.format(1)):
        print('wrong dataset path : {}'.format(file_path.format(1)))
        exit()
    for i in range(5):
        data_dict = unpickle(file_path.format(i+1))
        if train_data is None:
            train_data = data_dict[b'data']
            train_labels = data_dict[b'labels']
        else:
            train_data = np.concatenate((train_data, data_dict[b'data']), axis=0)
            train_labels = np.concatenate((train_labels, data_dict[b'labels']), axis=0)
    return train_data, train_labels

ResNeXt

In [None]:
class ResNeXtUnit(nn.Module):
    def __init__(self, in_features, out_features, mid_features=None, stride=1, groups=32):
        super(ResNeXtUnit, self).__init__()
        if mid_features is None:
            mid_features = int(out_features/2)
        self.feas = nn.Sequential(
            nn.Conv2d(in_features, mid_features, 1, stride=1),
            nn.BatchNorm2d(mid_features),
            nn.Conv2d(mid_features, mid_features, 3, stride=stride, padding=1, groups=groups),
            nn.BatchNorm2d(mid_features),
            nn.Conv2d(mid_features, out_features, 1, stride=1),
            nn.BatchNorm2d(out_features)
        )
        if in_features == out_features:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_features, out_features, 1, stride=stride),
                nn.BatchNorm2d(out_features)
            )
    
    def forward(self, x):
        fea = self.feas(x)
        return fea + self.shortcut(x)


class ResNeXt(nn.Module):
    def __init__(self, class_num):
        super(ResNeXt, self).__init__()
        self.basic_conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        ) # 32x32
        self.stage_1 = nn.Sequential(
            ResNeXtUnit(64, 256, mid_features=128),
            nn.ReLU(),
            ResNeXtUnit(256, 256),
            nn.ReLU(),
            ResNeXtUnit(256, 256),
            nn.ReLU()
        ) # 32x32
        self.stage_2 = nn.Sequential(
            ResNeXtUnit(256, 512, stride=2),
            nn.ReLU(),
            ResNeXtUnit(512, 512),
            nn.ReLU(),
            ResNeXtUnit(512, 512),
            nn.ReLU()
        ) # 16x16
        self.stage_3 = nn.Sequential(
            ResNeXtUnit(512, 1024, stride=2),
            nn.ReLU(),
            ResNeXtUnit(1024, 1024),
            nn.ReLU(),
            ResNeXtUnit(1024, 1024),
            nn.ReLU()
        ) # 8x8
        self.pool = nn.AvgPool2d(8)
        self.classifier = nn.Sequential(
            nn.Linear(1024, class_num),
            #nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        fea = self.basic_conv(x)
        fea = self.stage_1(fea)
        fea = self.stage_2(fea)
        fea = self.stage_3(fea)
        fea = self.pool(fea)
        fea = torch.squeeze(fea) #torch.Size([1024])
        fea = self.classifier(fea) #torch.Size([10])
        return fea


    # x = torch.rand(1, 3, 32, 32)
    # a = nn.Sequential(
    #         nn.Conv2d(3, 64, 3, padding=1),
    #         nn.BatchNorm2d(64)
    #     )(x)
    # b = nn.Sequential(
    #         ResNeXtUnit(64, 256, mid_features=128),
    #         nn.ReLU(),
    #         ResNeXtUnit(256, 256),
    #         nn.ReLU(),
    #         ResNeXtUnit(256, 256),
    #         nn.ReLU()
    #     )(a)
    # c = nn.Sequential(
    #         ResNeXtUnit(256, 512, stride=2),
    #         nn.ReLU(),
    #         ResNeXtUnit(512, 512),
    #         nn.ReLU(),
    #         ResNeXtUnit(512, 512),
    #         nn.ReLU()
    #     )(b)
    # d = nn.Sequential(
    #         ResNeXtUnit(512, 1024, stride=2),
    #         nn.ReLU(),
    #         ResNeXtUnit(1024, 1024),
    #         nn.ReLU(),
    #         ResNeXtUnit(1024, 1024),
    #         nn.ReLU()
    #     )(c)
    # print(a.size())
    # print(b.size())
    # print(c.size())
    # print(d.size())
    # e=nn.AvgPool2d(8)(d).squeeze()
    # print(e.size())
    # f=nn.Linear(1024, 10)(e)
    # print(f)

SKNet

In [None]:
class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1 ,L=32):
        """ Constructor
        Args:
            features: input channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            stride: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super(SKConv, self).__init__()
        d = max(int(features/r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3+i*2, stride=stride, padding=1+i, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = torch.unsqueeze(conv(x),dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
        fea_U = torch.sum(feas, dim=1)
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            vector = torch.unsqueeze(fc(fea_z),dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v


class SKUnit(nn.Module):
    def __init__(self, in_features, out_features, WH, M, G, r, mid_features=None, stride=1, L=32):
        """ Constructor
        Args:
            in_features: input channel dimensionality.
            out_features: output channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            mid_features: the channle dim of the middle conv with stride not 1, default out_features/2.
            stride: stride.
            L: the minimum dim of the vector z in paper.
        """
        super(SKUnit, self).__init__()
        if mid_features is None:
            mid_features = int(out_features/2)
        self.feas = nn.Sequential(
            nn.Conv2d(in_features, mid_features, 1, stride=1),
            nn.BatchNorm2d(mid_features),
            SKConv(mid_features, WH, M, G, r, stride=stride, L=L),
            nn.BatchNorm2d(mid_features),
            nn.Conv2d(mid_features, out_features, 1, stride=1),
            nn.BatchNorm2d(out_features)
        )
        if in_features == out_features:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_features, out_features, 1, stride=stride),
                nn.BatchNorm2d(out_features)
            )
    
    def forward(self, x):
        fea = self.feas(x)
        return fea + self.shortcut(x)


class SKNet(nn.Module):
    def __init__(self, class_num):
        super(SKNet, self).__init__()
        self.basic_conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        ) # 32x32
        self.stage_1 = nn.Sequential(
            SKUnit(64, 256, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(256, 256, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(256, 256, 32, 2, 8, 2),
            nn.ReLU()
        ) # 32x32
        self.stage_2 = nn.Sequential(
            SKUnit(256, 512, 32, 2, 8, 2, stride=2),
            nn.ReLU(),
            SKUnit(512, 512, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(512, 512, 32, 2, 8, 2),
            nn.ReLU()
        ) # 16x16
        self.stage_3 = nn.Sequential(
            SKUnit(512, 1024, 32, 2, 8, 2, stride=2),
            nn.ReLU(),
            SKUnit(1024, 1024, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(1024, 1024, 32, 2, 8, 2),
            nn.ReLU()
        ) # 8x8
        self.pool = nn.AvgPool2d(8)
        self.classifier = nn.Sequential(
            nn.Linear(1024, class_num),
            #nn.Softmax(dim=1)
        )

    def forward(self, x):
        fea = self.basic_conv(x)
        fea = self.stage_1(fea)
        fea = self.stage_2(fea)
        fea = self.stage_3(fea)
        fea = self.pool(fea)
        fea = torch.squeeze(fea) #torch.Size([1024])
        fea = self.classifier(fea) #torch.Size([10])
        return fea


    # x = torch.rand(1, 3, 32, 32)
    # a = nn.Sequential(
    #         nn.Conv2d(3, 64, 3, padding=1),
    #         nn.BatchNorm2d(64)
    #     )(x)
    # b = nn.Sequential(
    #         SKUnit(64, 256, 32, 2, 8, 2),
    #         nn.ReLU(),
    #         SKUnit(256, 256, 32, 2, 8, 2),
    #         nn.ReLU(),
    #         SKUnit(256, 256, 32, 2, 8, 2),
    #         nn.ReLU()
    #     )(a)
    # c = nn.Sequential(
    #         SKUnit(256, 512, 32, 2, 8, 2, stride=2),
    #         nn.ReLU(),
    #         SKUnit(512, 512, 32, 2, 8, 2),
    #         nn.ReLU(),
    #         SKUnit(512, 512, 32, 2, 8, 2),
    #         nn.ReLU()
    #     )(b)
    # d = nn.Sequential(
    #         SKUnit(512, 1024, 32, 2, 8, 2, stride=2),
    #         nn.ReLU(),
    #         SKUnit(1024, 1024, 32, 2, 8, 2),
    #         nn.ReLU(),
    #         SKUnit(1024, 1024, 32, 2, 8, 2),
    #         nn.ReLU()
    #     )(c)
    # print(a.size())
    # print(b.size())
    # print(c.size())
    # print(d.size())
    # e=nn.AvgPool2d(8)(d).squeeze()
    # print(e.size())
    # f=nn.Linear(1024, 10)(e)
    # print(f)

Training and testing

In [None]:
def train_epoch(model, optimizer, train_loader, criterion, epoch, writer=None):
    model.train()
    num = len(train_loader)
    for i, (data, label) in enumerate(train_loader):
        model.zero_grad()
        optimizer.zero_grad()
        data = data.cuda()
        label = label.cuda().long()
        result = model(data)
        loss = criterion(result, label)
        loss.backward()
        optimizer.step()
        if i%10==0:
            print('epoch {}, [{}/{}], loss {}'.format(epoch, i, num, loss))
            if writer is not None:
                writer.add_scalar('loss', loss.item(), epoch*num + i)


def test(model, test_loader, criterion, epoch, writer=None):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, (data, label) in enumerate(test_loader):
            data = data.cuda()
            label = label.cuda()
            result = model(data)
            test_loss += criterion(result, label).item()
            pred = result.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()
    print('epoch {}, test loss {}, acc [{}/{}]'.format(epoch, test_loss, correct, len(test_loader.dataset)))
    if writer is not None:
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('acc', correct/len(test_loader.dataset), epoch)

ResNeXt model

In [None]:
tic = time.process_time()
root_path = '/'

train_loader = torch.utils.data.DataLoader(MyDataset('train', root_path=root_path), batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(MyDataset('test', root_path=root_path), batch_size=100)

net = ResNeXt(10)

net.cuda()
optimizer = optim.Adam(net.parameters(), weight_decay=1e-5, betas=(0.9, 0.999))
criterion = nn.CrossEntropyLoss().cuda()

log_path = './logs/'
writer = SummaryWriter(log_path)

epoch_num = 100
lr0 = 1e-3

for epoch in range(epoch_num):
    current_lr = lr0 / 2**int(epoch/50)
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr
    train_epoch(net, optimizer, train_loader, criterion, epoch, writer=writer)
    test(net, test_loader, criterion, epoch, writer=writer)
    if (epoch+1)%5==0:
        torch.save(net.state_dict(), os.path.join('/ResNeXt_model_{}.pth'.format(epoch)))
torch.save(net.state_dict(), os.path.join('/ResNeXt_model.pth'))

print("Processing time:", time.process_time() - tic)

SKNet model

In [None]:
tic = time.process_time()
root_path = '/'

train_loader = torch.utils.data.DataLoader(MyDataset('train', root_path=root_path), batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(MyDataset('test', root_path=root_path), batch_size=100)

net = SKNet(10)

net.cuda()
optimizer = optim.Adam(net.parameters(), weight_decay=1e-5, betas=(0.9, 0.999))
criterion = nn.CrossEntropyLoss().cuda()

log_path = './logs/'
writer = SummaryWriter(log_path)

epoch_num = 100
lr0 = 1e-3

for epoch in range(epoch_num):
    current_lr = lr0 / 2**int(epoch/50)
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr
    train_epoch(net, optimizer, train_loader, criterion, epoch, writer=writer)
    test(net, test_loader, criterion, epoch, writer=writer)
    if (epoch+1)%5==0:
        torch.save(net.state_dict(), os.path.join('/SKNet_model_{}.pth'.format(epoch)))
torch.save(net.state_dict(), os.path.join('/SKNet_model.pth'))

print("Processing time:", time.process_time() - tic)