In [33]:
# TODO: As in "ONLINE BATCH SELECTION FOR FASTER TRAINING OF NEURAL NETWORKS", 
# have one pass over the whole training set to calculate the loss of each sample, (usin any method, prioritzed loss or validated)
# and upon sampling and training, recalculate the new loss that would be induced, but don't apply it

# seems to be difficult for validated: one would need to sample a validation batch, eval on val. batch, train,
# eval. on val batch and undo the changes. So it would require three forward passes and one backward pass
# For prioritized it would be one forward and backward pass to get the precise gradient for every sample in the batch

# TODO: implement the upper bound of the gradient norm as in Katharopoulos et al ""

import torch
import skorch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from exp_rep import PrioritizedReplayBuffer, ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import numpy as np

In [34]:
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=None)

In [35]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [42]:
def train(model, device, train_loader, optimizer, epoch, writer, test_loader, sampling_procedure,
          buffer, train_batch_size, val_batch_size, num_updates, total_updates, sampling_beta,
          improvement_mean, sampling_running_avg, sampling_anneal_beta, total_epochs, eval_uniform):
    model.train()
    remove_samples_from_buffer = (sampling_procedure == "uniform")
        
    start_beta = sampling_beta
    for batch_idx in range(num_updates):
        with torch.no_grad():
            for test_data, test_target in test_loader:
                output = model(test_data)
                test_loss = F.nll_loss(output, test_target, reduction='sum').item() # sum up batch loss
                writer.add_scalar("Test loss", test_loss, global_step = total_updates)
                break
        # Get an estimate of the current train loss to track performance:
        #with torch.no_grad():
        #    for test_data, test_target in train_loader:
        #        output = model(test_data)
        #        test_loss = F.nll_loss(output, test_target, reduction='sum').item() # sum up batch loss
        #        break
        
        if len(buffer) < train_batch_size:
            train_batch_size = len(buffer)
        # Anneal beta from original value to 1 over all epochs
        sampling_beta = start_beta + (batch_idx / (total_epochs * num_updates)) * (1 - start_beta)
        samples = buffer.sample(train_batch_size, beta=sampling_beta,
                                remove_samples_from_buffer=remove_samples_from_buffer)
        data = torch.from_numpy(samples[0]).float()
        data = data.view((train_batch_size, 1, 28, 28))
        target = torch.from_numpy(samples[2]).long()
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target, reduction='none')
        if sampling_procedure != "uniform":
            weights = torch.from_numpy(samples[-2]).float()
            idxes = samples[-1]
            loss_per_sample = loss.detach()
            loss *= weights
            if sampling_procedure == "prioritized":
                buffer.update_priorities(idxes, loss_per_sample)
            

        loss = loss.mean()
        loss.backward()
        optimizer.step()
        if sampling_procedure == "validated_prioritized" or eval_uniform:
            with torch.no_grad():
                # Calculate improvement after training:
                updated_test_output = model(test_data)
                updated_test_loss = F.nll_loss(updated_test_output, test_target, reduction='sum').item()
                improvement = (test_loss - updated_test_loss)
                writer.add_scalar("Validated Improvement", improvement, global_step=total_updates)

                
                linear_boosted_clip_func = lambda x: np.clip(x + (1 if x > 0 else 0), 0.1, None)
                relative_improvement_func = lambda x: (x - improvement_mean) / (improvement_mean if improvement_mean != 0 else 1)
                sigmoid_func = lambda x: 1 / (0.1 + np.e ** (0.5 + -np.clip(x, -10, 10)))
                
                    
                linear_boosted = linear_boosted_clip_func(improvement)
                sigmoid_improvement = sigmoid_func(improvement)
                relative_improvement = relative_improvement_func(improvement)
                boosted_relative = linear_boosted_clip_func(relative_improvement)
                sigmoid_relative = sigmoid_func(relative_improvement)
                
                writer.add_scalar("Validated Linear boosted Improvement", linear_boosted, global_step=total_updates)
                writer.add_scalar("Validated Sigmoid Improvement", sigmoid_improvement, global_step=total_updates)                
                writer.add_scalar("Validated Relative Improvement", relative_improvement, global_step=total_updates)
                writer.add_scalar("Validated Relative Sigmoid Improvement", sigmoid_relative, global_step=total_updates)                
                writer.add_scalar("Validated Relative Linear Boosted improvement", boosted_relative, global_step=total_updates)
                writer.add_scalar("Running Mean of Validated Improvement", improvement_mean, global_step=total_updates)
                # update running mean:
                improvement_mean = improvement_mean * 0.99 + 0.01 * improvement
                                              
                improvement = sigmoid_relative

            if sampling_procedure == "validated_prioritized":
                prioritization_weights = np.ones_like(idxes) * improvement
                buffer.update_priorities(idxes, prioritization_weights, running_avg=sampling_running_avg)
       
        
        total_updates += 1
        writer.add_scalar("Train loss", loss.detach(), global_step=total_updates)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        if 1 == 2 and batch_idx != "uniform": #TODO: enable every fifth of training time
            writer.add_histogram("Sampling distribution counts", buffer.counts, global_step=epoch, bins='tensorflow', max_bins=None)
            for i in range(60000):
                weight_values[i] = buffer._it_sum[i]
            writer.add_histogram("Sampling distribution Sum Tree Content", weight_values, global_step=epoch, bins='tensorflow', max_bins=None)

        
                
    return total_updates, improvement_mean


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))

