In [232]:
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
import math
from mpl_toolkits.axes_grid1 import ImageGrid

In [261]:
def imshow_grid(images, shape=[10, 4], name='default', save=False):
    """
    Plot images in a grid of a given shape.
    Initial code from: https://github.com/pumpikano/tf-dann/blob/master/utils.py
    """
    fig = plt.figure(1, figsize=(15, 6))
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i], cmap='gray')  # The AxesGrid object work as a list of axes.
    if save:
        plt.savefig('./nn/' + str(name) + '.png', bbox_inches='tight', transparent = True , pad_inches = 0)
        plt.clf()
    else:
        plt.show()

In [223]:
training_data = torch.load("./data/mnist/MNIST/processed/training.pt")

In [108]:
X = training_data[0].reshape(training_data[0].shape[0], training_data[0].shape[1]*training_data[0].shape[2])

In [110]:
neigh = NearestNeighbors(n_neighbors=3, n_jobs=-1)

In [111]:
neigh.fit(X)

NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
         metric_params=None, n_jobs=-1, n_neighbors=3, p=2, radius=1.0)

In [224]:
training_data[0].shape

torch.Size([60000, 28, 28])

## Fetch for random sample

### Change here for different models

In [175]:
sampler = torch.load("./mnist_004/mnist/mnistsamples.pkl")

In [246]:
images_nn = []

In [247]:
for i in range(10):
    sample = sampler["samples_"+str(i)]
    sample = np.array(sample)
    
    random_idx = np.random.randint(0, 2000)
    random_sample = sample[random_idx]
    images_nn.append(random_sample[0])
    
    random_sample = random_sample.reshape(28*28)
    random_sample = ((random_sample - random_sample.min()) * (1/(random_sample.max() - random_sample.min()) * 255).astype('uint8'))
    
    nei = neigh.kneighbors(X=[random_sample] ,n_neighbors=3)
    for i in nei[1][0]:
        images_nn.append(training_data[0][i])
    
        

In [248]:
len(images_nn)

40

In [262]:
imshow_grid(images_nn, name="005", save=True)

<Figure size 1080x432 with 0 Axes>