Let us create synthetic data.
We consider 

In [1]:
import numpy as np
import torch

In [2]:
n = 500
# here we define 2 Gaussian latents variables z = (l_1, l_2)
l1 = np.random.normal(size=n)
l2 = np.random.normal(size=n)

latents = np.array([l1, l2]).T

# We define two random transformations from the latent space to the space of X and Y respectively
transform_x = np.random.randint(-8,8, size = 10).reshape([2,5])

# We compute data X = z w_x, and Y = z w_y
X = latents.dot(transform_x) 

# We we add some random Gaussian noise
X = X + 2*np.random.normal(size = n*5).reshape((n, 5))

X_normalized=(X-X.min())/(X.max()-X.min())
data=torch.Tensor([[[x]] for x in X_normalized])

print('The latent space has dimension ' + str(latents.shape))
print('The transformation for X has dimension ' + str(transform_x.shape))

print('X has dimension ' + str(X.shape))
print('data has shape '+str(data.shape))

The latent space has dimension (500, 2)
The transformation for X has dimension (2, 5)
X has dimension (500, 5)
data has shape torch.Size([500, 1, 1, 5])


In [3]:
l1.mean(),l1.std()

(-0.0680884379982985, 1.0058423287428864)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

In [5]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc21 = nn.Linear(h_dim1, z_dim)
        self.fc22 = nn.Linear(h_dim1, z_dim)
        # decoder part
        self.fc3 = nn.Linear(z_dim, h_dim1)
        self.fc4 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        return self.fc21(h), self.fc22(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 5))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# build model
vae = VAE(x_dim=5, h_dim1= 10, z_dim=2)
    
vae

VAE(
  (fc1): Linear(in_features=5, out_features=10, bias=True)
  (fc21): Linear(in_features=10, out_features=2, bias=True)
  (fc22): Linear(in_features=10, out_features=2, bias=True)
  (fc3): Linear(in_features=2, out_features=10, bias=True)
  (fc4): Linear(in_features=10, out_features=5, bias=True)
)

In [6]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1,5), reduction='sum')
    
    KLD = 0.5 * torch.sum(-1 - log_var +mu.pow(2) + log_var.exp())
    
    print(BCE.item(),KLD.item())
    return BCE + KLD



In [7]:
def train(epoch):
    vae.train()
        
    optimizer.zero_grad()

    recon_batch, mu, log_var = vae(data)
    loss = loss_function(recon_batch, data, mu, log_var)

    loss.backward()
    optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, loss.item()))

In [8]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [9]:
for epoch in range(1, 1001):
    train(epoch)

1756.6485595703125 43.83906173706055
====> Epoch: 1 Average loss: 1800.4877
1757.895263671875 42.47222137451172
====> Epoch: 2 Average loss: 1800.3674
1756.3258056640625 41.13226318359375
====> Epoch: 3 Average loss: 1797.4580
1755.493896484375 39.81956100463867
====> Epoch: 4 Average loss: 1795.3135
1754.2979736328125 38.534244537353516
====> Epoch: 5 Average loss: 1792.8323
1755.954833984375 37.27609634399414
====> Epoch: 6 Average loss: 1793.2310
1754.404541015625 36.04560089111328
====> Epoch: 7 Average loss: 1790.4502
1756.052978515625 34.842010498046875
====> Epoch: 8 Average loss: 1790.8950
1751.0726318359375 33.6667594909668
====> Epoch: 9 Average loss: 1784.7394
1751.5906982421875 32.51831817626953
====> Epoch: 10 Average loss: 1784.1090
1750.40185546875 31.398618698120117
====> Epoch: 11 Average loss: 1781.8004
1749.197265625 30.30644416809082
====> Epoch: 12 Average loss: 1779.5037
1750.4755859375 29.244592666625977
====> Epoch: 13 Average loss: 1779.7202
1745.8111572265625 



