In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
from models.awpooling import AWPool2d, AWPool2d_
from models.vggaw import VGG11
# from models.vgg import 

In [2]:
train_ds = torchvision.datasets.CIFAR100(root='/notebooks/nfs/work/dataset', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=128 , num_workers=2)
images, labels = next(iter(train_loader))
t = Variable(torch.Tensor([1., 1., 1., 1., 1.]), requires_grad=True)

Files already downloaded and verified


In [44]:
class VGG11(nn.Module):
    def __init__(self, num_class=100, init_weights=True):
        super(VGG11, self).__init__()
        self.temperature = Variable(torch.Tensor([1., 1., 1., 1., 1.]), requires_grad=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
        )
        self.aw1 = AWPool2d(kernel_size=2, stride=2, temperature=self.temperature[0])
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
        )
        self.aw2 = AWPool2d(kernel_size=2, stride=2, temperature=self.temperature[1])
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256)
        )
        self.aw3 = AWPool2d(kernel_size=2, stride=2, temperature=self.temperature[2])
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
        )
        self.aw4 = AWPool2d(kernel_size=2, stride=2, temperature=self.temperature[3])
        
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
        )
        self.aw5 = AWPool2d(kernel_size=2, stride=2, temperature=self.temperature[4])
        self.globalavg = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )
        
        if init_weights:
            self._initialize_weights()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.aw1(x)
        x = self.conv2(x)
        x = self.aw2(x)
        x = self.conv3(x)
        x = self.aw3(x)
        x = self.conv4(x)
        x = self.aw4(x)
        x = self.conv5(x)
        x = self.aw5(x)
        x = self.globalavg(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        
        return x
        
    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        

In [12]:
criterion = nn.CrossEntropyLoss().cuda()

In [45]:
model = VGG11()
model = model.cuda()

In [46]:
opt1 = torch.optim.SGD(model.parameters(), lr=0.1)
opt2 = torch.optim.SGD([model.temperature], lr=0.01)

In [63]:
# for n,p in model.named_parameters():
#     print(n)
print(model.temperature)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 0.9999], requires_grad=True)


In [76]:
images, labels = images.cuda(), labels.cuda()
pred = model(images)
loss = criterion(pred, labels)
loss.backward()
opt1.step()
opt2.step()
opt1.zero_grad()
opt2.zero_grad()

In [77]:
print(model.temperature)

tensor([0.9997, 1.0000, 0.9997, 1.0001, 0.9961], requires_grad=True)


In [2]:
from models.awpooling import AWPool2d_

In [44]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.conv = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.aw = AWPool2d_(kernel_size=2, stride=2)
        self.classifier = nn.Linear(64, 100)
    def forward(self, x):
        x = self.aw(self.conv(x))
        x = torch.flatten(x, start_dim=1)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [60]:
model = net().cuda()
criterion = nn.CrossEntropyLoss()
data = torch.randn(1, 1, 2, 2).cuda()
label = torch.arange(1.0,101.0).cuda()
label = label.view(1, -1)
opt1 = torch.optim.SGD(params=[*model.conv.parameters(), *model.classifier.parameters()], lr=0.1)
opt = torch.optim.SGD(model.aw.parameters(), lr=0.1)

In [71]:
param_list = []
for n, p in model.named_parameters():
    if 'aw' in n:
        param_list += [{'params': p, 'lr': 0.1}]
    else:
        param_list += [{'params': p, 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 1e-4}]
opt = torch.optim.SGD(params=param_list)

In [76]:
model.aw.parameters()

<generator object Module.parameters at 0x7fb980f0d5f0>

In [64]:
a = torch.FloatTensor([1,2,3,4]).cuda()
a = a.view(1, 1, 2, 2)
model.aw(a)

tensor([[[[3.1754]]]], device='cuda:0', grad_fn=<SumBackward1>)

In [62]:
loss = criterion(pred, label)
loss.backward()
opt.step()
opt1.step()

In [78]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
writer.add_scalars('Test', {'t': model.aw.t}, 1)

