In [4]:
import time
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from datasets.dataset import CUBDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [2]:
epoch = 1
learning_rate = 1e-3
save_part_name = 'models/part.pth'
conv_model_name = 'models/resnet.pth'


In [3]:
trainset = CUBDataset()
trainloader = DataLoader(dataset=trainset, batch_size=10, shuffle=True)
testset = CUBDataset(is_test = True)
testloader = DataLoader(dataset=testset, batch_size=10, shuffle=True)

In [26]:
class Part(nn.Module):

    def __init__(self):
        super(Part, self).__init__()
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 256)

    def forward(self, x):
        print(x.shape)
        conv_matrix = torch.clone(x)
        conv_matrix = conv_matrix.reshape(conv_matrix.size(0), 256, 1, 784)
        conv_matrix = conv_matrix.transpose(1, 3)
        x = F.avg_pool2d(x, kernel_size=28, stride=28)
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        x = torch.sigmoid(x).unsqueeze(1).unsqueeze(1)
        x = F.interpolate(x, (1, 784), mode='bilinear', align_corners=True)
        x = x.squeeze(1).squeeze(1).unsqueeze(2).unsqueeze(3)
        x = x * conv_matrix
        x = F.avg_pool2d(x, kernel_size=(1, 256), stride=256)
        x = x * 0.1
        x = F.softmax(x, dim=1)
        x = torch.exp(x)
        x = x + 1
        x = torch.log(x)
        x = x * 4
        x = x.squeeze(2).squeeze(2)
        return x.reshape(x.size(0), 28, 28)


In [27]:
class Loss(nn.Module):

    def __init__(self):
        super(Loss, self).__init__()

    def forward(self, tensor):
        loss_sum = torch.zeros(1)
        indexes = Loss.get_max_index(tensor)
        for i in range(len(indexes)):
            max_x, max_y = indexes[i]
            for j in range(tensor.size(1)):
                for k in range(tensor.size(2)):
                    loss_sum += ((max_x - j) * (max_x - j) + (max_y - k) * (max_y - k)) * tensor[i, j, k]
        return loss_sum
    
    @staticmethod
    def get_max_index(tensor):
        shape = tensor.shape
        indexes = []
        for i in range(shape[0]):
            mx = tensor[i, 0, 0]
            x, y = 0, 0
            for j in range(shape[1]):
                for k in range(shape[2]):
                    if tensor[i, j, k] > mx:
                        mx = tensor[i, j, k]
                        x, y = j, k
            indexes.append([x, y])
        return indexes


In [28]:
def get_channels(c, data):
    return c.layer3(c.layer2(c.layer1(c.maxpool(c.relu(c.bn1(c.conv1(data)))))))

In [29]:
conv

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [30]:
conv = torch.load(conv_model_name)
part = Part()
loss_fn = Loss()
optimizer = torch.optim.Adam(part.parameters(), lr = learning_rate)

In [31]:
for epoch_number in range(epoch):
    running_loss, count, acc = 0., 0, 0.
    print(time.asctime())
    for img, _ in trainloader:
        img = Variable(img)
        channels = get_channels(conv, img)
        output = part(channels)
        optimizer.zero_grad()
        loss = loss_fn(output)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        count += img.size(0)
    print(epoch_number, count, running_loss, Loss.get_max_index(output))

Mon May 17 18:59:28 2021
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([8, 256, 28, 28])
0 198 121130220.5 [[9, 0], [26, 16], [8, 10], [16, 16], [8, 0], [17, 0], [5, 10], [26, 0]]


In [21]:
torch.save(part, save_part_name)

In [24]:
count=0
for img, _ in testloader:
    img = Variable(img)
    channels = get_channels(conv, img)
    output = part(channels)
    count += img.size(0)
    print(count, Loss.get_max_index(output))

torch.Size([10, 784, 1, 1]) torch.Size([10, 784, 1, 256])
torch.Size([10, 784, 1, 256])
10 [[17, 0], [19, 0], [22, 14], [10, 0], [11, 14], [17, 8], [18, 12], [17, 0], [11, 14], [25, 23]]
torch.Size([10, 784, 1, 1]) torch.Size([10, 784, 1, 256])
torch.Size([10, 784, 1, 256])
20 [[19, 15], [22, 13], [20, 13], [25, 8], [10, 0], [13, 18], [22, 17], [19, 21], [13, 2], [17, 15]]
torch.Size([10, 784, 1, 1]) torch.Size([10, 784, 1, 256])
torch.Size([10, 784, 1, 256])
30 [[20, 13], [10, 5], [25, 0], [26, 18], [18, 0], [19, 0], [26, 0], [21, 18], [23, 20], [7, 0]]
torch.Size([10, 784, 1, 1]) torch.Size([10, 784, 1, 256])
torch.Size([10, 784, 1, 256])
40 [[17, 0], [20, 0], [16, 0], [10, 0], [21, 8], [15, 0], [20, 15], [17, 0], [18, 18], [17, 7]]
