In [2]:
# import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tools import *
import operator
import itertools

In [32]:
train_data=torchvision.datasets.MNIST('./', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.MNIST('./', train=False, download=True,transform=torchvision.transforms.ToTensor())
# train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=True)

In [33]:
train_data_list=[]
train_label_list=[]
for x,y in train_data:
    train_data_list.append(x)
    train_label_list.append(y)
    
test_data_list=[]
test_label_list=[]
for x,y in test_data:
    test_data_list.append(x)
    test_label_list.append(y)
    
train_data_tensor=torch.stack(train_data_list)
train_label_tensor=torch.tensor(train_label_list)
test_data_tensor=torch.stack(test_data_list)
test_label_tensor=torch.tensor(test_label_list)

print(train_data_tensor.size())
print(train_label_tensor.size())
print(test_data_tensor.size())
print(test_label_tensor.size())

torch.Size([60000, 1, 28, 28])
torch.Size([60000])
torch.Size([10000, 1, 28, 28])
torch.Size([10000])


In [34]:
print(train_data_tensor[0:2].size())

torch.Size([2, 1, 28, 28])


In [121]:
class Net(nn.Module):
    def __init__(self,feature_dim):
        super(Net, self).__init__()
        self.feature_dim=feature_dim        
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 100)
        self.fc2 = nn.Linear(100, feature_dim)
        self.fc3 = nn.Linear(feature_dim,10)
        
        self.optimizer = optim.Adam(self.parameters(), lr=0.001)

    def forward(self, x):
        x=x.view(-1,1,28,28)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x =torch.tanh(self.fc2(x))
        x= self.fc3(x)
        return F.log_softmax(x,dim=-1)
        
    def predictive_distribution_entropy(self,x):
        with torch.no_grad():
            batch_logit=self.forward(x)
            batch_probs=torch.exp(batch_logit)
            batch_entropy=-torch.sum(batch_logit*batch_probs,dim=-1)
#             print(batch_entropy.size())
        return batch_entropy

        
    
    def train(self,x,label):
        train_losses = []
        batch_size=100
        for epoch in range(0,1000):
            if epoch%100==0:
                print('learning_epoch:',epoch)
            for it in range(0,int(x.size(0)/batch_size)):
                index=np.random.choice(x.size(0),batch_size)
                self.optimizer.zero_grad()
                output = self.forward(x[index])
                nll_loss= F.nll_loss(output,label[index],reduction='sum')
                nll_loss.backward()
                self.optimizer.step()
                train_losses.append(nll_loss.item())
        plt.title('training_accuracy')
        plt.plot(train_losses)
        plt.show()
        return train_losses
    
    def test(self,x,label):
        correct=0
        pred = (self.forward(x).data.max(dim=1, keepdim=True)[1]).view(-1)
#         print(pred)
#         print(label)
#         print(torch.nonzero(pred-label))
        accuracy=(pred == label).sum().item()/label.size(0)
        return accuracy
    
        
#     def test(self):
#         correct=0
#         for data, target in test_loader:
#             pred = self.predict(data)
#             correct += pred.eq(target.data.view_as(pred)).sum()
#             correct_ratio= float(correct)/len(test_loader.dataset)
#         return correct_ratio
    

In [122]:
nn_tanh = Net(feature_dim=20)
init_train_data=train_data_tensor[0:100]
init_train_label=train_label_tensor[0:100]
# print(init_train_label.size())

In [None]:
accuracy_list=[]
for epoch in range(0,100):
    print('big_epoch:', epoch, 'start training...')
    print('train_data_size',init_train_label.size(0))
    nn_tanh.train(init_train_data,init_train_label)
    
    accuracy=nn_tanh.test(test_data_tensor,test_label_tensor)
    print('epoch:', epoch, 'test_accuracy', accuracy)
    plt.title('test_accuracy')
    plt.plot(accuracy_list)
    plt.show()
    accuracy_list.append(accuracy)
    ### active part
    print('epoch:', epoch, 'start active learning...')
    for i in range(1,600):
        if i %60 ==0:
            print(i,'active_iterations')
        active_batch_data=train_data_tensor[i*100:(i+1)*100]
        entropy_list=nn_tanh.predictive_distribution_entropy(active_batch_data)
        _, index = entropy_list.max(0)
        init_train_data=torch.cat((init_train_data,active_batch_data[index].view(1,1,28,28)),0)
        init_train_label=torch.cat((init_train_label,train_label_tensor[index+i*100].view(-1)),0)

big_epoch: 0 start training...
train_data_size 699
learning_epoch: 0
learning_epoch: 100
learning_epoch: 200
learning_epoch: 300
learning_epoch: 400
learning_epoch: 500
