#### Show some samples

In [None]:
def show_some(dataset):
    ### Plot some sample
    plt.close('all')
    fig, axs = plt.subplots(5, 5, figsize=(8,8))
    for ax in axs.flatten():
        img, label = random.choice(dataset)
        ax.imshow(img.squeeze().numpy(), cmap='gist_gray')
        ax.set_title('Label: %d' % label)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    plt.show()
    return

#### Load the datasets

In [None]:
def load_datasets(data_dir):
    train_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_dataset = MNIST(data_dir, train=True,  download=True, transform=train_transform)
    test_dataset  = MNIST(data_dir, train=False, download=True, transform=test_transform)
    return train_dataset, test_dataset

#### Training function

In [None]:
def train_epoch(net, dataloader, loss_fn, optimizer, show_steps=1, use_noise=False):
    net.train()
    counter = 0
    for sample_batch in dataloader:
        if use_noise:
            transform_noise = transforms.Lambda(lambda x: x + 0.5*torch.cuda.FloatTensor(x.shape).normal_())
            transform_crop = transforms.RandomErasing(p=1, scale=(0.01,0.25), ratio=(0.5,2))
            transform = transforms.Compose([transform_crop,
                                            transform_noise])
            image_batch = sample_batch[0].to('cuda')
            image_batch_or = image_batch.clone()
            for i in range(len(image_batch)):
                image_batch[i] = transform(image_batch[i])
        else:
            image_batch = sample_batch[0].to('cuda')
            image_batch_or = image_batch.clone()
        output = net(image_batch)
        loss = loss_fn(output, image_batch_or)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        counter += 1
        if counter%show_steps==0:
            print('\t partial train loss: %f' % (loss.data))

#### Evaluating test set

In [None]:
def test_epoch(net, dataloader, loss_fn):
    net.eval()
    with torch.no_grad():
        conc_out = torch.Tensor().float()
        conc_label = torch.Tensor().float()
        for sample_batch in dataloader:
            image_batch = sample_batch[0].to('cuda')
            out = net(image_batch)
            conc_out = torch.cat([conc_out, out.cpu()])
            conc_label = torch.cat([conc_label, image_batch.cpu()])
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

#### Show the progress of the training

In [None]:
def show_progress(index, dataset, net, epoch):
    img = dataset[index][0].unsqueeze(0).to('cuda')
    net.eval()
    with torch.no_grad():
        rec_img  = net(img)
    fig, axs = plt.subplots(1, 2, figsize=(8,4))
    axs[0].imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
    axs[0].set_title('Original image')
    axs[1].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
    axs[1].set_title('Reconstructed image (EPOCH %d)' % (epoch + 1))
    plt.tight_layout()
    plt.pause(0.1)
    # Save figures
    os.makedirs('progress/autoencoder_progress_%d_features' % net.encoded_space_dim, exist_ok=True)
    plt.savefig('progress/autoencoder_progress_%d_features/epoch_%d.png' % (net.encoded_space_dim, epoch + 1))
    plt.show()
    plt.close()
    return

#### Show structure of latent space (so far just 2-dim latent space)

In [None]:
def show_result(net, dataset):
    encoded_samples = []
    for sample in dataset:
        img = sample[0].unsqueeze(0)
        label = sample[1]
        net.eval()
        with torch.no_grad():
            encoded_img  = net.encode(img)
        encoded_samples.append((encoded_img.flatten().numpy(), label))
        
    color_map = {
        0: '#1f77b4',
        1: '#ff7f0e',
        2: '#2ca02c',
        3: '#d62728',
        4: '#9467bd',
        5: '#8c564b',
        6: '#e377c2',
        7: '#7f7f7f',
        8: '#bcbd22',
        9: '#17becf'
        }

    fig, ax = plt.subplots(figsize=(12,10))
    for enc_sample, label in encoded_samples:
        ax.plot(enc_sample[0], enc_sample[1], marker='.', color=color_map[label])
    plt.xlim(-8,8)
    plt.ylim(-8,8)
    plt.close()
    
    #x_max = max(encoded_samples, key = lambda k: k[0][0])[0][0]
    #x_min = min(encoded_samples, key = lambda k: k[0][0])[0][0]
    #y_max = max(encoded_samples, key = lambda k: k[0][1])[0][1]
    #y_min = min(encoded_samples, key = lambda k: k[0][1])[0][1]
    
    digit_size = 28
    n_digits = 25
    figure = np.zeros((digit_size*n_digits, digit_size*n_digits))
    grid_x = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], n_digits)
    grid_y = np.linspace(ax.get_ylim()[1], ax.get_ylim()[0], n_digits)
    
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.Tensor([xi, yi])
            x_decoded = net.decode(z_sample)
            digit = x_decoded.view(digit_size, digit_size)
            figure[i*digit_size: (i+1)*digit_size, 
                   j*digit_size: (j+1)*digit_size] = digit.detach().numpy()
    fig = plt.figure(figsize=(12,12))
    plt.imshow(figure, cmap='Greys', alpha=0.7, extent=[ax.get_xlim()[0], ax.get_xlim()[1],
                                                        ax.get_ylim()[0], ax.get_ylim()[1]])
    for enc_sample, label in encoded_samples:
        plt.plot(enc_sample[0], enc_sample[1], marker='.', color=color_map[label], alpha=0.2)
    plt.grid(True)
    plt.legend([plt.Line2D([0], [0], ls='', marker='.', color=c, label=l) for l, c in color_map.items()],
               color_map.keys())
    plt.tight_layout()
    plt.xlim(-8,8)
    plt.ylim(-8,8)
    plt.title('Latent Space')
    fig.savefig('report/latent_2.pdf')
    return

#### Functions for computing the loss over different levels of noise and crop