1748.5364990234375 25.271419525146484
====> Epoch: 17 Average loss: 1773.8079
1747.47998046875 24.3465576171875
====> Epoch: 18 Average loss: 1771.8265
1748.0042724609375 23.447969436645508
====> Epoch: 19 Average loss: 1771.4523
1747.6268310546875 22.575830459594727
====> Epoch: 20 Average loss: 1770.2026
1745.9959716796875 21.729503631591797
====> Epoch: 21 Average loss: 1767.7255
1745.990966796875 20.90899658203125
====> Epoch: 22 Average loss: 1766.8999
1744.9615478515625 20.113727569580078
====> Epoch: 23 Average loss: 1765.0753
1744.100830078125 19.34296989440918
====> Epoch: 24 Average loss: 1763.4438
1743.893798828125 18.596553802490234
====> Epoch: 25 Average loss: 1762.4904
1742.546142578125 17.873699188232422
====> Epoch: 26 Average loss: 1760.4198
1742.3121337890625 17.174970626831055
====> Epoch: 27 Average loss: 1759.4871
1743.041259765625 16.500621795654297
====> Epoch: 28 Average loss: 1759.5419
1742.158447265625 15.849422454833984
====> Epoch: 29 Average loss: 1758.007

1729.62890625 0.13965433835983276
====> Epoch: 160 Average loss: 1729.7686
1729.485107421875 0.1394142210483551
====> Epoch: 161 Average loss: 1729.6245
1729.38720703125 0.13918423652648926
====> Epoch: 162 Average loss: 1729.5264
1729.640869140625 0.13896843791007996
====> Epoch: 163 Average loss: 1729.7798
1729.161376953125 0.13875818252563477
====> Epoch: 164 Average loss: 1729.3002
1729.342529296875 0.13855642080307007
====> Epoch: 165 Average loss: 1729.4811
1729.3974609375 0.13836178183555603
====> Epoch: 166 Average loss: 1729.5358
1728.9842529296875 0.13817346096038818
====> Epoch: 167 Average loss: 1729.1224
1729.3916015625 0.13798663020133972
====> Epoch: 168 Average loss: 1729.5295
1729.70947265625 0.13780701160430908
====> Epoch: 169 Average loss: 1729.8473
1729.1036376953125 0.13763070106506348
====> Epoch: 170 Average loss: 1729.2412
1728.992919921875 0.13745668530464172
====> Epoch: 171 Average loss: 1729.1304
1729.16845703125 0.1372832953929901
====> Epoch: 172 Average 

====> Epoch: 301 Average loss: 1728.9424
1728.935791015625 0.11562570929527283
====> Epoch: 302 Average loss: 1729.0514
1729.00927734375 0.1154724657535553
====> Epoch: 303 Average loss: 1729.1248
1728.7474365234375 0.11532041430473328
====> Epoch: 304 Average loss: 1728.8628
1728.9954833984375 0.11516916751861572
====> Epoch: 305 Average loss: 1729.1106
1728.6422119140625 0.11501896381378174
====> Epoch: 306 Average loss: 1728.7572
1728.9345703125 0.11486998200416565
====> Epoch: 307 Average loss: 1729.0494
1728.9290771484375 0.1147221028804779
====> Epoch: 308 Average loss: 1729.0438
1728.893798828125 0.11457276344299316
====> Epoch: 309 Average loss: 1729.0084
1728.8763427734375 0.1144268810749054
====> Epoch: 310 Average loss: 1728.9907
1728.9251708984375 0.11427980661392212
====> Epoch: 311 Average loss: 1729.0394
1728.911376953125 0.11413440108299255
====> Epoch: 312 Average loss: 1729.0255
1728.602783203125 0.11399012804031372
====> Epoch: 313 Average loss: 1728.7168
1728.899658

1728.805419921875 0.09758582711219788
====> Epoch: 448 Average loss: 1728.9030
1728.7821044921875 0.09748765826225281
====> Epoch: 449 Average loss: 1728.8796
1728.718505859375 0.09738728404045105
====> Epoch: 450 Average loss: 1728.8159
1728.93505859375 0.0972890853881836
====> Epoch: 451 Average loss: 1729.0323
1729.0400390625 0.09718987345695496
====> Epoch: 452 Average loss: 1729.1372
1728.76416015625 0.09708940982818604
====> Epoch: 453 Average loss: 1728.8612
1728.6917724609375 0.09699153900146484
====> Epoch: 454 Average loss: 1728.7888
1728.6390380859375 0.09689316153526306
====> Epoch: 455 Average loss: 1728.7360
1728.718505859375 0.09679299592971802
====> Epoch: 456 Average loss: 1728.8153
1728.7947998046875 0.09669330716133118
====> Epoch: 457 Average loss: 1728.8915
1729.009521484375 0.0965927243232727
====> Epoch: 458 Average loss: 1729.1061
1728.779541015625 0.09649428725242615
====> Epoch: 459 Average loss: 1728.8760
1728.7098388671875 0.09639418125152588
====> Epoch: 46

