In [69]:
import os

import numpy as np
from worldmodels.params import home

from worldmodels.vision.vae import VAE
from worldmodels.params import vae_params

import matplotlib.pyplot as plt

import matplotlib
matplotlib.rcParams.update({'font.size': 8})

def load_episode(seed):
    from collections import defaultdict
    path = os.path.join(home, 'controller-samples', str(seed))

    data = defaultdict(list)
    files = os.listdir(path)
    files = [r for r in files if 'DS_Store' not in r]

    for fi in files:
        if len(fi) > 4:
            data[fi[:-4]] = np.load(
                os.path.join(path, fi)
            )

    data['labels'] = data['latent'][1:].reshape(-1, 32)
    data['preds'] = data['pred-latent'][:-1].reshape(-1, 32)
    data['error'] = np.mean(np.abs(data['labels'] - data['preds']), axis=1)
    
    return data

In [None]:
image_files = []

for idx in range(1000):
    dpi = 120
    base = (256, 128)
    scale = 5
    fig = plt.figure(constrained_layout=False, figsize=(scale*base[0]/dpi, scale*base[1]/dpi), dpi=dpi)

    gs = fig.add_gridspec(3, 3)

    widths = [3, 3, 3]
    heights = [3, 1, 1]
    gs = fig.add_gridspec(ncols=3, nrows=3, width_ratios=widths,
                              height_ratios=heights)

    ax0 = fig.add_subplot(gs[0, 0])

    seed = '2357136044'
    data = load_episode(seed)

    im = data['observation'][idx]
    ax0.imshow(im)
    ax0.set_title('observation')

    re = data['reconstruct'][idx].reshape(64, 64, 3)
    ax1 = fig.add_subplot(gs[0, 1])

    ax1.imshow(re)
    ax1.set_title('reconstruction')

    pred_re = data['pred-reconstruct'][idx].reshape(64, 64, 3)
    ax2 = fig.add_subplot(gs[0, 2])

    ax2.imshow(pred_re)
    ax2.set_title('pred-reconstruction')

    ax31 = fig.add_subplot(gs[1, 0])
    ax31.plot(data['vae-loss-reconstruct'], linewidth=0.3)
    ax31.plot(idx, data['vae-loss-reconstruct'][idx], marker='o', alpha=0.5)
    ax31.set_title('reconstruction loss')
    ax31.set_ylim((0))

    ax32 = fig.add_subplot(gs[2, 0])
    ax32.plot(data['vae-loss-unclipped-kl'], linewidth=0.3)
    ax32.plot(idx, data['vae-loss-unclipped-kl'][idx], marker='o', alpha=0.5)
    ax32.set_title('kl loss')

    ax4 = fig.add_subplot(gs[1:, 1]) 
    ax4.plot(data['error'], color='red', label='memory mae', linewidth=0.3)
    ax4.set_title('memory mae')
    ax4.plot(idx, data['error'][idx], marker='o', alpha=0.5)

    ax5 = fig.add_subplot(gs[1:, 2]) 
    ax5.plot(data['total-reward'], linewidth=0.3)
    ax5.set_title('total-reward')
    ax5.plot(idx, data['total-reward'][idx], marker='o', alpha=0.5)

    plt.subplots_adjust(wspace=0.4, hspace=0.5)

    fig.suptitle('step {} - rew {:3.1f} - seed {}'.format(idx, data['total-reward'][-1], seed))
    out_dir = os.path.join(home, 'controller-samples', str(seed), 'gif')
    os.makedirs(out_dir, exist_ok=True)
    f_name = os.path.join(out_dir, '{}.png'.format(idx))
    fig.savefig(f_name)
    image_files.append(imageio.imread(f_name))
    print(f_name)
    
anim_file = os.path.join(path, 'rollout.gif')
print('saving to gif')
imageio.mimsave(anim_file, image_files, duration=0.2)

/Users/adam/world-models-experiments/controller-samples/2357136044/gif/0.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/1.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/2.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/3.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/4.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/5.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/6.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/7.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/8.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/9.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/10.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/11.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/12.



/Users/adam/world-models-experiments/controller-samples/2357136044/gif/20.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/21.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/22.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/23.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/24.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/25.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/26.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/27.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/28.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/29.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/30.png
/Users/adam/world-models-experiments/controller-samples/2357136044/gif/31.png
/Users/adam/world-models-experiments/controller-samples/23571360

In [None]:
name = 'observation'

root = os.path.join(home, 'controller-samples')
rollouts = os.listdir(root)
print(rollouts)

# for rollout in rollouts:

# for name in names
name = 'observation'
obser = np.load(
    os.path.join(root, '2357136044', '{}.npy'.format(name))
)

name = 'reconstruct'
reconstruct = np.load(
    os.path.join(root, '2357136044', '{}.npy'.format(name))
)

reconstruct = np.squeeze(reconstruct)

In [None]:

vae_params['load_model'] = True
vae = VAE(**vae_params)

In [None]:
n_sample = 2


dpi = 120
samples = np.linspace(0, obser.shape[0] - 1, 4).astype(int)

base = (256, 128)
scale = 5
f, axes = plt.subplots(int(len(samples) / 2), 4, figsize=(scale*base[0]/dpi, scale*base[1]/dpi), dpi=dpi)

sample = samples[0]
obs = obser[sample]
losses = vae.loss(np.expand_dims(obs, 0))
print(losses)
axes[0, 0].imshow(obs)
axes[0, 1].imshow(reconstruct[sample])
axes[0, 0].set_title('step {} \n true'.format(sample))
axes[0, 1].set_title('losses \n recon {:2.1f} kld {:2.1f}'.format(
    losses['reconstruction-loss'].numpy(), losses['unclipped-kl-loss'].numpy()[0]))