In [47]:
# Training settings
seed = 1
log_interval = 10
save_model = False
use_cuda = torch.cuda.is_available()
eval_uniform = True
sampling_procedure = "prioritized" # "uniform" "prioritized", "validated_prioritized"
sampling_alpha = 1
sampling_beta = 0.4
sampling_max_priority = 2
sampling_running_avg = 0.5
sampling_anneal_beta = False
train_batch_size = 8
val_batch_size = 128
lr = 0.005
improvement_mean = 0


torch.manual_seed(seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=train_batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=val_batch_size, shuffle=True, **kwargs)

num_train_samples = 60000
if sampling_procedure == "uniform":
    buffer = ReplayBuffer(num_train_samples)
else:
    buffer = PrioritizedReplayBuffer(num_train_samples, sampling_alpha, max_priority=sampling_max_priority)

for batch in train_loader:
    for idx in range(len(batch[0])):
        buffer.add(batch[0][idx], None, batch[1][idx], None, False)
    
total_updates = 0
num_updates = num_train_samples // train_batch_size 
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5)
writer = SummaryWriter()
#writer.add_graph(model)

weight_values = np.zeros(60000)
total_epochs = 3
for epoch in range(1, total_epochs + 1):
    total_updates, improvement_mean = train(model, device, train_loader, optimizer, epoch, writer,
                          test_loader, sampling_procedure, buffer, train_batch_size,
                          val_batch_size, num_updates, total_updates, sampling_beta, improvement_mean,
                                           sampling_running_avg, sampling_anneal_beta, total_epochs, 
                                            eval_uniform=eval_uniform)
    test(model, device, test_loader)
    
    if sampling_procedure != "uniform":
        writer.add_histogram("Sampling distribution counts", buffer.counts, global_step=epoch, bins='tensorflow', max_bins=None)
        for i in range(60000):
            weight_values[i] = buffer._it_sum[i]
        writer.add_histogram("Sampling distribution", weight_values, global_step=epoch, bins='tensorflow', max_bins=None)

    
if (save_model):
    torch.save(model.state_dict(),"mnist_cnn.pt")





KeyboardInterrupt: 

In [None]:
for batch in train_loader:
    for i in range(len(batch[0])):
        print(batch[0][i], batch[1][i])
    break

#for idx, batch in enumerate(train_loader):
#    for idx in range(len(batch[0])):
#        buffer.add(batch[0][idx], None, batch[1][idx], None, False)
    

In [None]:
samples = buffer.sample(train_batch_size)
imgs = torch.from_numpy(samples[0])
labels = torch.from_numpy(samples[2])

In [None]:
len(buffer)

In [None]:
torch.from_numpy(labels)