In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.datasets as dset

from lib.VisdomWrapper import *
from lib.GANs import *
from lib.DataCreationWrapper import *
from lib.DataManager import *

In [2]:
torch.manual_seed(2)
batch_size = 512
num_epochs = 30
img_width = 28 #hardcoded
n_features = img_width**2
n_noise_features = 50
n_classes = 10

In [3]:
bal_gen = Conv_GeneratorNetwork(n_noise_features, n_features, n_classes)
bal_gen.load_state_dict(torch.load("models\gen_nn_directed_noise"))
bal_gen = bal_gen.cuda()
bal_gen.eval()

Conv_GeneratorNetwork(
  (label_embedding): Sequential(
    (0): Embedding(10, 50)
  )
  (to_input_form): Sequential(
    (0): Linear(in_features=50, out_features=25088, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): BatchNorm1d(25088, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (hidden0): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.4, inplace=False)
  )
  (hidden1): Sequential(
    (0): ConvTranspose2d(64, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1)
  )
  (out): Sequential(
    (0): Conv2d(16, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1)
    (2): Tanh()
  )
)

In [4]:
compose = transforms.Compose(
        [
            transforms.ToTensor(),
            # AddNormalNoise(0, .2),
            transforms.Normalize((.5,),(.5))
        ])

In [5]:
mnist = dset.MNIST('input', train=True, download=True, transform=compose)

In [6]:
bal_train = DataLoader(mnist, batch_size =1000)
X, Y = data_loader_to_tensor(bal_train)

In [7]:
from sklearn.neighbors import NearestNeighbors
import numpy as np

In [8]:
print(X.shape)
X = X.view(len(X), -1)
print(X.shape)

torch.Size([60000, 1, 28, 28])
torch.Size([60000, 784])


In [9]:
nbrs = NearestNeighbors(n_neighbors=3, algorithm='auto').fit(X)

In [10]:
syn_data, syn_labels =  synthesize_data_of_each_label(bal_gen, gaussian_noise, 100 * np.ones(10).astype('int'))

syn_data = syn_data.view(len(syn_data), -1)
dist, labels = nbrs.kneighbors(syn_data.detach().cpu(), return_distance=True)

In [11]:
(syn_data[0].detach().cpu() - X[labels[0,0]].detach().cpu()).sum()

tensor(60.5558)

In [12]:
np.average(dist,axis=1)

array([ 9.86580697,  8.84034223,  7.80373661, 11.35775365,  8.16321847,
        9.69102724,  7.91509274,  8.37901917,  8.88025252,  8.83584831,
        8.72835356,  8.48714777, 11.41911669,  8.12662837, 10.01038805,
       13.06321583,  9.19606371,  7.51189754,  9.78850821,  8.89423105,
        8.62390071,  8.74511394,  9.39960866,  9.32069563,  9.09295514,
        8.90196635,  7.80213193,  6.91374246,  8.92491918,  9.05604395,
        9.38407212,  7.3745541 ,  8.21845276,  7.63789783, 10.12641923,
        6.85564199,  7.96941969,  8.91400219,  8.37532421,  8.89576452,
        8.33005715,  9.28417654,  7.93261272,  7.02463537, 12.25427115,
        7.91845167,  7.56743466,  8.50936206,  8.56547264,  9.98994436,
       10.26257382,  9.13781191, 10.67386405,  9.98721787,  9.85440067,
       10.66047038,  8.80863468,  7.5996183 ,  9.97208814,  9.2445579 ,
       10.93819749,  9.64895834,  8.72563329,  9.68417006,  9.7317385 ,
        8.31291805,  7.77615808, 11.7579121 ,  8.47362119,  9.38

In [13]:
(dist == 0).sum()

0

In [14]:
np.average(dist,axis=1).min()

2.5152797998785705