In [2]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
#from tensorflow.examples.tutorials.mnist import input_data

import gym

In [18]:
## Get data to use for training
env = gym.make("Pong-v0")
observation = env.reset()
t = 300
obs = []
for i in range(t):
    observation,_,_,_ = env.step(env.action_space.sample())
    obs.append(observation)

In [23]:
X_dim = observation.size
Z_dim = 100
h_dim = 128
c = 0
lr = 1e-3

In [24]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [25]:
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

In [26]:
Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)

In [29]:
def Q(X):
    h = nn.relu(X.matmul(Wxh) + bxh.repeat(X.size(0), 1))
    z_mu = h.matmul(Whz_mu) + bhz_mu.repeat(h.size(0), 1)
    z_var = h.matmul(Whz_var) + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var

In [31]:
def sample_z(mu, log_var):
    # Using reparameterization trick to sample from a gaussian
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps

In [32]:
Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [33]:
def P(z):
    h = nn.relu(z.matmul(Wzh) + bzh.repeat(z.size(0), 1))
    X = nn.sigmoid(h.matmul(Whx) + bhx.repeat(h.size(0), 1))
    return X

In [34]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

In [35]:
np.random.choice(obs)

ValueError: a must be 1-dimensional

In [38]:
obs[i%(len(obs)-1)].shape

(210, 160, 3)

In [42]:
mb_size=1

In [None]:
for it in range(100000):
    X = obs[i%(len(obs)-1)]
    X = Variable(torch.from_numpy(X).float().view(-1))

    # Forward
    z_mu, z_var = Q(X)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)


    # Loss
    recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False)
    kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var)
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping
    for p in params:
        p.grad.data.zero_()