In [5]:
ck11 = torch.load('checkpoints/tiny-imagenet/vgg11awt_best.pth.tar')
# ck13 = torch.load('checkpoints/tiny-imagenet/vgg13awt_best.pth.tar')
# ck16 = torch.load('checkpoints/tiny-imagenet/vgg16awt_best.pth.tar')
# ck19 = torch.load('checkpoints/tiny-imagenet/vgg19awt_best.pth.tar')

print(ck11['epoch'], ck11['best_acc1'], ck11['temperature'],sep='\n')

# print(ck11['best_acc1'], ck13['best_acc1'], ck16['best_acc1'], ck19['best_acc1'])
# print(ck11['temperature'], ck13['temperature'], ck16['temperature'], ck19['temperature'], sep='\n')
# print(ck11['epoch'], ck13['epoch'], ck16['epoch'], ck19['epoch'])

84
49.84
[Parameter containing:
tensor([1.], device='cuda:0'), Parameter containing:
tensor([1.], device='cuda:0'), Parameter containing:
tensor([1.], device='cuda:0'), Parameter containing:
tensor([1.], device='cuda:0'), Parameter containing:
tensor([1.], device='cuda:0')]


In [2]:
def set_mode(model, mode='normal'):
    if mode == 'normal':
        for n, p in model.named_parameters():
            if 'aw' in n:
                p.requires_grad = False
            else:
                p.requires_grad = True
    elif mode == 'temperature':
        for n, p in model.named_parameters():
            if 'aw' in n:
                p.requires_grad = True
            else:
                p.requires_grad = False

In [2]:
model = VGG11()

model.disable_t()
for n, p in model.named_parameters():
    print(n, p.requires_grad)

# set_mode(model, mode='temperature')

True
False
conv1.0.weight True
conv1.0.bias True
conv1.2.weight True
conv1.2.bias True
aw1.t False
conv2.0.weight True
conv2.0.bias True
conv2.2.weight True
conv2.2.bias True
aw2.t True
conv3.0.weight True
conv3.0.bias True
conv3.2.weight True
conv3.2.bias True
conv3.3.weight True
conv3.3.bias True
conv3.5.weight True
conv3.5.bias True
aw3.t True
conv4.0.weight True
conv4.0.bias True
conv4.2.weight True
conv4.2.bias True
conv4.3.weight True
conv4.3.bias True
conv4.5.weight True
conv4.5.bias True
aw4.t True
conv5.0.weight True
conv5.0.bias True
conv5.2.weight True
conv5.2.bias True
conv5.3.weight True
conv5.3.bias True
conv5.5.weight True
conv5.5.bias True
aw5.t True
classifier.0.weight True
classifier.0.bias True
classifier.3.weight True
classifier.3.bias True
classifier.6.weight True
classifier.6.bias True


In [15]:
for n, p in model.named_parameters():
    print(n, p.requires_grad)


conv1.0.weight False
conv1.0.bias False
conv1.2.weight False
conv1.2.bias False
aw1.t True
conv2.0.weight False
conv2.0.bias False
conv2.2.weight False
conv2.2.bias False
aw2.t True
conv3.0.weight False
conv3.0.bias False
conv3.2.weight False
conv3.2.bias False
conv3.3.weight False
conv3.3.bias False
conv3.5.weight False
conv3.5.bias False
aw3.t True
conv4.0.weight False
conv4.0.bias False
conv4.2.weight False
conv4.2.bias False
conv4.3.weight False
conv4.3.bias False
conv4.5.weight False
conv4.5.bias False
aw4.t True
conv5.0.weight False
conv5.0.bias False
conv5.2.weight False
conv5.2.bias False
conv5.3.weight False
conv5.3.bias False
conv5.5.weight False
conv5.5.bias False
aw5.t True
classifier.0.weight False
classifier.0.bias False
classifier.3.weight False
classifier.3.bias False
classifier.6.weight False
classifier.6.bias False


# 分開訓練, 交互訓練