1728.8126220703125 0.08480063080787659
====> Epoch: 597 Average loss: 1728.8975
1728.7264404296875 0.08472928404808044
====> Epoch: 598 Average loss: 1728.8112
1728.80078125 0.08465588092803955
====> Epoch: 599 Average loss: 1728.8855
1728.7811279296875 0.08458119630813599
====> Epoch: 600 Average loss: 1728.8657
1728.59912109375 0.08450895547866821
====> Epoch: 601 Average loss: 1728.6836
1728.7713623046875 0.08443456888198853
====> Epoch: 602 Average loss: 1728.8558
1728.73486328125 0.08436229825019836
====> Epoch: 603 Average loss: 1728.8192
1728.750244140625 0.08428746461868286
====> Epoch: 604 Average loss: 1728.8345
1728.7445068359375 0.08421555161476135
====> Epoch: 605 Average loss: 1728.8287
1728.7562255859375 0.08414214849472046
====> Epoch: 606 Average loss: 1728.8403
1728.762939453125 0.08406996726989746
====> Epoch: 607 Average loss: 1728.8470
1728.8841552734375 0.08399593830108643
====> Epoch: 608 Average loss: 1728.9681
1728.8233642578125 0.08392107486724854
====> Epoch:

1728.7645263671875 0.07717332243919373
====> Epoch: 730 Average loss: 1728.8417
1728.7679443359375 0.07712393999099731
====> Epoch: 731 Average loss: 1728.8451
1728.8843994140625 0.07707327604293823
====> Epoch: 732 Average loss: 1728.9614
1728.58056640625 0.0770207941532135
====> Epoch: 733 Average loss: 1728.6576
1728.684326171875 0.07696837186813354
====> Epoch: 734 Average loss: 1728.7614
1728.7972412109375 0.07691588997840881
====> Epoch: 735 Average loss: 1728.8741
1728.7825927734375 0.07686305046081543
====> Epoch: 736 Average loss: 1728.8595
1728.729248046875 0.07681122422218323
====> Epoch: 737 Average loss: 1728.8060
1728.576171875 0.07675826549530029
====> Epoch: 738 Average loss: 1728.6530
1728.6644287109375 0.07670512795448303
====> Epoch: 739 Average loss: 1728.7411
1728.6680908203125 0.07665213942527771
====> Epoch: 740 Average loss: 1728.7448
1728.649658203125 0.07659795880317688
====> Epoch: 741 Average loss: 1728.7262
1728.8721923828125 0.07654377818107605
====> Epoch

1728.65576171875 0.07140174508094788
====> Epoch: 851 Average loss: 1728.7272
1728.7403564453125 0.0713576078414917
====> Epoch: 852 Average loss: 1728.8118
1728.6790771484375 0.07131442427635193
====> Epoch: 853 Average loss: 1728.7504
1728.6427001953125 0.07127127051353455
====> Epoch: 854 Average loss: 1728.7140
1728.6590576171875 0.07122629880905151
====> Epoch: 855 Average loss: 1728.7302
1728.6180419921875 0.0711817741394043
====> Epoch: 856 Average loss: 1728.6892
1728.7410888671875 0.07113805413246155
====> Epoch: 857 Average loss: 1728.8123
1728.7796630859375 0.07109248638153076
====> Epoch: 858 Average loss: 1728.8507
1728.651611328125 0.07104906439781189
====> Epoch: 859 Average loss: 1728.7227
1728.6729736328125 0.07100322842597961
====> Epoch: 860 Average loss: 1728.7440
1728.776123046875 0.0709608793258667
====> Epoch: 861 Average loss: 1728.8470
1728.778076171875 0.07091712951660156
====> Epoch: 862 Average loss: 1728.8490
1728.7918701171875 0.07087227702140808
====> Epo

