In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from itertools import product

import numpy as np
import matplotlib.pyplot as plt

from matplotlib import animation, rc
from IPython.display import HTML

import torch
from torch import Tensor
from torch import Tensor as T
from torch.utils.data import DataLoader, Subset
from torch import nn
import torch.optim
from torch.optim import SGD
import torch.nn.functional as F

import keras
L = keras.layers

import sys
sys.path.append('../code')
import toy_data as toy
from vae import VAE, loss_function, cVAE

from sklearn import gaussian_process

from tqdm_utils import TqdmProgressCallback
from sklearn.decomposition import PCA
from tqdm import tqdm_notebook as tqdm

def plot_reconstruction(recon, orig):
    fig, ax = plt.subplots(ncols=2)
    ax[0].imshow(recon)
    ax[1].imshow(orig)
    ax[0].set_title('Reconstruction of an validation image.')
    ax[1].set_title('Original')
    plt.show()

# Some parameters:
N = 3600000  # number of observations
batch_size = 32
d = 64  # image edge length
D = d**2
img_shape = (d, d)
print(f"{N} points with {D} dimensions.")

# useful for plotting on a 3x3 grid:
to_ind = np.array(list(product(range(3), range(3))))

# setting up torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
print(f'Train with {device}')
device = torch.device(device)

# Toy data: Images of hierarchical structures without time dependency

Generate (d, d)-pixel images from 7 parameters: 
- 3 angles for each image, 
- 3 bone lengths and keypoint marker width shared for whole dataset. 

Origin of the 3-bone hierarchy is the central pixel.
End of bones are marked by squared exponentials.

In [None]:
# parameters:
eps = np.random.rand(3)
bone_lengths = d//6 * (eps/2+1-1/3)
print("Bone lengths:", bone_lengths)
key_marker_width = 1.5 * d/32
labels = 1/2*np.pi*(np.random.rand(N, 3)-0.5)
labels[:, 0] = labels[:, 0] * 4

# generate training data
h = toy.HierarchyImages(angles=labels, bone_lengths=bone_lengths,
                        key_marker_width=key_marker_width,
                        img_shape=img_shape)

# data loader for easy batching
data_loader = DataLoader(h, batch_size=64, shuffle=True, num_workers=4)

# generate validation data
labels_val = 1/2*np.pi*(np.random.rand(N, 3)-0.5)
labels_val[:, 0] = labels_val[:, 0] * 4
h_val = toy.HierarchyImages(angles=labels, bone_lengths=bone_lengths,
                            key_marker_width=key_marker_width,
                            img_shape=img_shape)

val_loader = DataLoader(h, batch_size=64, shuffle=False, num_workers=4)

# generate, encode, and decode new image for validation
idx = np.random.randint(0, len(h_val))
test_img = h_val[idx]['image']
test_angles = h_val[idx]['angles']

In [None]:
h.plot_image(np.random.randint(0, len(labels)))

## PCA

In [None]:
# Use subset of data because of memory restriction
idxs = np.random.choice(range(len(h)), replace=False, size=D)
imgs = np.array([h[i]['image'] for i in idxs])
X = np.reshape(imgs, (len(idxs), d**2))
pca = PCA().fit(X)
# n_comp = pca.n_components

vexpl = np.cumsum(pca.explained_variance_ratio_)
plt.plot(vexpl)
thresh = 0.95
n_comp = np.where(vexpl > thresh)[0][0]
plt.vlines(n_comp, 0, 1, linestyles='dotted')
plt.hlines(thresh, 0, 800, linestyles='dotted')
plt.title(f'{n_comp} components explain {thresh} of the variance')
plt.show()

In [None]:
# Plot first 9 components
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
ax[0, 1].set_title('First 9 principle components are fourier decomposition on a circle.')
for i in range(9):
    x, y = to_ind[i]
    comp = pca.components_[i].reshape(img_shape)
    ax[x, y].imshow(comp)
# plt.tight_layout()
plt.show()

In [None]:
w = pca.transform(test_img.ravel()[None])[:, :n_comp]

recon_img = np.dot(w, pca.components_[:n_comp])[0].reshape(*img_shape)
recon_img = recon_img + pca.mean_.reshape(*img_shape)

plot_reconstruction(recon_img, test_img)

In [None]:
# generate image from noise
W = pca.transform(X)[:, :n_comp]

# Using uniform noise because it matches the
# statistics of w better than gauss

min_w = W.min(axis=0)
span = W.max(axis=0) - min_w

