In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import numpy as np
import math

In [None]:
device = 'cuda'
batch_size = 256

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
#     transforms.RandomRotation(degrees=(-45, 45)),
#     transforms.ColorJitter(brightness=.5,hue=0.5), # 改变图像的亮度和饱和度
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, pin_memory=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size*2,
                                         shuffle=False, pin_memory=True, num_workers=4)

# Triple Attention

In [None]:
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU(True) if relu else None
        
    def forward(self, x):
        out = self.conv(x)
        if self.bn is not None:
            out = self.bn(out)
        if self.relu is not None:
            out = self.relu(out)
        return out
    
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat(
            (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
        )
    
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(
            2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False
        )
    
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid_(x_out)
        return x * scale
    
class TripleAttention(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"], no_spatial=False):
        super(TripleAttention, self).__init__()
        self.ChannelGate_H = SpatialGate()
        self.ChannelGate_W = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        out_CW = self.ChannelGate_H(x_perm1)
        out_CW = out_CW.permute(0, 2, 1, 3).contiguous()

        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        out_CH = self.ChannelGate_W(x_perm2)
        out_CH = out_CH.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            out = self.SpatialGate(x)
            out = (1/3) * (out + out_CW + out_CH)
        else:
            out = (1/2) * (out_CW + out_CH)
        return out

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes,
                     kernel_size=3, stride=stride,
                     padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes,
                     kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion=1
    def __init__(self, in_planes, out_planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, out_planes, stride)
        self.bn1= nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(True)
        self.conv2 = conv3x3(out_planes, out_planes)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.downsample = downsample
        self.stride = stride
        
        self.triplet_attention = TripleAttention(out_planes, 16)
        
    def forward(self, x):
        if self.downsample is not None:
            identity = self.downsample(x)
        else:
            identity = x
            
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out = self.triplet_attention(out)
        
        out += identity
        out = self.relu(out)
        return out
    
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, out_planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(in_planes, out_planes)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.conv2 = conv3x3(out_planes, out_planes, stride)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.conv3 = conv1x1(out_planes, out_planes*self.expansion)
        self.bn3 = nn.BatchNorm2d(out_planes*self.expansion)
        self.relu = nn.ReLU(True)
        self.downsample = downsample
        self.stride = stride
        
        self.triplet_attention = TripleAttention(out_planes * expansion, 16)
        
    def forward(self, x):
        if self.downsample is not None:
            identity = self.downsample(x)
        else:
            identity = x
            
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        out = self.triplet_attention(out)
        
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=100):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        self.relu = nn.ReLU(True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.make_layer(block, 64, layers[0])
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512*block.expansion, num_classes)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def make_layer(self, block, out_planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != out_planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.in_planes, out_planes*block.expansion, stride),
                nn.BatchNorm2d(out_planes*block.expansion)
            )
        
        layers = []
        layers.append(block(self.in_planes, out_planes, stride, downsample))
        self.in_planes = out_planes*block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, out_planes))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out
    
def Triple_Attention_ResNet18():
    model = ResNet(BasicBlock, [2,2,2,2])
    return model

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
def adjust_learning_rate(optimizer, current_iter, warmup_iter, max_warm_up_lr=0.1):
    if current_iter <= warmup_iter:
        lr = max_warm_up_lr * current_iter / warmup_iter
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
loss_arr = []
acc_arr = []
lr_arr = []

def fit(model, epochs, train_loader, valid_loader, optimizer, lr_scheduler,
        loss_fn, updata_lr_every_epoch, warm_up=False, max_warm_up_lr=0.01, grad_clip=None, PATH='./SKNet.pth'):
    log_interval = int((50000/batch_size)/2)
    for epoch in range(epochs):
        print(f"{'='*20} Epoch: {epoch+1} {'='*20}\n")
        model.train()
        avg_loss = 0
        
        for i, (inputs, targets) in enumerate(train_loader):
            if warm_up != False and epoch <= warm_up:
                adjust_learning_rate(optimizer, (i+1)+epoch*len(train_loader), len(train_loader)*warm_up, max_warm_up_lr=max_warm_up_lr)
                lr_arr.append(get_lr(optimizer))
            outputs = model(inputs.to(device))
            loss = loss_fn(outputs, targets.to(device))
            loss.backward()
            if grad_clip is not None:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            optimizer.step()
            if lr_scheduler is not None and updata_lr_every_epoch == False:
                lr_scheduler.step()
                lr_arr.append(get_lr(optimizer))
            optimizer.zero_grad()
            
            avg_loss += loss.item()
            if i % log_interval == log_interval-1:
                avg_loss = avg_loss/log_interval
                print(f"batch: {i+1}, train_loss: {avg_loss:.4f}")
                avg_loss = 0
                
        if lr_scheduler is not None and updata_lr_every_epoch == True:
            lr_scheduler.step()
            lr_arr.append(get_lr(optimizer))
        if lr_scheduler is None:
            lr_arr.append(get_lr(optimizer))
        
        model.eval()
        top1_correct = 0
        top5_correct = 0
        total = 0
        avg_loss = 0
        with torch.no_grad():
            for (images, labels) in valid_loader:
                labels = labels.to(device)
                outputs = model(images.to(device))
                _, predicted = torch.topk(outputs.data, 5, dim=1, largest=True, sorted=True)
                total += labels.cpu().size(0)
                top1_correct += (predicted[:, 0] == labels).sum().item()
                for i in range(len(labels)):
                    if labels[i] in predicted[i]:
                        top5_correct += 1
                avg_loss += loss_fn(outputs, labels)
            avg_loss = avg_loss.cpu() / len(valid_loader)
            loss_arr.append(avg_loss)
            top1_acc = 100 * top1_correct / total
            top5_acc = 100 * top5_correct / total
            acc_arr.append(top1_acc)
            
        print(f'Top1 Accuracy: {top1_acc}% ({top1_correct} / {total}), Top5 Accuracy: {top5_acc}% ({top5_correct} / {total}), Loss: {avg_loss:.3f}, Last_lr: {lr_arr[-1]:.5f}')
    torch.save(model, PATH)

In [None]:
model = Triple_Attention_ResNet18().to(device)
num_param = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (num_param/1e6))

In [None]:
epochs = 30
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# sched = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.2)
sched = optim.lr_scheduler.OneCycleLR(optimizer, 0.001, epochs=epochs,
                                      steps_per_epoch=len(trainloader))

fit(model=model,
    epochs=epochs,
    train_loader=trainloader,
    valid_loader=testloader,
    optimizer=optimizer,
    lr_scheduler=sched,
    loss_fn=nn.CrossEntropyLoss(),
    warm_up=False, grad_clip=None, updata_lr_every_epoch=False, PATH='/TripleAttention.pth')

In [None]:
plt.figure(figsize=(3, 2))
plt.plot(loss_arr)
plt.title('loss')

plt.figure(figsize=(3, 2))
plt.plot(acc_arr)
plt.title('Arr')

plt.figure(figsize=(3, 2))
plt.plot(lr_arr)
plt.title('LR')