In [9]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
#from tensorflow.examples.tutorials.mnist import input_data

import gym

In [12]:
def prepro(I):
    """ prepro 210x160x3 uint8 frame into 6400 (80x80) 1D float vector """
    I = I[35:195] # crop
    I = I[::2,::2,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    #I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    #I = I[:-1,:,0]
    return I.astype(np.float)

In [13]:
## Get data to use for training
env = gym.make("Pong-v0")
observation = env.reset()
t = 300
obs = []
for i in range(t):
    observation,_,_,_ = env.step(env.action_space.sample())
    obs.append(prepro(observation))

In [72]:
X_dim = prepro(observation).size
Z_dim = 100
h_dim = 128
c = 0
lr = 1e-4

In [73]:
X_dim

6400

In [74]:
mb_size=1

In [75]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [76]:
Wxh = xavier_init(size=[X_dim, h_dim])
#bxh = Variable(torch.zeros(h_dim), requires_grad=True)

In [77]:
Whz_mu = xavier_init(size=[h_dim, Z_dim])
#bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
#bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)

In [78]:
def Q(X):
    #h = nn.relu(X.matmul(Wxh) + bxh.repeat(X.size(0), 1))
    h = nn.relu(X.matmul(Wxh))
    #print(h.shape)
    #z_mu = h.matmul(Whz_mu) + bhz_mu.repeat(h.size(0), 1)
    #z_var = h.matmul(Whz_var) + bhz_var.repeat(h.size(0), 1)
    z_mu = h.matmul(Whz_mu)
    z_var = h.matmul(Whz_var)
    return z_mu, z_var

In [79]:
def sample_z(mu, log_var):
    # Using reparameterization trick to sample from a gaussian
    eps = Variable(torch.randn(mb_size, Z_dim))
    #print(eps.shape)
    return mu + torch.exp(log_var / 2) * eps

In [80]:
Wzh = xavier_init(size=[Z_dim, h_dim])
#bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
#bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [81]:
def P(z):
    #print(z.shape)
    #h = nn.relu(z.matmul(Wzh) + bzh.repeat(z.size(0), 1))
    h = nn.relu(z.matmul(Wzh))
    #print(h.shape)
    #X = torch.sigmoid(h.matmul(Whx) + bhx.repeat(h.size(0), 1))
    X = torch.sigmoid(h.matmul(Whx))
    #print(X.shape)
    return X

In [82]:
#params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var, Wzh, bzh, Whx, bhx]
params = [Wxh, Whz_mu, Whz_var, Wzh, Whx]

solver = optim.Adam(params, lr=lr)

In [83]:
for it in range(100000):
    #print("begin")
    X = obs[it%(len(obs)-1)]
    X = Variable(torch.from_numpy(X).float().view(-1))

    # Forward
    #print("forward")
    z_mu, z_var = Q(X)
    #print(z_mu.shape, z_var.shape)
    z = sample_z(z_mu, z_var)
    #print(z.shape)
    X_sample = P(z)


    # Loss
    #print("loss")
    recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False)
    kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var)
    loss = recon_loss + kl_loss
    if it % 2 == 0:
        print(it, loss)

    # Backward
    #print("back")
    loss.backward()

    # Update
    #print("up")
    solver.step()

    # Housekeeping
    #print("house")
    for p in params:
        p.grad.data.zero_()

