In [1]:
import time
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from datasets.dataset import CUBDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [2]:
epoch = 1
learning_rate = 1e-3
save_part_name = 'models/part.pth'
conv_model_name = 'models/resnet.pth'


In [3]:
trainset = CUBDataset()
trainloader = DataLoader(dataset=trainset, batch_size=10, shuffle=True)
testset = CUBDataset(is_test = True)
testloader = DataLoader(dataset=testset, batch_size=10, shuffle=True)

In [11]:
class Part(nn.Module):

    def __init__(self):
        super(Part, self).__init__()
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 256)

    def forward(self, x):
        print(x)
        conv_matrix = torch.clone(x)
        conv_matrix = conv_matrix.reshape(conv_matrix.size(0), 256, 1, 784)
        conv_matrix = conv_matrix.transpose(1, 3)
        x = F.avg_pool2d(x, kernel_size=28, stride=28)
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        x = torch.sigmoid(x).unsqueeze(1).unsqueeze(1)
        x = F.interpolate(x, (1, 784), mode='bilinear', align_corners=True)
        x = x.squeeze(1).squeeze(1).unsqueeze(2).unsqueeze(3)
        x = x * conv_matrix
        x = F.avg_pool2d(x, kernel_size=(1, 512), stride=512)
        x = x * 0.1
        x = F.softmax(x, dim=1)
        x = torch.exp(x)
        x = x + 1
        x = torch.log(x)
        x = x * 4
        x = x.squeeze(2).squeeze(2)
        return x.reshape(x.size(0), 28, 28)


In [12]:
class Loss(nn.Module):

    def __init__(self):
        super(Loss, self).__init__()

    def forward(self, tensor):
        loss_sum = torch.zeros(1)
        indexes = Loss.get_max_index(tensor)
        for i in range(len(indexes)):
            max_x, max_y = indexes[i]
            for j in range(tensor.size(1)):
                for k in range(tensor.size(2)):
                    loss_sum += ((max_x - j) * (max_x - j) + (max_y - k) * (max_y - k)) * tensor[i, j, k]
        return loss_sum
    
    @staticmethod
    def get_max_index(tensor):
        shape = tensor.shape
        indexes = []
        for i in range(shape[0]):
            mx = tensor[i, 0, 0]
            x, y = 0, 0
            for j in range(shape[1]):
                for k in range(shape[2]):
                    if tensor[i, j, k] > mx:
                        mx = tensor[i, j, k]
                        x, y = j, k
            indexes.append([x, y])
        return indexes


In [13]:
def get_channels(c, data):
    return c.layer3(c.layer2(c.layer1(c.maxpool(c.relu(c.bn1(c.conv1(data)))))))

In [None]:
conv

In [14]:
conv = torch.load(conv_model_name)
part = Part()
loss_fn = Loss()
optimizer = torch.optim.Adam(part.parameters(), lr = learning_rate)

In [15]:
for epoch_number in range(epoch):
    running_loss, count, acc = 0., 0, 0.
    print(time.asctime())
    for img, _ in trainloader:
        img = Variable(img)
        channels = get_channels(conv, img)
        output = part(channels)
        optimizer.zero_grad()
        loss = loss_fn(output)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        count += img.size(0)
    print(epoch_number, count, running_loss, Loss.get_max_index(output))

Sun May 16 16:46:16 2021
tensor([[[[0.0000e+00, 9.2169e-01, 1.6786e+00,  ..., 1.8804e+00,
           1.9598e+00, 2.2441e+00],
          [0.0000e+00, 1.6703e+00, 2.2644e-01,  ..., 5.3503e-01,
           6.5314e-01, 8.6552e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.2412e-01,
           3.3317e-01, 3.9077e-01],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.9310e-01,
           3.3819e-01, 3.4165e-01],
          [0.0000e+00, 6.4395e-02, 0.0000e+00,  ..., 3.3094e-01,
           3.5505e-01, 3.6393e-01],
          [0.0000e+00, 2.5339e+00, 0.0000e+00,  ..., 3.5984e-01,
           3.8635e-01, 3.5788e-01]],

         [[5.7494e+00, 0.0000e+00, 0.0000e+00,  ..., 5.3329e-01,
           8.8991e-01, 3.4730e-01],
          [9.5474e+00, 0.0000e+00, 0.0000e+00,  ..., 9.0328e-02,
           9.6410e-02, 1.5726e-01],
          [1.0745e+01, 0.0000e+00, 0.0000e+00,  ..., 2.2157e-01,
           2.4789e-01, 3.0047e-01],
          ...,
          [9.6561e+00, 0.0000e+00,

RuntimeError: Given input size: (784x1x256). Calculated output size: (784x1x0). Output size is too small

In [10]:
get_channels(conv,img).shape

torch.Size([10, 256, 28, 28])

In [None]:
img

In [None]:
torch.save(part, save_part_name)

In [None]:
count=0
for img, _ in testloader:
    img = Variable(img)
    channels = get_channels(conv, img)
    output = part(channels)
    count += img.size(0)
    print(count, Loss.get_max_index(output))