In [1]:
import sys
sys.path.append('..')

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchvision
from torchvision.utils import save_image
from torch.utils.data.dataset import Dataset
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm_notebook as tqdm

os.environ['CUDA_VISIBLE_DEVICES']='0'

### MNIST data

In [2]:
# A wrapper dataset over MNIST to return images and indices
class DatasetMNIST(Dataset):

    def __init__(self, root_dir, latent_size, transform=None):
        mnist = torchvision.datasets.MNIST(root=root_dir, train=True,download=True)
        self.data = mnist.train_data.float()/255.0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]

        return image.flatten(), index

In [3]:
# Create a directory if not exists
sample_dir = 'samples'
os.makedirs(sample_dir, exist_ok=True)

# Hyper-parameters
image_size = 784
h_dim = 512
num_epochs = 12
batch_size = 128
learning_rate = 2e-3
latent_size = 2

### Auto-decoder

In [4]:
class AD(nn.Module):
    def __init__(self, image_size=784, z_dim=latent_size, data_shape=60000):
        super(AD, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True), nn.Linear(512, 28 * 28), nn.Tanh())
        
        self.latent_vectors = nn.Parameter(torch.FloatTensor(data_shape, latent_size))
        
        init.xavier_normal(self.latent_vectors)
    
    def forward(self, ind):
        x = self.latent_vectors[ind]
        return self.decoder(x)
    
    def predict(self, x):
        return self.decoder(x)

In [5]:
transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = DatasetMNIST(root_dir='../../data', latent_size=latent_size)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [6]:
model = AD().cuda()

# recusntruction loss
criterion = nn.MSELoss()#nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

  


In [7]:
x, y = next(iter(data_loader))
print(x.shape, y.shape)
y

torch.Size([128, 784]) torch.Size([128])


tensor([13540, 21457, 10410, 32233, 41246, 54087,  4344, 59169, 10640, 27006,
        17653, 56291, 13409, 11766,  5502, 39906, 13918, 13974,  3461, 35348,
        37722, 51551,   544,  8944, 22313, 49235, 53887, 59039, 42327, 12350,
        22135, 15129, 14211, 20127, 51002, 24538, 58044,  5788, 22465, 37846,
        18978, 55127, 46041,  6216, 21545, 20958, 12630, 58545, 19814, 54990,
        13217,  7113, 10594,  2188, 27824, 50855,  3840,  8330, 15817, 45277,
        44903, 11846, 24572, 14835, 44232, 44413, 43247,  8501, 44076, 22922,
         9182,  6270,  6106, 10432, 23935, 36022, 56633,   460, 38529, 16794,
        46296, 14553, 40087, 25712, 16402, 38071, 53925, 28276,  7727, 48537,
        13171,  9558, 11496, 12218, 33071, 39886, 56109, 14406, 49940, 21222,
        31373, 17619, 40946, 29179, 50905,  3964, 53350, 32685, 10145, 54054,
        24827, 10277, 30692, 28237, 44877, 13184, 46068, 18332, 26410, 56684,
        32720, 33955, 36564, 54490, 59078, 15826,  9466, 57649])

In [8]:
model.latent_vectors # just to check

Parameter containing:
tensor([[-4.9269e-03, -1.6008e-02],
        [ 3.6784e-03, -5.6042e-03],
        [ 2.8870e-03,  2.9110e-03],
        ...,
        [ 1.9106e-03, -1.6223e-03],
        [-5.9483e-03, -8.7236e-03],
        [-2.4812e-03,  2.8096e-06]], device='cuda:0', requires_grad=True)

### Training

In [9]:
model.train()
for epoch in range(num_epochs):
    tq = tqdm(total=len(data_loader))
    tq.set_description('Epoch {}'.format(epoch))
    for i, (x, ind) in enumerate(data_loader):
        # Forward pass
        x = x.cuda()
        x_reconst = model(ind)

        loss = criterion(x_reconst, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tq.update()
        tq.set_postfix(loss='{:.3f}'.format(loss.item()))          
            
    
    with torch.no_grad():
        # Visualize 2D latent space        
        steps = 20
        bound = 0.4
        size = 28
        out = torch.zeros(size=(steps * size, steps * size))

        for i, l1 in enumerate(np.linspace(-bound, bound, steps)):
            for j, l2 in enumerate(np.linspace(-bound, bound, steps)):
                vector = torch.tensor([l1, l2]).cuda()
                out_ = model.predict(vector)
                out[i * size:(i + 1) * size, j * size:(j + 1) * size] = out_.view(size, size)
        save_image(out, os.path.join(sample_dir, 'latent_space-{}.png'.format(epoch + 1)))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




Exception in thread Thread-4:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.6/site-packages/tqdm/_monitor.py", line 62, in run
    for instance in self.tqdm_cls._instances:
  File "/opt/conda/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration



HBox(children=(IntProgress(value=0, max=469), HTML(value='')))




In [10]:
model.latent_vectors # just to check

Parameter containing:
tensor([[-0.1604, -0.0163],
        [-0.3056,  0.1926],
        [-0.0478,  0.0599],
        ...,
        [-0.1557, -0.1057],
        [-0.0999, -0.0220],
        [-0.0740, -0.1179]], device='cuda:0', requires_grad=True)

In [11]:
model.latent_vectors.max(), model.latent_vectors.min()

(tensor(0.4014, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(-0.3995, device='cuda:0', grad_fn=<MinBackward1>))