sample = samples[1]
obs = obser[sample]
losses = vae.loss(np.expand_dims(obs, 0))
axes[0, 2].imshow(obs)
axes[0, 3].imshow(reconstruct[sample])
axes[0, 2].set_title('step {} \n true'.format(sample))
axes[0, 3].set_title('losses \n recon {:2.1f} kld {:2.1f}'.format(
    losses['reconstruction-loss'].numpy(), losses['unclipped-kl-loss'].numpy()[0]))

sample = samples[2]
obs = obser[sample]
losses = vae.loss(np.expand_dims(obs, 0))
axes[1, 0].imshow(obs)
axes[1, 1].imshow(reconstruct[sample])
axes[1, 0].set_title('step {} \n true'.format(sample))
axes[1, 1].set_title('losses \n recon {:2.1f} kld {:2.1f}'.format(
    losses['reconstruction-loss'].numpy(), losses['unclipped-kl-loss'].numpy()[0]))

sample = samples[3]
obs = obser[sample]
losses = vae.loss(np.expand_dims(obs, 0))
axes[1, 2].imshow(obs)
axes[1, 3].imshow(reconstruct[sample])
axes[1, 2].set_title('step {} \n true'.format(sample))
axes[1, 3].set_title('losses \n recon {:2.1f} kld {:2.1f}'.format(
    losses['reconstruction-loss'], losses['unclipped-kl-loss'].numpy()[0]))

for ax in axes.flatten():
    ax.set_yticklabels([])
    ax.set_xticklabels([])

plt.subplots_adjust(wspace=0.1, hspace=0.3)
#f.tight_layout()
f.show()

## Distribution of episode rewards

In [None]:
root = os.path.join(home, 'controller-samples')
rollouts = os.listdir(root)

rollouts = [r for r in rollouts if 'DS_Store' not in r]
rews = []
for rollout in rollouts:

# for name in names
    name = 'total-reward'
    rew = np.load(
        os.path.join(root, str(rollout), '{}.npy'.format(name))
    )

    rews.append(rew[-1])
    
print(np.mean(rews))

import matplotlib.pyplot as plt

_ = plt.hist(rews)

## Next obs, reconstructed obs

In [None]:
obser.shape

name = 'pred-reconstruct'
pred_reconstruct = np.load(
    os.path.join(root, '2357136044', '{}.npy'.format(name))
)

pred_reconstruct.shape

obser = obser[1:]
pred_reconstruct = pred_reconstruct[:-1]

dpi = 120
samples = np.linspace(0, obser.shape[0] - 1, 4).astype(int)
base = (256, 128)
scale = 5
f, axes = plt.subplots(int(len(samples) / 2), 4, figsize=(scale*base[0]/dpi, scale*base[1]/dpi), dpi=dpi)

sample = samples[0]
obs = obser[sample]
axes[0, 0].imshow(obs)
axes[0, 1].imshow(pred_reconstruct[sample])
axes[0, 0].set_title('step {} \n true'.format(sample))
axes[0, 1].set_title(' \n reconstructed pred.')

sample = samples[1]
obs = obser[sample]
axes[0, 2].imshow(obs)
axes[0, 3].imshow(pred_reconstruct[sample])
axes[0, 2].set_title('step {} \n true'.format(sample))
axes[0, 3].set_title(' \n reconstructed pred.')

sample = samples[2]
obs = obser[sample]
axes[1, 0].imshow(obs)
axes[1, 1].imshow(pred_reconstruct[sample])
axes[1, 0].set_title('step {} \n true'.format(sample))
axes[1, 1].set_title(' \n reconstructed pred.')

sample = samples[3]
obs = obser[sample]
axes[1, 2].imshow(obs)
axes[1, 3].imshow(pred_reconstruct[sample])
axes[1, 2].set_title('step {} \n true'.format(sample))
axes[1, 3].set_title(' \n reconstructed pred.')

for ax in axes.flatten():
    ax.set_yticklabels([])
    ax.set_xticklabels([])

f.show()


## Sample VAE from noise

In [None]:
noise = np.random.normal(0, 1, size=16*32).reshape(16, 32)

dpi = 120
base = (128, 128)
scale = 5
f, axes = plt.subplots(4, 4, figsize=(scale*base[0]/dpi, scale*base[1]/dpi), dpi=dpi)

decoded = vae.decode(noise)

for ax, dec in zip(axes.flatten(), decoded):
    ax.imshow(dec)

for ax in axes.flatten():
    ax.set_yticklabels([])
    ax.set_xticklabels([])

plt.subplots_adjust(wspace=0.0, hspace=0.1)

In [None]:
from sklearn.manifold import TSNE

name = 'latent'
latent = np.load(
    os.path.join(root, '2357136044', '{}.npy'.format(name))
)

embedded = TSNE(n_components=2).fit_transform(np.squeeze(latent))
print(embedded)

plt.scatter(embedded[:, 0], embedded[:, 1])

## es learning

In [None]:
from worldmodels.dataset.sample_policy import get_max_gen

res = []
ma = get_max_gen()
for gen in range(ma):
    res.append(np.load(os.path.join(home, 'control', 'generations', 'generation_{}'.format(gen), 'epoch-results.npy')))
    
res = np.array(res)

res.shape

In [None]:
plt.plot(np.mean(res, axis=1), label='mean +/- 1 std')
plt.plot(np.max(res, axis=1), label='maximum')

x = list(range(ma))
y = np.mean(res, axis=1)
error = 100
plt.fill_between(x, y-error, y+error, alpha=0.5)

plt.xlabel('generation')
plt.ylabel('total episode reward')

plt.axhline(900, color='orange')

plt.ylim((-200, 1000))

plt.legend()