In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from hessian import *
import operator
torch.manual_seed(1)


<torch._C.Generator at 0x106320d10>

In [2]:
train_data=torchvision.datasets.MNIST('./', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.MNIST('./', train=False, download=True,transform=torchvision.transforms.ToTensor())
# train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=True)

In [3]:
train_dict={}
for (data, target) in train_data:
    if  target not in train_dict:
        train_dict[ target]=[]
        train_dict[ target].append(data)
    else:
        train_dict[ target].append(data)

In [4]:
initial_train_data=[]
initial_train_label=[]
for i in range(5):
    initial_train_data.append(train_dict[i][0])
    initial_train_label.append(i)
    
print(initial_train_label)

[0, 1, 2, 3, 4]


In [5]:
initial_train_data_tensor=torch.stack(initial_train_data)
initial_train_label_tensor=torch.tensor(initial_train_label)
train_x=torch.stack([data for (data, target) in train_data])
train_label=torch.tensor([target for (data, target) in train_data])
print(initial_train_data_tensor.size())
print(initial_train_label_tensor.size())

torch.Size([5, 1, 28, 28])
torch.Size([5])


In [16]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.final_weight=torch.tensor(np.random.randn(20,10), requires_grad=True)
        params = list(self.final_weight)
        self.optimizer = optim.Adam(params, lr=0.001)


#         for name, param in self.named_parameters():
#             if param.requires_grad:
#                 print (name, param.data)

    def forward(self, x):
        x=x.view(-1,1,28,28)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x= torch.matmul(x,self.final_weight)
        return F.log_softmax(x,dim=-1)
    
    def predict(self,x):
        with torch.no_grad():
            output = torch.exp(self.forward(x))
            pred = output.data.max(dim=1, keepdim=True)[1]
            return pred
    
    def train(self,x,label):
        train_losses = []
        if x.size(0)>100:
            for it in range(0,1000):
                index=np.random.choice(x.size(0),100)
                self.optimizer.zero_grad()
                output = self.forward(x[index])
                loss = F.nll_loss(output,label[index])
                loss.backward()
                self.optimizer.step()
                train_losses.append(loss.item())
        
        else:    
            for it in range(0,1000):
                self.optimizer.zero_grad()
                output = self.forward(x)
                loss = F.nll_loss(output,label)
                loss.backward()
                self.optimizer.step()
                train_losses.append(loss.item())
                print(self.final_weight[0])
        plt.plot(train_losses)
        plt.show()
        
        
    def test(self):
        correct=0
        for data, target in test_loader:
            pred = self.predict(data)
            correct += pred.eq(target.data.view_as(pred)).sum()
            correct_ratio= float(correct)/len(test_loader.dataset)
        return correct_ratio
        
    def uncertainty(self,x,label):
        output = torch.exp(self.forward(x))
        loss = F.nll_loss(output,torch.tensor([label]))+torch.norm(network.final_weight,2)
        jacobian_w=jacobian(output,self.final_weight)
        hessian_w=hessian(loss,self.final_weight)+torch.eye(100)*1e-6
        hessian_inverse=torch.inverse(hessian_w)
        left=torch.matmul(jacobian_w,hessian_inverse)
        pos_cov=torch.matmul(left,jacobian_w.t())
        (sign, logdet) = np.linalg.slogdet(pos_cov.detach().numpy()) 
        entropy=5+5*np.log(2*np.pi)+0.5*logdet
        return entropy
    

In [17]:
active_bnn = Net()
active_bnn.train(initial_train_data_tensor,initial_train_label_tensor)

ValueError: can't optimize a non-leaf Tensor

In [None]:
bnn_test_list=[]
bnn_test_list.append(active_bnn.test())
print(bnn_test_list)

In [None]:
active_bnn = Net()
active_bnn.train(initial_train_data_tensor,initial_train_label_tensor)
bnn_test_list=[]
bnn_test_list.append(active_bnn.test())
print(bnn_test_list)
current_train=initial_train_data_tensor
current_label=initial_train_label_tensor
for epoch in range(60):
    index=np.random.choice(60000,10)
    current_train=torch.cat([current_train,train_x[index]],0)
    current_label=torch.cat([current_label,train_label[index]],0)
    active_bnn.train(current_train,current_label)
    ratio=active_bnn.test()
    print('ratio',ratio)
    bnn_test_list.append(ratio)

In [None]:
plt.plot(bnn_test_list,label='random sampling')
plt.plot(active_list,label='active learning')
plt.title('Sample efficiency comparison')
plt.xlabel('number of training data (*10)') 
plt.ylabel('accuracy on test dataset')
plt.legend()
plt.show()

In [None]:
current_train=initial_train_data_tensor
current_label=initial_train_label_tensor
for epoch in range(600):
    entropy_dict={}
    for i in range(epoch*1000,(epoch+1)*1000):
        train_epoch_x=train_x[i]
        train_epoch_label=train_label[i]
        entropy_dict[i]=active_bnn.uncertainty(train_x[i],train_label[i])
    sorted_dict = sorted(entropy_dict.items(), key=operator.itemgetter(1), reverse=True)
    sorted_index=[i[0] for i in sorted_dict[:10]]
    current_train=torch.cat([current_train,train_x[sorted_index]],0)
    current_label=torch.cat([current_label,train_label[sorted_index]],0)
    print(current_train.size())  
    print(current_label.size()) 
    active_bnn.train(current_train,current_label)
    ratio=active_bnn.test()
    print('ratio',ratio)
    bnn_test_list.append(ratio)

In [None]:
active_list=[]
active_list.append(bnn_test_list[0])
active_list.extend(bnn_test_list[3:])
x_list=np.arange(1,62)*10
plt.plot(x_list,active_list)
plt.show()


In [None]:
[0.4308, 0.6717, 0.8553, 0.6185, 0.7087, 0.7191, 0.7344, 0.7656, 0.7898, 0.831, 0.834, 0.827, 0.8636, 0.8821, 0.8717, 0.8792, 0.8892, 0.9056, 0.9103, 0.9135, 0.9116, 0.9102, 0.9268, 0.9285, 0.9287, 0.9325, 0.9387, 0.9399, 0.933, 0.93, 0.9265, 0.9374, 0.9414, 0.9447, 0.9484, 0.9461, 0.9448, 0.9361, 0.9435, 0.9456, 0.9457, 0.9497, 0.9506, 0.9527, 0.951, 0.9475, 0.9404, 0.9412, 0.9361, 0.939, 0.9435, 0.9462, 0.9504, 0.9496, 0.9539, 0.9542, 0.9573, 0.9563, 0.9587, 0.963, 0.9591, 0.9611, 0.9601]