In [None]:
%run include.ipynb
%run Net.ipynb
%run Data.ipynb
%run viewer.ipynb

from torch.autograd import Variable

class CNN_3Dmnist(object):
    
    def __init__(self, arch):
        
        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
        torch.manual_seed(random.randint(1, 10000))
        
        self.input_dims, layers = Net.parse_layers(arch)
        self.net = Network_template(gpu_num, layers).to(self.device)
        Net.init_weights(self.net, "normal")
        
    def test(self, X_test, targets_test):
        
        batch_size = 100
        X_test = X_test.reshape(2000,3,16,16,16)
        test_x = torch.from_numpy(X_test).float()
        test_y = torch.from_numpy(targets_test).long()
        test = torch.utils.data.TensorDataset(test_x,test_y)
        test_loader = torch.utils.data.DataLoader(test, batch_size = batch_size, shuffle = False)
        
        self.net.eval()
        for images, labels in test_loader:

            test = Variable(images.view(batch_size,3,16,16,16)).to(self.device)
            test.requires_grad_()
            outputs = self.net(test)
            score_max_index = outputs.argmax(dim=1)
            for i in range(batch_size):
                score_max = outputs[i,score_max_index[i]]
                score_max.backward(retain_graph=True)
                
            saliency, _ = torch.max(test.grad.data.abs(),dim=1)
            print(saliency.shape)
            break
        return saliency.detach().cpu().numpy()
        
        
    def train(self, X_train, X_test, targets_train, targets_test):
        
        X_train = X_train.reshape(10000,3,16,16,16)
        X_test = X_test.reshape(2000,3,16,16,16)
        
        train_x = torch.from_numpy(X_train).float()
        train_y = torch.from_numpy(targets_train).long()
        test_x = torch.from_numpy(X_test).float()
        test_y = torch.from_numpy(targets_test).long()

        batch_size = 100 #We pick beforehand a batch_size that we will use for the training


        # Pytorch train and test sets
        train = torch.utils.data.TensorDataset(train_x,train_y)
        test = torch.utils.data.TensorDataset(test_x,test_y)

        # data loader
        train_loader = torch.utils.data.DataLoader(train, batch_size = batch_size, shuffle = False)
        test_loader = torch.utils.data.DataLoader(test, batch_size = batch_size, shuffle = False)
        
        n_iters = 9000
        num_epochs = n_iters / (len(train_x) / batch_size)
        num_epochs = int(num_epochs)
        num_classes = 10
        
        error = nn.CrossEntropyLoss()
        learning_rate = 0.001
        optimizer = torch.optim.SGD(self.net.parameters(), lr=learning_rate)
        
        count = 0
        loss_list = []
        iteration_list = []
        accuracy_list = []
        for epoch in range(num_epochs):
            for i, (images, labels) in enumerate(train_loader):
                train = Variable(images.view(100,3,16,16,16)).to(self.device)
                labels = Variable(labels).to(self.device)
                
                optimizer.zero_grad()
                outputs = self.net(train)
                loss = error(outputs, labels)
                loss.backward()
                optimizer.step()
                
                count += 1
                if count % 50 == 0:
                    # Calculate Accuracy
                    correct = 0
                    total = 0
                    # Iterate through test dataset
                    for images, labels in test_loader:

                        test = Variable(images.view(100,3,16,16,16)).to(self.device)
                        # Forward propagation
                        outputs = self.net(test)

                        # Get predictions from the maximum value
                        predicted = torch.max(outputs.data, 1)[1]

                        # Total number of labels
                        total += len(labels)
                        correct += (predicted.detach().cpu() == labels).sum()

                    accuracy = 100 * correct / float(total)

                    # store loss and iteration
                    loss_list.append(loss.data)
                    iteration_list.append(count)
                    accuracy_list.append(accuracy)
                if count % 500 == 0:
                    # Print Loss
                    print('Iteration: {}  Loss: {}  Accuracy: {} %'.format(count, loss.data, accuracy))