In [None]:
def compute_limits_loss_noise(net, dataloader, loss_fn):
    
    with torch.no_grad():
        conc_out = torch.Tensor().float()
        conc_label = torch.Tensor().float()
        for sample_batch in dataloader:
            image_batch = sample_batch[0]
            out = net(image_batch)
            conc_out = torch.cat([conc_out, out.cpu()])
            conc_label = torch.cat([conc_label, image_batch.cpu()])
        val_loss = loss_fn(conc_out, conc_label)
    min_val = val_loss.item()
    
    with torch.no_grad():
        conc_out = torch.Tensor().float()
        conc_label = torch.Tensor().float()
        for sample_batch in dataloader:
            image_batch_or = sample_batch[0].clone()
            image_batch = torch.randn(image_batch_or.shape)
            out = net(image_batch)
            conc_out = torch.cat([conc_out, out.cpu()])
            conc_label = torch.cat([conc_label, image_batch_or.cpu()])
        val_loss = loss_fn(conc_out, conc_label)
    max_val = val_loss.item()
    return [min_val, max_val]

def compute_limits_loss_crop(net, dataloader, loss_fn):
    
    with torch.no_grad():
        conc_out = torch.Tensor().float()
        conc_label = torch.Tensor().float()
        for sample_batch in dataloader:
            image_batch = sample_batch[0]
            out = net(image_batch)
            conc_out = torch.cat([conc_out, out.cpu()])
            conc_label = torch.cat([conc_label, image_batch.cpu()])
        val_loss = loss_fn(conc_out, conc_label)
    min_val = val_loss.item()
    
    with torch.no_grad():
        conc_out = torch.Tensor().float()
        conc_label = torch.Tensor().float()
        for sample_batch in dataloader:
            image_batch_or = sample_batch[0].clone()
            image_batch = torch.zeros(image_batch_or.shape)
            out = net(image_batch)
            conc_out = torch.cat([conc_out, out.cpu()])
            conc_label = torch.cat([conc_label, image_batch_or.cpu()])
        val_loss = loss_fn(conc_out, conc_label)
    max_val = val_loss.item()
    return [min_val, max_val]

In [None]:
def test_noise_crop(net, dataloader, loss_fn):
    net.eval()
    device = torch.device('cpu')
    net.to(device)
    
    #Noise
    limits_loss_noise = compute_limits_loss_noise(net, dataloader, loss_fn)
    print('\n@@@@@@@@@@@@@@@@@@@@@@@   Noise testing:\n')
    noise_try = [0, 0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4,0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2]
    loss_log_noise = []
    for noise_level in noise_try:
        transform_noise = transforms.Lambda(lambda x: x + noise_level*torch.randn(x.shape))
        with torch.no_grad():
            conc_out = torch.Tensor().float()
            conc_label = torch.Tensor().float()
            for sample_batch in dataloader:
                image_batch_or = sample_batch[0].clone()
                image_batch = transform_noise(sample_batch[0])
                out = net(image_batch)
                conc_out = torch.cat([conc_out, out.cpu()])
                conc_label = torch.cat([conc_label, image_batch_or.cpu()])
            val_loss = loss_fn(conc_out, conc_label)-limits_loss_noise[0]
            val_loss *= 100/(limits_loss_noise[1]-limits_loss_noise[0])
            loss_log_noise.append(val_loss.item())
            print('Test loss with a %.1f %% noise level: %.4f' %(noise_level*100, val_loss))
            #fig, axs = plt.subplots(5,3 , figsize=(3,5))
            #for index in range(5):
            #    img = image_batch[index].view(1,-1,28,28)
            #    rec_img  = net(img)
            #    axs[index, 0].imshow(image_batch_or[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,0].axis('off')
            #    axs[index, 1].imshow(image_batch[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,1].axis('off')
            #    axs[index, 2].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,2].axis('off')
            #plt.tight_layout()
            #plt.show()
            plt.plot
    plt.plot(noise_try, loss_log_noise)
    plt.show()
        
    #Crop
    limits_loss_crop = compute_limits_loss_crop(net, dataloader, loss_fn)
    print('\n@@@@@@@@@@@@@@@@@@@@@@@   Crop testing:\n')
    scale_try = [0, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    loss_log_crop = []
    for scale in scale_try:
        transform_crop = transforms.RandomErasing(p=1, scale=(scale,scale), ratio=(0.5,2))
        with torch.no_grad():
            conc_out = torch.Tensor().float()
            conc_label = torch.Tensor().float()
            for sample_batch in dataloader:
                image_batch_or = sample_batch[0].clone()
                image_batch = sample_batch[0]
                if scale!=0:
                    for i in range(len(image_batch)):
                        image_batch[i] = transform_crop(image_batch[i])
                out = net(image_batch)
                conc_out = torch.cat([conc_out, out.cpu()])
                conc_label = torch.cat([conc_label, image_batch_or.cpu()])
            val_loss = loss_fn(conc_out, conc_label)-limits_loss_crop[0]
            val_loss *= 100/(limits_loss_crop[1]-limits_loss_crop[0])
            loss_log_crop.append(val_loss.item())
            print('Test loss with a %.1f %% crop: %.4f' %(scale*100, val_loss))
            #fig, axs = plt.subplots(5, 3, figsize=(3,5))
            #for index in range(5):
            #    img = image_batch[index].view(1,-1,28,28)
            #    rec_img  = net(img)
            #    axs[index, 0].imshow(image_batch_or[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,0].axis('off')
            #    axs[index, 1].imshow(image_batch[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,1].axis('off')
            #    axs[index, 2].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,2].axis('off')
            #plt.tight_layout()
            #plt.show()
    plt.plot(scale_try, loss_log_crop)
    plt.show()
    return