In [1]:
import copy
import numpy as np

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import torch.optim as optim
import os

In [2]:
%cd datasets
!bash get_datasets.sh
%cd ..

/home/acauligi/cs_231n/cs231n-project/datasets
/home/acauligi/cs_231n/cs231n-project


In [3]:
from datasets.data_utils import get_CIFAR10_data
data = get_CIFAR10_data(num_training=100, num_validation=10, num_test=100)
data['X_train'][:1].shape

(1, 3, 32, 32)

In [4]:
from src.configs import CifarEncoder, CifarDecoder, CifarTransition
from src.e2c import E2C

In [5]:
dim_in = data['X_train'][0].shape
dim_z = 6
dim_u = 0

enc = CifarEncoder(dim_in, dim_z)
enc

CifarEncoder(
  (conv_layers): ModuleList(
    (0): Conv2d(3, 8, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2))
    (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2))
  )
  (pool_layers): ModuleList(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (ff_layers): ModuleList(
    (0): Linear(in_features=72, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): Linear(in_features=32, out_features=12, bias=True)
  )
  (conv_activation): ReLU()
  (ff_activation): ReLU()
)

In [6]:
inp = torch.from_numpy(data['X_train'][1:4]).float()
mu, var = enc(inp)
mu

tensor([[0.0000, 0.0000, 3.1496, 0.0000, 2.9314, 0.0000],
        [0.0000, 0.0000, 2.0170, 0.0000, 0.8094, 0.0000],
        [0.0000, 0.0000, 0.6289, 0.0000, 0.5156, 0.0000]],
       grad_fn=<SplitBackward>)

In [7]:
dec =  CifarDecoder(dim_z, dim_in)
dec

CifarDecoder(
  (ff_layers): ModuleList(
    (0): Linear(in_features=6, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): Linear(in_features=32, out_features=968, bias=True)
  )
  (conv_layers): ModuleList(
    (0): ConvTranspose2d(8, 8, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2))
    (1): ConvTranspose2d(8, 3, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2))
  )
  (ff_activation): ReLU()
  (conv_activation): ReLU()
)

In [8]:
trans = CifarTransition(dim_z, dim_u)
trans

CifarTransition(
  (trans): Sequential(
    (0): Linear(in_features=6, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=100, out_features=12, bias=True)
  )
  (fc_B): Linear(in_features=6, out_features=0, bias=True)
  (fc_o): Linear(in_features=6, out_features=6, bias=True)
)

In [9]:
model = E2C(dim_in, dim_z, dim_u, config='cifar')
inp = torch.from_numpy(data['X_train'][:2]).float()

mean, logvar = model.encode(inp)
mean

tensor([[0.0000, 0.0811, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.7130, 0.2189, 0.0000, 0.0000]],
       grad_fn=<SplitBackward>)

In [10]:
model.decode(mean)

tensor([[[[0.0000, 0.0127, 0.0000,  ..., 0.0193, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0007, 0.0000, 0.0016],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0018, 0.0000,  ..., 0.0038, 0.0000, 0.0058],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.2045, 0.1869, 0.2040,  ..., 0.2089, 0.2115, 0.1947],
          [0.1792, 0.1384, 0.1828,  ..., 0.1189, 0.1782, 0.1160],
          [0.2170, 0.2146, 0.2209,  ..., 0

# Training loop

In [11]:
# criterion = nn.MSELoss()
# optimizer = optim.Adam(e2c.parameters(), lr=1e-4)

# # training parameters
# TRAINING_ITERATIONS = int(100)
# BATCH_SIZE = 32
# CHECKPOINT_AFTER = int(10)
# SAVEPOINT_AFTER = int(20)
# TEST_BATCH_SIZE = int(200)

# rand_idx = list(np.arange(0, X_train.shape[0]-1))
# indices = [rand_idx[ii * BATCH_SIZE:(ii + 1) * BATCH_SIZE] \
#     for ii in range((len(rand_idx) + BATCH_SIZE - 1)     // BATCH_SIZE)]

# for epoch in range(TRAINING_ITERATIONS):
#     for ii, idx in enumerate(indices):
#         optimizer.zero_grad()

#         next_pre_rec = model(x, action, x_next)
#         loss_rec, loss_trans = model.compute_loss(\
#             model.x_dec, model.x_next_pred_dec, \
#             x, x_next, \
#             model.Qz, model.Qz_next_pred, model.Qz_nex, mse=False)
        
#         ll = latent_loss(vae.z_mean, vae.z_sigma)
#         loss = criterion(dec, inputs) + ll
#         loss.backward()
#         optimizer.step()

In [12]:
# from IPython import display
# import matplotlib.pyplot as plt

# %matplotlib inline

# import gym
# from gym import wrappers

# env = gym.make('CartPole-v0')
# # env = wrappers.Monitor(env, "./gym-results", force=True)
# env.reset()

# # plt.figure(figsize=(9,9))
# # img = plt.imshow(env.render(mode='rgb_array')) # only call this once

# for _ in range(10):
# #     img.set_data(env.render(mode='rgb_array')) # just update the data
# #     display.display(plt.gcf())
# #     display.clear_output(wait=True)

#     obs, reward, done, info = env.step(env.action_space.sample())
# #     env.render()
#     if done:
#         env.reset()

# env.close()