In [1]:
import pdb
import numpy as np
from scipy import stats

import h5py

import time
import random
import string
from datetime import datetime

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import torch.optim as optim
import os

from src.configs import CifarEncoder, CifarDecoder, CifarTransition
from src.e2c import E2C

### Method to create (3,32,32) tensor

In [2]:
def create_img(x, y, pos_bounds, radius=0.5, W=32):
  # Check if center of ball outside image frame
  if x < pos_bounds[0] or x > pos_bounds[1]:
    return None
  elif y < pos_bounds[0] or y > pos_bounds[1]:
    return None

  x_px = int(round(W * x / posbounds[1]))
  y_px = int(round(W * y / posbounds[1]))
  r_px = int(round(radius / pos_bounds[1] * W))

  # Check if perimeter of ball outside image frame
  if x_px+r_px > W or x_px-r_px < 0:
    return None
  elif y_px+r_px > W or y_px-r_px < 0:
    return None

  img = np.ones((3,W,W))
  xx,yy = np.mgrid[:W, :W]
  circle = (xx-x_px)**2 + (yy-y_px)**2
  img[:, circle < r_px**2] = 0.

  return img

### PWA single integrator kinematics

In [3]:
def step(x0, Aks, add_noise=False):
  if x0[1] >= 0.5*posbounds[0]:
    Ak = Aks[0]
  else:
    Ak = Aks[1]
  update = Ak @ x0
  if add_noise:
    mn = np.array([0.1, 0.1])
    cov = np.diag([0.05, 0.05])
    frzn = stats.multivariate_normal(mn, cov)
    update += frzn.rvs(1)
  return update

### Generate training data

In [4]:
n = 4 
dhs = [0.05, 0.1]

posbounds = np.array([0,4]) # 4x4m square
velmax = 0.10

Aks = []
for dh in dhs:
  Ak = np.eye(n)
  Ak[0:int(n/2), int(n/2):] = dh * np.eye(int(n/2))
  Aks.append(Ak)

np.random.seed(12)

W = 32
NUM_DATA = 50

X = np.zeros((NUM_DATA,3,W,W))
X_next = np.zeros((NUM_DATA,3,W,W))

count = 1 
while count < NUM_DATA:
  x0 = np.hstack((posbounds[1] * np.random.rand(2), velmax*np.random.rand(2)))

  img = create_img(x0[0], x0[1], posbounds)
  if img is None:
    continue

  x0_new = step(x0, Aks)
  img_new = create_img(x0_new[0], x0_new[1], posbounds)
  if img_new is None:
    continue

  X[count,:,:,:] = img
  X_next[count,:,:,:] = img_new

  count += 1

In [5]:
dim_in = X[0].shape
dim_z = 6
dim_u = 0

enc = CifarEncoder(dim_in, dim_z)
enc

CifarEncoder(
  (conv_layers): ModuleList(
    (0): Conv2d(3, 8, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2))
    (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2))
  )
  (pool_layers): ModuleList(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (ff_layers): ModuleList(
    (0): Linear(in_features=72, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): Linear(in_features=32, out_features=12, bias=True)
  )
  (conv_activation): ReLU()
  (ff_activation): ReLU()
)

In [6]:
inp = torch.from_numpy(X[1:4]).float()
mu, var = enc(inp)
print(mu.shape)
mu

torch.Size([3, 6])


tensor([[0.0000, 0.1456, 0.1175, 0.0000, 0.0655, 0.0000],
        [0.0000, 0.1387, 0.1220, 0.0000, 0.0597, 0.0000],
        [0.0000, 0.1405, 0.1278, 0.0000, 0.0596, 0.0000]],
       grad_fn=<SplitBackward>)

In [7]:
dec =  CifarDecoder(dim_z, dim_in)
dec

CifarDecoder(
  (ff_layers): ModuleList(
    (0): Linear(in_features=6, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): Linear(in_features=32, out_features=968, bias=True)
  )
  (conv_layers): ModuleList(
    (0): ConvTranspose2d(8, 8, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2))
    (1): ConvTranspose2d(8, 3, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2))
  )
  (ff_activation): ReLU()
  (conv_activation): ReLU()
)

In [8]:
trans = CifarTransition(dim_z, dim_u)
trans

CifarTransition(
  (trans): Sequential(
    (0): Linear(in_features=6, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=100, out_features=12, bias=True)
  )
  (fc_B): Linear(in_features=6, out_features=0, bias=True)
  (fc_o): Linear(in_features=6, out_features=6, bias=True)
)

In [9]:
model = E2C(dim_in, dim_z, dim_u, config='cifar')
inp = torch.from_numpy(X[:2]).float()

mean, logvar = model.encode(inp)
mean

tensor([[0.0000, 0.0000, 0.1161, 0.0796, 0.0416, 0.0000],
        [0.0000, 0.0000, 0.1184, 0.0800, 0.0352, 0.0000]],
       grad_fn=<SplitBackward>)

In [10]:
model.decode(mean)

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.1148, 0.1170, 0.1049,  ..., 0.1222, 0.1034, 0.1141],
          [0.0736, 0.1106, 0.0905,  ..., 0.1169, 0.0890, 0.0913],
          [0.1110, 0.1217, 0.1030,  ..., 0

## Training loop

In [61]:
before = torch.from_numpy(X[:1]).float()
after = torch.from_numpy(X_next[:1]).float()

ctrl = torch.from_numpy(np.empty((NUM_DATA,0))).float()
ctrl = torch.tensor([])
ctrl = torch.empty(NUM_DATA, dim_u)

model.forward(before, ctrl, after)

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 100])

In [11]:
# criterion = nn.MSELoss()
# optimizer = optim.Adam(e2c.parameters(), lr=1e-4)

# # training parameters
# TRAINING_ITERATIONS = int(100)
# BATCH_SIZE = 32
# CHECKPOINT_AFTER = int(10)
# SAVEPOINT_AFTER = int(20)
# TEST_BATCH_SIZE = int(200)

# rand_idx = list(np.arange(0, X_train.shape[0]-1))
# indices = [rand_idx[ii * BATCH_SIZE:(ii + 1) * BATCH_SIZE] \
#     for ii in range((len(rand_idx) + BATCH_SIZE - 1)     // BATCH_SIZE)]

# for epoch in range(TRAINING_ITERATIONS):
#     for ii, idx in enumerate(indices):
#         optimizer.zero_grad()

#         next_pre_rec = model(x, action, x_next)
#         loss_rec, loss_trans = model.compute_loss(\
#             model.x_dec, model.x_next_pred_dec, \
#             x, x_next, \
#             model.Qz, model.Qz_next_pred, model.Qz_nex, mse=False)
        
#         ll = latent_loss(vae.z_mean, vae.z_sigma)
#         loss = criterion(dec, inputs) + ll
#         loss.backward()
#         optimizer.step()

In [12]:
# from IPython import display
# import matplotlib.pyplot as plt

# %matplotlib inline

# import gym
# from gym import wrappers

# env = gym.make('CartPole-v0')
# # env = wrappers.Monitor(env, "./gym-results", force=True)
# env.reset()

# # plt.figure(figsize=(9,9))
# # img = plt.imshow(env.render(mode='rgb_array')) # only call this once

# for _ in range(10):
# #     img.set_data(env.render(mode='rgb_array')) # just update the data
# #     display.display(plt.gcf())
# #     display.clear_output(wait=True)

#     obs, reward, done, info = env.step(env.action_space.sample())
# #     env.render()
#     if done:
#         env.reset()

# env.close()