In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import numpy as np
import math

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

batch_size = 256

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomRotation(degrees=(-45, 45)),
#     transforms.ColorJitter(brightness=.5,hue=0.5), # 改变图像的亮度和饱和度
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, pin_memory=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size*2,
                                         shuffle=False, pin_memory=True, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:11<00:00, 15261596.41it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


# SK-Net

In [3]:
class SKConv(nn.Module):
    def __init__(self, features, M=2, G=32, r=16, stride=1, L=32):
        super(SKConv, self).__init__()
        
        d = max(int(features/r), L)
        self.M = M
        self.features = features
        
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3, stride=stride,
                          padding=1+i, dilation=1+i, groups=G, bias=False),
                nn.BatchNorm2d(features),
                nn.ReLU(True)
            ))
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Conv2d(features, d, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(d),
            nn.ReLU(True)
        )
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(
                nn.Conv2d(d, features, kernel_size=1, stride=1)
            )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        feats = [conv(x) for conv in self.convs]
        feats = torch.cat(feats, dim=1)
        feats = feats.view(batch_size, self.M, self.features, feats.shape[2], feats.shape[3])
        
        feats_U = torch.sum(feats, dim=1)
        feats_S = self.gap(feats_U)
        feats_Z = self.fc(feats_S)
        
        attention = [fc(feats_Z) for fc in self.fcs]
        attention = torch.cat(attention, dim=1)
        attention = attention.view(batch_size, self.M, self.features, 1, 1)
        attention = self.softmax(attention)
        
        feats_V = torch.sum(feats * attention, dim=1)
        return feats_V
    
    
class SKUnit(nn.Module):
    def __init__(self, in_features, middle_features, out_features, M=2, G=32, r=16, stride=1, L=32):
        super(SKUnit, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_features, middle_features, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(middle_features),
            nn.ReLU(True)
        )
        self.conv2_sk = SKConv(middle_features, M, G, r, stride, L)
        self.conv3 = nn.Sequential(
            nn.Conv2d(middle_features, out_features, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_features)
        )
        
        if in_features == out_features:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_features, out_features, 1, stride, bias=False),
                nn.BatchNorm2d(out_features)
            )
        
        self.relu = nn.ReLU(True)
        
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.conv2_sk(out)
        out = self.conv3(out)
        out = self.relu(out + self.shortcut(identity))
        return out
    
class SKNet(nn.Module):
    def __init__(self, num_classes, block_config, strides_list=[1, 2, 2, 2]):
        super(SKNet, self).__init__()
        self.basic_conv = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.maxpool = nn.MaxPool2d(3, 2, 1)
        
        self.stage_1 = self.make_layer(64, 128, 256, num_blocks=block_config[0], stride=strides_list[0])
        self.stage_2 = self.make_layer(256, 256, 512, num_blocks=block_config[1], stride=strides_list[1])
        self.stage_3 = self.make_layer(512, 512, 1024, num_blocks=block_config[2], stride=strides_list[2])
        self.stage_4 = self.make_layer(1024, 1024, 2048, num_blocks=block_config[3], stride=strides_list[3])
        
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(2048, num_classes)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def make_layer(self, in_features, middle_features, out_features, num_blocks, stride=1):
        layers = [SKUnit(in_features, middle_features, out_features, stride)]
        for _ in range(1, num_blocks):
            layers.append(SKUnit(out_features, middle_features, out_features))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.basic_conv(x)
        out = self.maxpool(out)
        out = self.stage_1(out)
        out = self.stage_2(out)
        out = self.stage_3(out)
        out = self.stage_4(out)
        out = self.gap(out)
        out = torch.squeeze(out)
        out = self.classifier(out)
        return out

    
def SKNet26(num_classes=100):
    return SKNet(num_classes, [2,2,2,2])

def SKNet50(num_classes=100):
    return SKNet(num_classes, [3,4,6,3])

In [4]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
def adjust_learning_rate(optimizer, current_iter, warmup_iter, max_warm_up_lr):
    if current_iter <= warmup_iter:
        lr = max_warm_up_lr * current_iter / warmup_iter
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [5]:
loss_arr = []
acc_arr = []
lr_arr = []

In [6]:
def fit(model, epochs, train_loader, valid_loader, optimizer, lr_scheduler,
        loss_fn, updata_lr_every_epoch, warm_up=False, max_warm_up_lr=0.01, grad_clip=None, PATH='./Res2Net/Res2Net.pth'):
    log_interval = int((50000/batch_size)/2)
    for epoch in range(epochs):
        print(f"{'='*20} Epoch: {epoch+1} {'='*20}\n")
        model.train()
        avg_loss = 0
        
        for i, (inputs, targets) in enumerate(train_loader):
            if warm_up != False and epoch <= warm_up:
                adjust_learning_rate(optimizer, (i+1)+epoch*len(train_loader), len(train_loader)*warm_up, max_warm_up_lr=max_warm_up_lr)
