In [40]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm_notebook as tqdm
from PIL import Image
from skimage.transform import rescale

In [48]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(1024, 256)
        self.fc21 = nn.Linear(256, 32)
        self.fc22 = nn.Linear(256, 32)
        self.fc3 = nn.Linear(32, 256)
        self.fc4 = nn.Linear(256, 1024)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 1024))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [51]:
size = 32

transform = transforms.Compose([transforms.Resize((size, size)),
                                transforms.ToTensor()])

train_set = MNIST(root='./',
              train=True,
             transform=transform)

batch_size = 100 

train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True)

In [37]:
to_img = transforms.ToPILImage()

In [62]:
logvar.shape

torch.Size([100, 32])

In [71]:
mu

tensor([[ 0.1139,  0.0384,  0.0203,  ..., -0.0359,  0.0816,  0.0689],
        [-0.0492,  0.0333, -0.0025,  ...,  0.0588,  0.0308,  0.0934],
        [ 0.0143,  0.0627,  0.0517,  ...,  0.0518,  0.0173,  0.0838],
        ...,
        [ 0.2248,  0.0438, -0.0335,  ..., -0.0233,  0.0889, -0.0233],
        [ 0.0026, -0.0706, -0.0113,  ...,  0.0094,  0.0566,  0.0636],
        [ 0.0038, -0.0367,  0.0395,  ...,  0.0116,  0.0043,  0.0595]],
       grad_fn=<AddmmBackward>)

In [73]:
(mu**2).sum(-1)

tensor([0.2217, 0.1009, 0.1298, 0.1605, 0.1374, 0.0888, 0.2262, 0.0978, 0.1270,
        0.2646, 0.1067, 0.1095, 0.1202, 0.1228, 0.1478, 0.1544, 0.1641, 0.2316,
        0.1441, 0.1382, 0.1635, 0.1452, 0.0959, 0.2571, 0.2392, 0.1011, 0.1348,
        0.1726, 0.1010, 0.1503, 0.1191, 0.1645, 0.1279, 0.3042, 0.1353, 0.2881,
        0.1529, 0.1330, 0.1536, 0.2174, 0.2443, 0.1068, 0.1450, 0.2344, 0.2110,
        0.1354, 0.3052, 0.0794, 0.1703, 0.1530, 0.1711, 0.2341, 0.0985, 0.2592,
        0.0653, 0.1511, 0.1311, 0.1784, 0.1687, 0.1330, 0.2624, 0.1879, 0.1245,
        0.1673, 0.2844, 0.0920, 0.2895, 0.2434, 0.2279, 0.2616, 0.1640, 0.1322,
        0.1752, 0.4424, 0.1567, 0.1364, 0.1620, 0.0702, 0.2677, 0.1488, 0.1574,
        0.2225, 0.1052, 0.0871, 0.1124, 0.1010, 0.1332, 0.0618, 0.1478, 0.0963,
        0.1710, 0.1172, 0.1350, 0.0708, 0.2225, 0.2662, 0.1215, 0.3163, 0.1970,
        0.0958], grad_fn=<SumBackward1>)

In [79]:
 -1/2 * ( 1 - (mu**2).sum(-1) + logvar.sum(-1) + torch.exp(logvar).sum(-1))

tensor([-17.3637, -17.2111, -17.0776, -16.8850, -17.6429, -17.1321, -16.7539,
        -16.4896, -16.7945, -16.8663, -17.4297, -16.6156, -17.0123, -16.7586,
        -17.4490, -16.6814, -17.3434, -16.7411, -17.0585, -17.0527, -16.8665,
        -16.8231, -17.1150, -16.6741, -16.4968, -16.5078, -16.6908, -16.8936,
        -16.4616, -16.5188, -17.2821, -16.7395, -16.6372, -16.5404, -17.0022,
        -17.1990, -16.4781, -16.6016, -16.3806, -16.8696, -17.2510, -16.6367,
        -17.0998, -17.0497, -16.2499, -17.2078, -17.0475, -16.5706, -16.5114,
        -16.6305, -16.8924, -16.9757, -16.1624, -16.2294, -16.7861, -17.0978,
        -16.9312, -16.8050, -16.7978, -16.6826, -16.8348, -16.9122, -16.7768,
        -16.8415, -16.2919, -16.3914, -17.2325, -16.8541, -17.0492, -16.9515,
        -16.4518, -17.1926, -16.7693, -16.6878, -16.6501, -16.5239, -17.2847,
        -17.2895, -16.8728, -16.9369, -16.9488, -16.5486, -16.8963, -17.4071,
        -16.9811, -16.5721, -16.8534, -16.5212, -17.0297, -17.37

In [53]:
vae = VAE()
optimizer = optim.Adam(VAE.parameters())

num_epochs = 20

t = tqdm(range(num_epochs), desc="Epoch : ", leave=False)
tt = tqdm(train_loader, desc="Batch loss : ", leave=True)


for epoch in t:
    for batch_idx, (img, labels) in enumerate(tt):
        optimizer.zero_grad()
        img_recons, mu, logvar = vae.forward(img)
        reg_loss = -1/2 * ( 1 - (mu**2).sum(-1) + logvar.sum(-1) + torch.exp(logvar).sum(-1))
        reconstruct_loss = torch.norm(img_recons - img, 2)
        loss = reg_loss + reconstruct_loss
        loss.backward()
        optimizer.step()
        
        

HBox(children=(IntProgress(value=0, description='Epoch : ', max=20, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Batch loss : ', max=600, style=ProgressStyle(description_widt…

KeyboardInterrupt: 

In [33]:
math.sqrt(2)

NameError: name 'math' is not defined

In [35]:
(img_recons.reshape(batch_size, 28, 28)[0])

tensor([[0.5485, 0.5248, 0.4799, 0.5572, 0.4985, 0.4818, 0.5234, 0.4000, 0.5339,
         0.4057, 0.5957, 0.3809, 0.5804, 0.5066, 0.5144, 0.5499, 0.5960, 0.4966,
         0.4732, 0.5259, 0.4774, 0.4316, 0.5986, 0.4345, 0.4578, 0.4974, 0.5415,
         0.4138],
        [0.4521, 0.4829, 0.5307, 0.5876, 0.4198, 0.5240, 0.5369, 0.4386, 0.5195,
         0.5282, 0.5321, 0.4821, 0.4364, 0.5054, 0.4895, 0.6017, 0.5718, 0.4702,
         0.4975, 0.4521, 0.4642, 0.3850, 0.4494, 0.5334, 0.5009, 0.4993, 0.4255,
         0.5243],
        [0.5001, 0.4421, 0.4244, 0.4362, 0.4842, 0.4538, 0.4889, 0.5403, 0.5547,
         0.4529, 0.4994, 0.4702, 0.4916, 0.4163, 0.5031, 0.5142, 0.4522, 0.4403,
         0.4729, 0.4252, 0.4312, 0.4072, 0.4786, 0.4211, 0.5337, 0.3507, 0.5280,
         0.4558],
        [0.3616, 0.5337, 0.5603, 0.5664, 0.5870, 0.5302, 0.4875, 0.4570, 0.3876,
         0.5015, 0.3995, 0.4873, 0.4776, 0.5709, 0.4411, 0.5334, 0.4862, 0.5252,
         0.5137, 0.5228, 0.4299, 0.4706, 0.6051, 0.5275