In [7]:
import time
import torch.optim
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from datasets.dataset import CUBDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [8]:
target_accuracy = 95.
learning_rate = 1e-3
input_channels = 1
output_features = 6
save_clf_name = 'models/clf.pth'
part_model_name = 'models/part.pth'
conv_model_name = 'models/resnet.pth'

In [9]:
trainset = CUBDataset()
trainloader = DataLoader(dataset=trainset, batch_size=10, shuffle=True)
testset = CUBDataset(is_test = True)
testloader = DataLoader(dataset=testset, batch_size=10, shuffle=True)

In [17]:
class Part(nn.Module):

    def __init__(self):
        super(Part, self).__init__()
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 256)

    def forward(self, x):
        conv_matrix = torch.clone(x)
        conv_matrix = conv_matrix.reshape(conv_matrix.size(0), 256, 1, 784)
        conv_matrix = conv_matrix.transpose(1, 3)
        x = F.avg_pool2d(x, kernel_size=28, stride=28)
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        x = torch.sigmoid(x).unsqueeze(1).unsqueeze(1)
        x = F.interpolate(x, (1, 784), mode='bilinear', align_corners=True)
        x = x.squeeze(1).squeeze(1).unsqueeze(2).unsqueeze(3)
        x = x * conv_matrix
        x = F.avg_pool2d(x, kernel_size=(1, 256), stride=256)
        x = x * 0.1
        x = F.softmax(x, dim=1)
        x = torch.exp(x)
        x = x + 1
        x = torch.log(x)
        x = x * 4
        x = x.squeeze(2).squeeze(2)
        return x.reshape(x.size(0), 28, 28)


In [18]:
class Clf(nn.Module):

    def __init__(self):
        super(Clf, self).__init__()
        self.res1 = models.resnet18()
        self.res1.conv1 = nn.Conv2d(256, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.res1.fc = nn.Linear(in_features=512, out_features=output_features, bias=True)
        self.res2 = models.resnet18()
        self.res2.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.res2.fc = nn.Linear(in_features=512, out_features=output_features, bias=True)

    def forward(self, channels, attention):
        xc = self.res1(channels)
        xa = self.res2(attention)
        return F.softmax(xc + xa, dim=1)


In [19]:
def get_channels(c, data):
    return c.layer3(c.layer2(c.layer1(c.maxpool(c.relu(c.bn1(c.conv1(data)))))))

In [20]:
clf = Clf()

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.RMSprop(clf.parameters(), lr = learning_rate)

In [21]:
conv = torch.load(conv_model_name)
part = torch.load(part_model_name)

In [23]:
epoch = 0
while True:
    for data in trainloader:
        img, label = data
        img = Variable(img)
        label = Variable(label).type(torch.long)
        channels = get_channels(conv, img)
        attention = part(channels).unsqueeze(1)
        output = clf(channels, attention)
        optimizer.zero_grad()
        loss = loss_fn(output, label)
        loss.backward()
        optimizer.step()
    count, acc = 0, 0.
    for data in testloader:
        img, label = data
        img = Variable(img)
        label = Variable(label)
        channels = get_channels(conv, img)
        attention = part(channels).unsqueeze(1)
        output = clf(channels, attention)
        acc += (torch.max(output, dim=1)[1]==label).sum()
        count += img.size(0)
    print(epoch, (int(acc)/count)*100,'%', time.asctime())
    if (int(acc)/count)*100 > target_accuracy:
        torch.save(clf, save_clf_name)
        break
    epoch += 1


0 37.5 % Mon May 17 18:41:51 2021
1 35.0 % Mon May 17 18:43:02 2021
2 45.0 % Mon May 17 18:44:13 2021
3 25.0 % Mon May 17 18:45:26 2021


KeyboardInterrupt: 

In [None]:
chaos_matrix = torch.zeros((output_features, output_features))
for data in testloader:
    img, label = data
    img = Variable(img)
    label = Variable(label)
    channels = get_channels(conv, img)
    attention = part(channels).unsqueeze(1)
    output = clf(channels, attention)
    output = torch.max(output, dim=1)[1]
    for (ix, iy) in zip(output, label):
        chaos_matrix[ix, iy] += 1

print(chaos_matrix)