In [38]:
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
import torch.optim as optim

import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [39]:
class MnistBags(data_utils.Dataset):
    def __init__(self, mean_bag_length=10, var_bag_length=2,
                num_bag=250, seed=1, train=True):
        self.mean_bag_length = mean_bag_length
        self.var_bag_length = var_bag_length
        self.num_bag = num_bag
        self.train = train
        
        self.r = np.random.RandomState(seed)
        
        self.num_in_train = 60000
        self.num_in_test = 10000
        
        if self.train:
            self.train_bags_list, self.train_labels_list = self._create_bags()
        else:
            self.test_bags_list, self.test_labels_list = self._create_bags()
        
    
    def _create_bags(self):
        if self.train:
            loader = data_utils.DataLoader(datasets.MNIST('../datasets',
                                                         train=True, download=True,
                                                         transform= transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,),(0.3081,))
                        ])), batch_size=self.num_in_train, shuffle=False)
        else:
            loader = data_utils.DataLoader(datasets.MNIST('../datasets',
                                                         train=False, download=True,
                                                         transform= transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,),(0.3081,))
                        ])), batch_size=self.num_in_test, shuffle=False)
        
        for (batch_data, batch_labels) in loader:
            all_imgs = batch_data
            all_labels = batch_labels
            print(all_imgs.shape)
            print(all_labels.shape)
            
        bags_list = []
        labels_list = []
        
        for i in range(self.num_bag):
            bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
            if bag_length < 1:
                bag_length = 1
            
            if self.train:
                indices = torch.LongTensor(self.r.randint(0, self.num_in_train, bag_length))
            else:
                indices = torch.LongTensor(self.r.randint(0, self.num_in_test, bag_length))
            
            labels_in_bag = all_labels[indices]
            labels1 = torch.sum(labels_in_bag == 9) # True or False
            labels2 = torch.sum(labels_in_bag == 8) # True or False'
            
            if labels1 > 0 and labels2 > 0 :
                label = torch.FloatTensor([1,1,0])
            elif labels1 > 0 and labels2 == 0 :
                label = torch.FloatTensor([1,0,0])
            elif labels1 == 0 and labels2 > 0:
                label = torch.FloatTensor([0,1,0])
            else:
                label = torch.FloatTensor([0,0,1])
            
            
            bags_list.append(all_imgs[indices])
            labels_list.append(label)
        return bags_list, labels_list
    def __len__(self):
        if self.train:
            return len(self.train_labels_list)
        else:
            return len(self.test_labels_list)
    def __getitem__(self, index):
        if self.train:
            bag = self.train_bags_list[index]
            label = self.train_labels_list[index]
        else:
            bag = self.test_bags_list[index]
            label = self.test_labels_list[index]

        return bag, label

In [40]:
train_loader = data_utils.DataLoader(MnistBags(
    mean_bag_length=10, var_bag_length=2,
    num_bag=200, seed=1, train=True), batch_size=1, shuffle=True)

test_loader = data_utils.DataLoader(MnistBags(
    mean_bag_length=10, var_bag_length=2,
    num_bag=50, seed=1, train=False), batch_size=1, shuffle=False)

torch.Size([60000, 1, 28, 28])
torch.Size([60000])
torch.Size([10000, 1, 28, 28])
torch.Size([10000])


In [41]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1
        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 4 * 4, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 3),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50 * 4 * 4)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL  i.e. 1x500

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob

    #def calculate_classification_error(self,X,Y):
    
    #def calculate_objective(self,X,Y):

In [None]:
epochs = 40

model = Attention()

optimizer = optim.Adam(model.parameters(),lr=0.0005, betas=(0.9,0.999), weight_decay=10e-5)

for i in range(epochs):
    model.train()
    train_loss = 0
    train_error = 0
    for batch_idx, (data, label) in enumerate(train_loader):
        predict = model(data)
        
        loss = F.binary_cross_entropy(predict[0][0],label[0][0]) + F.binary_cross_entropy(predict[0][1],label[0][1]) + F.binary_cross_entropy(predict[0][2],label[0][2])
        train_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    print('Epoch : {}, Loss : {:.4f}'.format(i, train_loss))
#data = torch.zeros([13,1,28,28])
#label = torch.torch.LongTensor([1,0,0])
#model(data)

Epoch : 0, Loss : 1.8987
Epoch : 1, Loss : 1.8052
Epoch : 2, Loss : 1.7416
Epoch : 3, Loss : 1.5637
Epoch : 4, Loss : 1.3572
Epoch : 5, Loss : 1.0965
Epoch : 6, Loss : 0.9558
Epoch : 7, Loss : 0.7184
Epoch : 8, Loss : 0.5702
Epoch : 9, Loss : 0.4610
Epoch : 10, Loss : 0.3579
Epoch : 11, Loss : 0.2022
Epoch : 12, Loss : 0.2069
Epoch : 13, Loss : 0.3348
Epoch : 14, Loss : 0.1342
Epoch : 15, Loss : 0.0671
Epoch : 16, Loss : 0.0571
Epoch : 17, Loss : 0.0178
Epoch : 18, Loss : 0.0113
Epoch : 19, Loss : 0.0066
Epoch : 20, Loss : 0.0048
Epoch : 21, Loss : 0.0033
Epoch : 22, Loss : 0.0024
Epoch : 23, Loss : 0.0019
Epoch : 24, Loss : 0.0016
Epoch : 25, Loss : 0.0013
Epoch : 26, Loss : 0.0011
Epoch : 27, Loss : 0.0010
Epoch : 28, Loss : 0.0008
Epoch : 29, Loss : 0.0007
Epoch : 30, Loss : 0.0006
Epoch : 31, Loss : 0.0005
Epoch : 32, Loss : 0.0005
Epoch : 33, Loss : 0.0005
Epoch : 34, Loss : 0.0004
Epoch : 35, Loss : 0.0004
Epoch : 36, Loss : 0.0004
Epoch : 37, Loss : 0.0003


In [37]:
model.eval()
test_loss = 0
num_samples = 0
num_correct = 0

for batch_idx, (data,label) in enumerate(test_loader):
    predict = model(data)
    loss = F.binary_cross_entropy(predict[0][0],label[0][0]) + F.binary_cross_entropy(predict[0][1],label[0][1]) + F.binary_cross_entropy(predict[0][2],label[0][2])
    test_loss += loss.item()
    if torch.equal(torch.ge(predict[0],0.7).float(), label[0]):
        num_correct += 1
    num_samples += 1
    
test_loss /= len(test_loader)       
print('Test Loss : {:.4f}, {}/{} Correct'.format(test_loss,num_correct,num_samples))


Test Loss : 3.1674, 39/50 Correct
