In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from itertools import product

import numpy as np
import matplotlib.pyplot as plt

from matplotlib import animation, rc
from IPython.display import HTML

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import SGD

import keras
L = keras.layers

import sys
sys.path.append('../code')
import toy_data as toy
from vae import VAE

from sklearn import gaussian_process

from tqdm_utils import TqdmProgressCallback
from sklearn.decomposition import PCA

# Some parameters:
N = 2500  # number of observations
batch_size = 32
d = 28  # image edge length
D = d**2
img_shape = (d, d)

# useful for plotting on a 3x3 grid:
to_ind = np.array(list(product(range(3), range(3))))

# Toy data: Images of hierarchical structures without time dependency

Generate (d, d)-pixel images from 7 parameters: 
- 3 angles for each image, 
- 3 bone lengths and keypoint marker width shared for whole dataset. 

Origin of the 3-bone hierarchy is the central pixel.
End of bones are marked by squared exponentials.

In [None]:
# parameters:
bone_lengths = d//3*(np.random.rand(3)+1)/2

print("Bone lengths:", bone_lengths)
key_marker_width = 1
labels = 2*np.pi*(np.random.rand(N, 3)-0.5)
h = toy.HierarchyImages(angles=labels, bone_lengths=bone_lengths, key_marker_width=2)

    # using batch generator for possibly huge dataset
    batch_generator = toy.make_batch_generator(labels, bone_lengths, N)
    imgs, labels = batch_generator.__next__()
    print('imgs.shape:', imgs.shape)
    plt.imshow(imgs[0, 0])
    plt.title('Example of toy data image with 3 bones.')
    plt.show()

In [None]:
h.plot_image(np.random.randint(0, len(labels)))

In [None]:
imgs = np.array([h[i]['image'] for i in range(len(labels))])
labels = np.array([h[i]['angles'] for i in range(len(labels))])

## PCA

In [None]:
X = np.reshape(imgs, (N, 28**2))
pca = PCA().fit(X)

vexpl = np.cumsum(pca.explained_variance_ratio_)
plt.plot(vexpl)
thresh = 0.95
n_comp = np.where(vexpl > thresh)[0][0]
plt.vlines(n_comp, 0, 1, linestyles='dotted')
plt.hlines(thresh, 0, 800, linestyles='dotted')
plt.title(f'{n_comp} components explain {thresh} of the variance')
plt.show()

In [None]:
# Plot first 9 components
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
ax[0, 1].set_title('First 9 principle components are fourier decomposition on a circle.')
for i in range(9):
    x, y = to_ind[i]
    comp = pca.components_[i].reshape((28, 28))
    ax[x, y].imshow(comp)
# plt.tight_layout()
plt.show()

In [None]:
# generate, encode, and decode new image for validation
coords = toy.forward([-np.pi, -np.pi/2, -np.pi/2], [1, 1, 1])
test_img = toy.keypoint_to_image(coords)
test_img.flatten()[None].shape
w = pca.fit_transform(X)[:, :n_comp]

idx = np.random.randint(0, N)
recon_img = np.dot(w, pca.components_[:n_comp])[idx].reshape(28, 28)
plt.imshow(recon_img)
plt.title('Reconstruction of an validation image.')
plt.show()

In [None]:
# generate image from noise
w_rand = np.random.multivariate_normal(np.zeros(n_comp), np.diag(w.var(axis=0)))
plt.imshow(np.dot(w_rand, pca.components_[:n_comp]).reshape(28, 28))
plt.title('Using weights sampled from normal distribution to generate image.')
plt.show()

## PCA by autoencoder

In [None]:
def build_pca_autoencoder(img_shape, code_size):
    """
    Here we define a simple linear autoencoder.
    We also flatten and un-flatten data to be compatible with image shapes
    """
    
    encoder = keras.models.Sequential()
    encoder.add(L.InputLayer(img_shape))
    encoder.add(L.Flatten())                  #flatten image to vector
    hidden = L.Dense(code_size)
    encoder.add(hidden)           #actual encoder

    decoder = keras.models.Sequential()
    decoder.add(L.InputLayer((code_size,)))
    decoder.add(L.Dense(np.prod(img_shape)))  #actual decoder, height*width*3 units
    decoder.add(L.Reshape(img_shape))         #un-flatten
    
    return encoder, decoder

In [None]:
encoder, decoder = build_pca_autoencoder((28, 28), n_comp)
inp = L.Input(img_shape)
code = encoder(inp)
reconstruction = decoder(code)