#                 lr_arr.append(get_lr(optimizer))
            outputs = model(inputs.to(device))
            loss = loss_fn(outputs, targets.to(device))
            loss.backward()
            if grad_clip is not None:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            optimizer.step()
            if lr_scheduler is not None and updata_lr_every_epoch == False:
                lr_scheduler.step()
                lr_arr.append(get_lr(optimizer))
            optimizer.zero_grad()
            
            avg_loss += loss.item()
            if i % log_interval == log_interval-1:
                avg_loss = avg_loss/log_interval
#                 train_loss_arr.append(avg_loss)
                print(f"batch: {i+1}, train_loss: {avg_loss:.4f}")
#                 print(f"batch: {i+1}, train_loss: {avg_loss:.4f}, last_lr: {lr_arr[-1]:.5f}")
                avg_loss = 0
        if lr_scheduler is not None and updata_lr_every_epoch == True:
            lr_scheduler.step()
        lr_arr.append(get_lr(optimizer))
        
        model.eval()
        correct = 0
        total = 0
        avg_loss = 0
        with torch.no_grad():
            for (images, labels) in valid_loader:
                outputs = model(images.to(device))
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.to(device)).sum().item()
                avg_loss += loss_fn(outputs, labels.to(device))
            avg_loss = avg_loss.cpu() / len(valid_loader)
            loss_arr.append(avg_loss)
            acc = 100 * correct / total
            acc_arr.append(acc)
            
        print(f'Accuracy: {acc}% ({correct} / {total}), Loss: {avg_loss:.3f}, Last_lr: {lr_arr[-1]:.5f}')
    torch.save(model, PATH)

In [7]:
model = SKNet50().to(device)
num_param = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (num_param/1e6))

Number of parameter: 25.64M


In [None]:
%%time
epochs = 60
# optimizer = optim.SGD(model.parameters(), lr=0.0002, momentum=0.9, weight_decay=5e-4)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)# sched = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.2)
sched = optim.lr_scheduler.OneCycleLR(optimizer, 0.01, epochs=epochs,
                                      steps_per_epoch=len(trainloader))

# loaded_model = SKNet50().to(device)
# loaded_model.load_state_dict(torch.load('./SKNet_30(warmup).pth'))

fit(model=model,
    epochs=epochs,
    train_loader=trainloader,
    valid_loader=testloader,
    optimizer=optimizer,
    lr_scheduler=sched,
    loss_fn=nn.CrossEntropyLoss(),
    warm_up=False, max_warm_up_lr=0.01,
    grad_clip=None, updata_lr_every_epoch=False, PATH='./SKNet_60_OneCycle.pth')


batch: 97, train_loss: 4.4054
batch: 194, train_loss: 3.8841
Accuracy: 11.27% (1127 / 10000), Loss: 3.814, Last_lr: 0.00040

batch: 97, train_loss: 3.6435
batch: 194, train_loss: 3.5177
Accuracy: 17.89% (1789 / 10000), Loss: 3.525, Last_lr: 0.00040

batch: 97, train_loss: 3.3429
batch: 194, train_loss: 3.2525
Accuracy: 21.14% (2114 / 10000), Loss: 3.234, Last_lr: 0.00040

batch: 97, train_loss: 3.1087
batch: 194, train_loss: 3.0548
Accuracy: 24.78% (2478 / 10000), Loss: 3.029, Last_lr: 0.00040

batch: 97, train_loss: 2.9228
batch: 194, train_loss: 2.9099
Accuracy: 27.51% (2751 / 10000), Loss: 2.912, Last_lr: 0.00040

batch: 97, train_loss: 2.7643
batch: 194, train_loss: 2.7585
Accuracy: 29.41% (2941 / 10000), Loss: 2.812, Last_lr: 0.00040

batch: 97, train_loss: 2.6521
batch: 194, train_loss: 2.6186
Accuracy: 32.07% (3207 / 10000), Loss: 2.683, Last_lr: 0.00040

batch: 97, train_loss: 2.5099
batch: 194, train_loss: 2.4897
Accuracy: 32.99% (3299 / 10000), Loss: 2.639, Last_lr: 0.00040


In [None]:
# loss_arr = np.loadtxt('./loss.txt').tolist()
# acc_arr = np.loadtxt('./acc.txt').tolist()
# lr_arr = np.loadtxt('./lr.txt').tolist()
np.savetxt('./loss.txt', loss_arr)
np.savetxt('./acc.txt', acc_arr)
np.savetxt('./lr.txt', lr_arr)

In [None]:
plt.figure(figsize=(3, 2))
plt.plot(loss_arr)
plt.title('loss')

plt.figure(figsize=(3, 2))
plt.plot(acc_arr)
plt.title('Arr')

plt.figure(figsize=(3, 2))
plt.plot(lr_arr)
plt.title('LR')