====> Epoch: 977 Average loss: 1728.8151
1728.7421875 0.06583493947982788
====> Epoch: 978 Average loss: 1728.8080
1728.7196044921875 0.06579989194869995
====> Epoch: 979 Average loss: 1728.7854
1728.640625 0.06576496362686157
====> Epoch: 980 Average loss: 1728.7064
1728.5242919921875 0.06573125720024109
====> Epoch: 981 Average loss: 1728.5900
1728.7734375 0.06569728255271912
====> Epoch: 982 Average loss: 1728.8391
1728.6536865234375 0.06566417217254639
====> Epoch: 983 Average loss: 1728.7194
1728.7467041015625 0.06563270092010498
====> Epoch: 984 Average loss: 1728.8124
1728.73486328125 0.06560066342353821
====> Epoch: 985 Average loss: 1728.8004
1728.821044921875 0.06556981801986694
====> Epoch: 986 Average loss: 1728.8866
1728.6622314453125 0.06553864479064941
====> Epoch: 987 Average loss: 1728.7278
1729.0008544921875 0.06550788879394531
====> Epoch: 988 Average loss: 1729.0664
1728.8369140625 0.06547865271568298
====> Epoch: 989 Average loss: 1728.9023
1728.6630859375 0.065449

In [10]:
latents

array([[-2.46302855e+00, -3.01149248e-01],
       [-3.93871233e-01, -4.74494984e-01],
       [-3.41833964e+00,  1.30048731e+00],
       [-2.66616676e-01, -7.15614557e-01],
       [-8.13329838e-01, -4.79312550e-01],
       [-2.03814592e+00, -2.31928933e-01],
       [-1.80596358e+00, -2.72493535e-01],
       [ 8.24112654e-01, -5.51960053e-01],
       [ 1.04319232e+00, -1.26620675e+00],
       [-2.01638533e+00, -5.46127191e-02],
       [-1.05960006e+00,  5.27872540e-01],
       [-9.92382986e-01,  5.67340134e-01],
       [ 6.75747315e-01, -1.73234232e+00],
       [ 1.44628421e-02,  5.08113517e-01],
       [ 4.87751986e-01, -1.04233685e-01],
       [-4.56448260e-01, -9.98659237e-01],
       [-1.75859615e+00,  1.72952256e-01],
       [ 8.89235717e-01,  1.20044321e+00],
       [-4.34446940e-01,  5.68728598e-01],
       [-7.85308551e-01,  3.25522574e-01],
       [-1.59728222e-01,  6.88992413e-01],
       [ 9.05940084e-02, -8.02697053e-01],
       [-1.66201506e+00, -1.64854816e+00],
       [-3.

In [11]:
vae.encoder(data)[0]

tensor([[[[-8.5719e-05,  1.4682e-02]]],


        [[[-7.5200e-03,  2.2104e-03]]],


        [[[ 9.1834e-03,  3.0187e-02]]],


        [[[-3.6969e-03, -1.9089e-03]]],


        [[[ 2.5406e-03,  8.8944e-03]]],


        [[[ 4.1707e-04,  1.1453e-02]]],


        [[[ 9.7147e-03,  2.2670e-02]]],


        [[[-6.5098e-03, -8.3263e-03]]],


        [[[ 1.0492e-02,  3.9531e-03]]],


        [[[ 2.8953e-03,  9.8300e-03]]],


        [[[-5.3638e-04,  1.4362e-02]]],


        [[[-1.4781e-02, -6.9339e-03]]],


        [[[ 1.0903e-02,  8.8270e-03]]],


        [[[-1.5220e-02, -2.0072e-02]]],


        [[[ 1.4465e-02,  9.4692e-03]]],


        [[[ 6.8946e-03,  1.1395e-02]]],


        [[[ 6.1718e-03,  1.0560e-02]]],


        [[[ 3.4476e-03, -1.1343e-02]]],


        [[[-9.8971e-03, -4.3321e-03]]],


        [[[ 2.0688e-03,  2.6793e-03]]],


        [[[ 1.0467e-03, -2.2670e-03]]],


        [[[ 2.7473e-03,  3.7903e-03]]],


        [[[-1.4962e-03,  2.0669e-02]]],


        [[[-3.4516e-03, -8.8762e-0

In [12]:
z=vae.encoder(data)
z[0].mean()

tensor(6.2959e-05, grad_fn=<MeanBackward0>)

In [13]:
latents.mean(),latents.std()

(-0.04911039556263489, 0.997719187773147)