(0, tensor(32682552., grad_fn=<ThAddBackward>))
(2, tensor(1045214.8750, grad_fn=<ThAddBackward>))
(4, tensor(16366830., grad_fn=<ThAddBackward>))
(6, tensor(8841957., grad_fn=<ThAddBackward>))
(8, tensor(112620.6172, grad_fn=<ThAddBackward>))
(10, tensor(27842.1699, grad_fn=<ThAddBackward>))
(12, tensor(18083.3105, grad_fn=<ThAddBackward>))
(14, tensor(1017455.1250, grad_fn=<ThAddBackward>))
(16, tensor(1181759.3750, grad_fn=<ThAddBackward>))
(18, tensor(897115.5000, grad_fn=<ThAddBackward>))
(20, tensor(98178377337113435038547968., grad_fn=<ThAddBackward>))
(22, tensor(31062418628800771323968094208., grad_fn=<ThAddBackward>))
(24, tensor(580841382550343516160., grad_fn=<ThAddBackward>))
(26, tensor(5492958165419672720293448646656., grad_fn=<ThAddBackward>))
(28, tensor(128818253222848336886759424., grad_fn=<ThAddBackward>))
(30, tensor(102731679410547532211683328., grad_fn=<ThAddBackward>))
(32, tensor(99284510205475291136., grad_fn=<ThAddBackward>))
(34, tensor(147426635623740473344

(274, tensor(1185553580032., grad_fn=<ThAddBackward>))
(276, tensor(66931040256., grad_fn=<ThAddBackward>))
(278, tensor(3411168512., grad_fn=<ThAddBackward>))
(280, tensor(212433159323648., grad_fn=<ThAddBackward>))
(282, tensor(105286266257408., grad_fn=<ThAddBackward>))
(284, tensor(21807048640676705075200., grad_fn=<ThAddBackward>))
(286, tensor(6283699734434352726016., grad_fn=<ThAddBackward>))
(288, tensor(1299314193580687360., grad_fn=<ThAddBackward>))
(290, tensor(3728774406743822214103040., grad_fn=<ThAddBackward>))
(292, tensor(8824555770478592., grad_fn=<ThAddBackward>))
(294, tensor(1418393984., grad_fn=<ThAddBackward>))
(296, tensor(44894433280., grad_fn=<ThAddBackward>))
(298, tensor(11573479161394840435949568., grad_fn=<ThAddBackward>))
(300, tensor(1090992.6250, grad_fn=<ThAddBackward>))
(302, tensor(434258.6250, grad_fn=<ThAddBackward>))
(304, tensor(177655.6562, grad_fn=<ThAddBackward>))
(306, tensor(94244.7188, grad_fn=<ThAddBackward>))
(308, tensor(12089.5918, grad_

(554, tensor(2359598660976640., grad_fn=<ThAddBackward>))
(556, tensor(72449262069974302720., grad_fn=<ThAddBackward>))
(558, tensor(359584740144700194816., grad_fn=<ThAddBackward>))
(560, tensor(6108696263554617632733790208., grad_fn=<ThAddBackward>))
(562, tensor(208772870728615221369241600., grad_fn=<ThAddBackward>))
(564, tensor(3171457171456., grad_fn=<ThAddBackward>))
(566, tensor(10764572164096., grad_fn=<ThAddBackward>))
(568, tensor(153574555648., grad_fn=<ThAddBackward>))
(570, tensor(6620892160., grad_fn=<ThAddBackward>))
(572, tensor(94997472., grad_fn=<ThAddBackward>))
(574, tensor(1872995712., grad_fn=<ThAddBackward>))
(576, tensor(2007746560., grad_fn=<ThAddBackward>))
(578, tensor(175907749888., grad_fn=<ThAddBackward>))
(580, tensor(107712327188480., grad_fn=<ThAddBackward>))
(582, tensor(1384666049282048., grad_fn=<ThAddBackward>))
(584, tensor(1451678622675072188416., grad_fn=<ThAddBackward>))
(586, tensor(350817995456512., grad_fn=<ThAddBackward>))
(588, tensor(8707

(836, tensor(8697132049896067563520., grad_fn=<ThAddBackward>))
(838, tensor(548791056859136., grad_fn=<ThAddBackward>))
(840, tensor(14867446345361973575680., grad_fn=<ThAddBackward>))
(842, tensor(2252265280765952., grad_fn=<ThAddBackward>))
(844, tensor(5623514112., grad_fn=<ThAddBackward>))
(846, tensor(416905227665408., grad_fn=<ThAddBackward>))
(848, tensor(125726450365346852372480., grad_fn=<ThAddBackward>))
(850, tensor(9806093408366755840., grad_fn=<ThAddBackward>))
(852, tensor(972710215680., grad_fn=<ThAddBackward>))
(854, tensor(178037769371648., grad_fn=<ThAddBackward>))
(856, tensor(24273295763343303170326528., grad_fn=<ThAddBackward>))
(858, tensor(2491274501746023464960., grad_fn=<ThAddBackward>))
(860, tensor(1484825115932519038976., grad_fn=<ThAddBackward>))
(862, tensor(11936741720064., grad_fn=<ThAddBackward>))
(864, tensor(1310520149803008., grad_fn=<ThAddBackward>))
(866, tensor(956228409362556649472., grad_fn=<ThAddBackward>))
(868, tensor(1964964577280., grad_fn

(1114, tensor(26135414964224., grad_fn=<ThAddBackward>))
(1116, tensor(61347370893312., grad_fn=<ThAddBackward>))
(1118, tensor(1148939900323786391552., grad_fn=<ThAddBackward>))
(1120, tensor(78125899710464., grad_fn=<ThAddBackward>))
(1122, tensor(8309085184., grad_fn=<ThAddBackward>))
(1124, tensor(77080420352., grad_fn=<ThAddBackward>))
(1126, tensor(29595048., grad_fn=<ThAddBackward>))
(1128, tensor(1136444288., grad_fn=<ThAddBackward>))
(1130, tensor(66212566401024., grad_fn=<ThAddBackward>))
(1132, tensor(74016210794053632., grad_fn=<ThAddBackward>))
(1134, tensor(1451677496775165345792., grad_fn=<ThAddBackward>))
(1136, tensor(735911119322940964864., grad_fn=<ThAddBackward>))
(1138, tensor(1608207036241953177993216., grad_fn=<ThAddBackward>))
(1140, tensor(4200267420660315319173120., grad_fn=<ThAddBackward>))
(1142, tensor(23149225705472., grad_fn=<ThAddBackward>))
(1144, tensor(19023590850560., grad_fn=<ThAddBackward>))
(1146, tensor(31795951173632., grad_fn=<ThAddBackward>))


(1388, tensor(20372726397009920., grad_fn=<ThAddBackward>))
(1390, tensor(208484069099438080., grad_fn=<ThAddBackward>))
(1392, tensor(125896588266863329280., grad_fn=<ThAddBackward>))
(1394, tensor(1505405042688., grad_fn=<ThAddBackward>))
(1396, tensor(385449295872., grad_fn=<ThAddBackward>))
(1398, tensor(8738405613568., grad_fn=<ThAddBackward>))
(1400, tensor(36100206755840., grad_fn=<ThAddBackward>))
(1402, tensor(2631343768506640367616., grad_fn=<ThAddBackward>))
(1404, tensor(441017663488., grad_fn=<ThAddBackward>))
(1406, tensor(11811632401001153036288., grad_fn=<ThAddBackward>))
(1408, tensor(29222889708087263887360., grad_fn=<ThAddBackward>))
(1410, tensor(419495414731868952667881472., grad_fn=<ThAddBackward>))
(1412, tensor(8800341605506470641664., grad_fn=<ThAddBackward>))
(1414, tensor(25867028950679552., grad_fn=<ThAddBackward>))
(1416, tensor(656006929448960., grad_fn=<ThAddBackward>))
(1418, tensor(253433331712., grad_fn=<ThAddBackward>))
(1420, tensor(1449525116928., g

KeyboardInterrupt: 

In [None]:
X_sample.shape