autoencoder = keras.models.Model(inputs=inp, outputs=reconstruction)
autoencoder.compile(optimizer='adamax', loss='mse')

autoencoder.fit(x=imgs, y=imgs, epochs=15,
#                 validation_data=[X_test, X_test],
#                 callbacks=[TqdmProgressCallback()],
                verbose=False)

In [None]:
W, bias = encoder.get_weights()
X[0].dot(W) + bias

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
ax[0, 1].set_title('9 autoencoder components.')
for i in range(9):
    idx = np.random.randint(0, 100)
    x, y = to_ind[i]
    i = np.random.randint(0, W.shape[1])
    ax[x, y].imshow(np.reshape(W[:, idx] + bias[idx], (28, 28)))
plt.show()

In [None]:
# generate, encode, and decode new image for validation
coords = toy.forward([-np.pi, -np.pi/2, -np.pi/2], [10, 10, 10])
test_img = toy.keypoint_to_image(coords)
test_img.flatten()[None].shape
plt.imshow(test_img)

```python
    s = keras.backend.get_session()
    s.run(hidden, {input_1: test_img})
    w = encoder(L.Input(test_img))
    recon_img = reconstruction(w)
    recon_img
```

-> I don't know how to access the hidden layer with keras. 
Therefore I switch to pytorch.

# Linear Autoencoder: Pytorch

In [None]:
class LinearAutoencoder(nn.Module):
    def __init__(self, code_size=10):
        nn.Module.__init__(self)
        self.encoder = nn.Linear(784, code_size)
        self.decoder = nn.Linear(code_size, 784)
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, h):
        return self.decoder(h)
    
    def forward(self, x):
        w = self.encode(x.view(-1, 784))
        return self.decode(w)

def loss(recon_x, x):
    return torch.mean((recon_x-x)**2)

In [None]:
la = LinearAutoencoder()
data_loader = DataLoader(h, batch_size=64)
optimizer = SGD(la.parameters(), lr=1e-3)
device = torch.device('cuda')
train_losses = []
for epoch in range(100):
    for batch in data_loader:
        data = batch['image'].float()
        optimizer.zero_grad()
        recon = la(data.view(-1, 784))
        l = loss(recon, data.view(-1, 784))
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_losses.append(l)

In [None]:
plt.plot(train_losses)

In [None]:
plt.imshow(la(Tensor(test_img)).view(28,28).detach().numpy())

## VAE

In [None]:
batch_generator = toy.make_batch_generator(labels, [a,b, c], batch_size)
device = torch.device("cpu")
bottleneck = 10
model = VAE(bottleneck=bottleneck).to(device)
train_loss = model.fit(batch_generator, max_iter=200,
                      verbose=False)

plt.plot(train_loss)
plt.show()

In [None]:
with torch.no_grad():
    sample = torch.randn(9, bottleneck).to(device)
    sample = model.decode(sample).cpu()

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
fig.set_size_inches((10, 10))
for i, img in enumerate(np.array(sample)):
    ind = to_ind[i]
    ax[ind[0], ind[1]].imshow(img.reshape(28, 28))
plt.tight_layout()
plt.show()

# Hierarchical image date with time dependencies

In [None]:
# introduce time dependency
rbf = gaussian_process.kernels.RBF(length_scale=2)
GP = gaussian_process.GaussianProcessRegressor(kernel=rbf)

t = np.linspace(0, 120, N)
y = np.empty((N, 3))
y[:, 0] = GP.sample_y(t[:, None], random_state=None)[:, 0]
y[:, 1] = GP.sample_y(t[:, None], random_state=None)[:, 0]
y[:, 2] = GP.sample_y(t[:, None], random_state=None)[:, 0]
labels = y
plt.plot(t, labels)

In [None]:
a, b, c = 20//3*(np.random.rand(3)+1)/2
imgs = []
for label in labels:
    coords = toy.forward(label, [a, b, c])
    imgs += [toy.keypoint_to_image(coords, include_origin=True)]

In [None]:
fig, ax = plt.subplots()
img = imgs[0]
mimg = plt.imshow(img)

plt.close()

def init():
    mimg.set_data(img)
    return (mimg,)

def animate(i):
    img = imgs[2*i]
    mimg.set_data(img)
    return (mimg,)

anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=len(imgs)//2, interval=25, 
                               blit=True)
# anim.save('Toyproblem.mp4')
html_video = anim.to_html5_video()

In [None]:
HTML(html_video)