In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from itertools import product

import numpy as np
import matplotlib.pyplot as plt

from matplotlib import animation, rc
from IPython.display import HTML

import numpy as np
import matplotlib.pyplot as plt
import sys

import torch

import keras
L = keras.layers

sys.path.append('../code')

import toy_data as toy
from vae import VAE

from sklearn import gaussian_process

N = 2500
batch_size = 32
img_shape = (28, 28)

from utils import TqdmProgressCallback

In [None]:
to_ind = np.array(list(product(range(3), range(3))))

In [None]:
def plot_mean_variance_imgs():
    average = np.zeros((28, 28))
    dev = np.zeros((28, 28))
    labels = 2*np.pi*(np.random.rand(N, 2)-0.5)
    batch_generator = list(toy.make_batch_generator(labels, [a,b], batch_size))
    for i, (batch_img, _ ) in enumerate(batch_generator):
        average = average + np.sum(batch_img[:, 0], axis=0)
        dev = dev + np.sum(batch_img[:, 0]**2, axis=0)
    average = average/i
    dev = dev/i

    fig, ax = plt.subplots(ncols=2)
    ax[0].imshow(average)
    ax[1].imshow(average**2-dev)
# plot_mean_variance_imgs()

# Hierarchical images without time dependency
## PCA

In [None]:
from sklearn.decomposition import PCA

In [None]:
# Independent across time
a, b = (np.random.rand(2)+1)/2
print("Bone lengths:", a, b)
labels = 2*np.pi*(np.random.rand(N, 2)-0.5)

In [None]:
batch_generator = toy.make_batch_generator(labels, [a,b], N)
imgs, labels = batch_generator.__next__()
imgs.shape

In [None]:
X = np.reshape(imgs, (N, 28**2))
X.shape

In [None]:
pca = PCA().fit(X)

vexpl = np.cumsum(pca.explained_variance_ratio_)
plt.plot(vexpl)
thresh = 0.95
n_comp = np.where(vexpl > thresh)[0][0]
plt.vlines(n_comp, 0, 1, linestyles='dotted')
plt.hlines(thresh, 0, 800, linestyles='dotted')
plt.title(f'{n_comp} components explain {thresh} of the variance')
plt.show()

In [None]:
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
for i in range(9):
    x, y = to_ind[i]
    comp = pca.components_[i].reshape((28, 28))
    ax[x, y].imshow(comp)
plt.tight_layout()

In [None]:
coords = toy.forward([-np.pi, -np.pi/2, -np.pi/2], [1, 1, 1])
test_img = toy.keypoint_to_image(coords)
plt.imshow(test_img)

## PCA by autoencoder

In [None]:
def build_pca_autoencoder(img_shape, code_size):
    """
    Here we define a simple linear autoencoder as described above.
    We also flatten and un-flatten data to be compatible with image shapes
    """
    
    encoder = keras.models.Sequential()
    encoder.add(L.InputLayer(img_shape))
    encoder.add(L.Flatten())                  #flatten image to vector
    encoder.add(L.Dense(code_size))           #actual encoder

    decoder = keras.models.Sequential()
    decoder.add(L.InputLayer((code_size,)))
    decoder.add(L.Dense(np.prod(img_shape)))  #actual decoder, height*width*3 units
    decoder.add(L.Reshape(img_shape))         #un-flatten
    
    return encoder,decoder

In [None]:
encoder, decoder = build_pca_autoencoder((28, 28), n_comp)
inp = L.Input(img_shape)
code = encoder(inp)
reconstruction = decoder(code)

autoencoder = keras.models.Model(inputs=inp, outputs=reconstruction)
autoencoder.compile(optimizer='adamax', loss='mse')

autoencoder.fit(x=imgs[:, 0], y=imgs[:, 0], epochs=15,
#                 validation_data=[X_test, X_test],
                callbacks=[TqdmProgressCallback()],
                verbose=0)

In [None]:
W, bias = encoder.get_weights()
X[0].dot(W) + bias

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
for i in range(9):
    x, y = to_ind[i]
    i = np.random.randint(0, W.shape[1])
    ax[x, y].imshow(np.reshape(W[:, i] + bias[i], (28, 28)))
plt.show()

## VAE

In [None]:
batch_generator = toy.make_batch_generator(labels, [a,b], batch_size)
device = torch.device("cpu")
bottleneck = 10
model = VAE(bottleneck=bottleneck).to(device)
train_loss = model.fit(batch_generator, max_iter=200)

plt.plot(train_loss)

In [None]:
with torch.no_grad():
    sample = torch.randn(9, bottleneck).to(device)
    sample = model.decode(sample).cpu()

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
fig.set_size_inches((10, 10))
for i, img in enumerate(np.array(sample)):
    ind = to_ind[i]
    ax[ind[0], ind[1]].imshow(img.reshape(28, 28))
plt.tight_layout()
plt.show()

# Hierarchical image date with time dependencies

In [None]:
# introduce time dependency
rbf = gaussian_process.kernels.RBF(length_scale=2)
GP = gaussian_process.GaussianProcessRegressor(kernel=rbf)

t = np.linspace(0, 120, N)
y = np.empty((N, 3))
y[:, 0] = GP.sample_y(t[:, None], random_state=None)[:, 0]
y[:, 1] = GP.sample_y(t[:, None], random_state=None)[:, 0]
y[:, 2] = GP.sample_y(t[:, None], random_state=None)[:, 0]
labels = y
plt.plot(t, labels)

In [None]:
a, b, c = (np.random.rand(3)+1)/2
imgs = []
for label in labels:
    coords = toy.forward(label, [a, b, c])
    imgs += [toy.keypoint_to_image(coords)]

In [None]:
fig, ax = plt.subplots()
img = imgs[0]
mimg = plt.imshow(img)

def init():
    mimg.set_data(img)
    return (mimg,)

def animate(i):
    img = imgs[2*i]
    mimg.set_data(img)
    return (mimg,)

anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=len(imgs)//2, interval=25, 
                               blit=True)

HTML(anim.to_html5_video())