n_examples = 6
eps = np.random.rand(n_examples, n_comp)
w_rand = span[None] * eps + min_w[None]

generated = np.dot(w_rand, pca.components_[:n_comp])

fig, ax = plt.subplots(ncols=n_examples, sharey=True)
fig.set_size_inches(n_examples*3, 3)
ax[n_examples//2].set_title('Using weights sampled from \nuniform distribution to generate image.')
for i in range(n_examples): 
    ax[i].imshow(generated[i].reshape(*img_shape))
plt.show()

## PCA by autoencoder

In [None]:
def build_pca_autoencoder(img_shape, code_size):
    """
    Here we define a simple linear autoencoder.
    We also flatten and un-flatten data to be compatible with image shapes
    """
    
    encoder = keras.models.Sequential()
    encoder.add(L.InputLayer(img_shape))
    encoder.add(L.Flatten())                  #flatten image to vector
    hidden = L.Dense(code_size)
    encoder.add(hidden)           #actual encoder

    decoder = keras.models.Sequential()
    decoder.add(L.InputLayer((code_size,)))
    decoder.add(L.Dense(np.prod(img_shape)))  #actual decoder, height*width*3 units
    decoder.add(L.Reshape(img_shape))         #un-flatten
    
    return encoder, decoder

In [None]:
encoder, decoder = build_pca_autoencoder(img_shape, n_comp)
inp = L.Input(img_shape)
code = encoder(inp)
reconstruction = decoder(code)

autoencoder = keras.models.Model(inputs=inp, outputs=reconstruction)
autoencoder.compile(optimizer='adamax', loss='mse')

autoencoder.fit(x=imgs, y=imgs, epochs=7,
#                 validation_data=[X_test, X_test],
                callbacks=[TqdmProgressCallback()],
                verbose=False)

In [None]:
W, bias = encoder.get_weights()

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
ax[0, 1].set_title('9 autoencoder components.')
for i in range(9):
    idx = np.random.randint(0, 100)
    x, y = to_ind[i]
    i = np.random.randint(0, W.shape[1])
    ax[x, y].imshow(np.reshape(W[:, idx] + bias[idx], img_shape))
plt.show()

```python
    s = keras.backend.get_session()
    s.run(hidden, {input_1: test_img})
    w = encoder(L.Input(test_img))
    recon_img = reconstruction(w)
    recon_img
```

-> I don't know how to access the hidden layer with keras. 
Therefore I switch to pytorch.

# Linear Autoencoder: Pytorch

In [None]:
class LinearAutoencoder(nn.Module):
    def __init__(self, code_size=10):
        nn.Module.__init__(self)
        self.encoder = nn.Linear(D, code_size)
        self.decoder = nn.Linear(code_size, D)
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, h):
        return self.decoder(h)
    
    def forward(self, x):
        w = self.encode(x.view(-1, D))
        return self.decode(w)

def loss(recon_x, x):
    return torch.mean((recon_x-x)**2)

In [None]:
la = LinearAutoencoder(code_size=n_comp)
la.to(device)

optimizer = torch.optim.Adam(la.parameters(), lr=1e-3)

train_losses = []
for epoch in range(1):
    with tqdm(data_loader, leave=True) as pbar:
        for batch in pbar:
            data = batch['image'].float().to(device)
            optimizer.zero_grad()
            recon = la(data.view(-1, D))
            l = F.mse_loss(recon, data.view(-1, D))
            l.backward()
            optimizer.step()
            optimizer.zero_grad()
            train_losses.append(l)
            pbar.set_description(f'Epoch {epoch}: MSE = {l:.5f}')
        
# plt.plot(train_losses)
# plt.show()

In [None]:
recon = la(Tensor(test_img).view(-1, D).to(device)).cpu().detach().numpy().reshape(img_shape)
plot_reconstruction(recon, test_img)

## VAE

In [None]:
bottleneck = 3
model = VAE(input_dim=D, bottleneck=bottleneck,
            inter_dim=150).to(device)
train_loss = model.fit(data_loader, epochs=1,
                       verbose=True)

plt.plot(train_loss[100:])
plt.show()

In [None]:
with torch.no_grad():
    sample = torch.randn(9, bottleneck).to(device)
    sample = model.decode(sample).cpu().numpy()

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
fig.set_size_inches((10, 10))
for i, img in enumerate(sample):
    ind = to_ind[i]
    ax[ind[0], ind[1]].imshow(img.reshape(*img_shape))
plt.tight_layout()
plt.show()

In [None]:
recon, mu,logvar = model(Tensor(test_img).to(device))
print(mu)

recon = recon.cpu().detach().numpy().reshape(img_shape)
plot_reconstruction(recon, test_img)

In [None]:
def get_img(ind):
    mu_np = mu.cpu().detach().numpy()
    fig, ax = plt.subplots()
    for x in np.linspace(-4, 4):
        mu_np[:, ind] = x
        mu_star = Tensor(mu_np).to(device)
        gen = model.decode(mu_star)
        gen = gen.cpu().detach().numpy().reshape(img_shape)
        yield (x, gen)

fig, ax = plt.subplots()
img = h[0]['image']

mimg = plt.imshow(img)
plt.close()

def init():
    mimg.set_data(img)
    return (mimg,)

def animate(img):
    x, img = img
    ax.set_title(f'{x:.3}')
    mimg.set_data(img)
    return (mimg,)

vids = []
for i in range(mu.shape[1]):
    g = get_img(i)
    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                   frames=g, interval=120, 
                                   blit=True)
    vids += [anim.to_html5_video()]
    plt.close()

In [None]:
HTML(vids[0])

In [None]:
HTML(vids[1])

In [None]:
HTML(vids[2])

## Conditional VAE

In [None]:
def fit(model, data_loader, epochs=5, max_iter=200, verbose=True):
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
        model.train()
        train_loss = []
        for epoch in range(epochs):
            with tqdm(data_loader) as pbar:
                batch_idx = 0
                for batch in pbar:
                    batch_idx += 1
                    if batch_idx > max_iter:
                        print("Reached maximal number of iterations.")
                        break
                    img = batch['image'].view(data_loader.batch_size, -1)
                    label = batch['angles']
                    batch = torch.cat((img, label), dim=1).float()
                    data = T(batch).to(device).float()
                    model.optimizer.zero_grad()
                    recon_batch, mu, logvar = model(data)
                    loss = loss_function(recon_batch, data.view(-1, D+3),
                                         mu, logvar)
                    loss.backward()
                    train_loss += [loss.item()/len(data)]
                    model.optimizer.step()
                    if verbose:
                        pbar.set_description(
                            f"Epoch {epoch}, Loss at batch {batch_idx:05d}: {loss.item()/len(data):.1f}"
                        )
        return train_loss

In [None]:
bottleneck = 3
cvae = cVAE(input_dim=D+3, bottleneck=bottleneck,
            cond_data_len=3,
            hidden=150).to(device)

hist = fit(cvae, data_loader, max_iter=1e5)

In [None]:
test_np = np.concatenate((h[0]['image'].ravel(), h[0]['angles']))[None]

recon, mu, logvar = cvae(T(test_np))
image = recon.detach().numpy()[0, :D].reshape(d, d)
angles = recon.detach().numpy()[0, D:]
plot_reconstruction(image, h[0]['image'])
a, b = h[0]['angles'], angles
a/a.sum(), b/b.sum()

# Hierarchical image 

In [None]:
# introduce time dependency
rbf = gaussian_process.kernels.RBF(length_scale=2)
rbf_slow = gaussian_process.kernels.RBF(length_scale=5)
GP = gaussian_process.GaussianProcessRegressor(kernel=rbf)
GP_slow = gaussian_process.GaussianProcessRegressor(kernel=rbf_slow)

N = 250
t = np.linspace(0, 120, N)
y = np.empty((N, len(bone_lengths)))
y[:, 0] = GP_slow.sample_y(t[:, None], random_state=None)[:, 0]*3
for i in range(1, len(bone_lengths)):
    y[:, i] = GP.sample_y(t[:, None], random_state=None)[:, 0]*0.7
labels = y
plt.plot(t, labels)

# angles can not escape [-np.pi, np.pi]
idx = abs(y) > np.pi
y[idx] = y[idx] - 2*np.sign(y[idx])*np.pi

In [None]:
h = toy.HierarchyImages(labels, bone_lengths, key_marker_width=key_marker_width, img_shape=img_shape)
imgs = [h[i]['image'] for i in range(len(labels))]

In [None]:
fig, ax = plt.subplots()
img = imgs[0]
mimg = plt.imshow(img)
plt.close()

def init():
    mimg.set_data(img)
    return (mimg,)

def animate(i):
    img = imgs[i]
    mimg.set_data(img)
    return (mimg,)

anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=len(imgs), interval=60, 
                               blit=True)
# anim.save(f'Toyproblem_unambiguous_{d}x{d}.mp4')
html_video = anim.to_html5_video()

In [None]:
HTML(html_video)