In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pdb
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

import h5py

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import torch.optim as optim
import os

import cvxpy as cp

from mie2c.e2c import E2C, compute_loss, PWATransition, train_vae
from mie2c.losses import SigmoidAnneal
from mie2c.cartpole_model import (get_cartpole_encoder, get_cartpole_decoder,
    get_cartpole_transition, get_cartpole_linear_transition, get_cartpole_pwa_transition)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

### Run `python generate_cartpole_data.py N --torque_control` where `N` is the number of desired samples 

In [None]:
# load data generated for the cartpole
X0 = np.load('data/cartpole_X0.npy')
X_all = np.load('data/cartpole_X.npy')
U_all = np.load('data/cartpole_U.npy')
X_next_all = np.load('data/cartpole_X_next.npy')

# test_ratio = .01
# test_indx = int(test_ratio * X_all.shape[0])
test_indx = 100

X0 = torch.tensor(X0, dtype=torch.float64) / 255.
X_test = torch.tensor(X_all[:test_indx, :, :, :], dtype=torch.float64)
U_test = torch.tensor(U_all[:test_indx, :], dtype=torch.float64)
X_next_test = torch.tensor(X_next_all[:test_indx, :, :, :], dtype=torch.float64)

X = torch.tensor(X_all[test_indx:, :, :, :], dtype=torch.float64)
U = torch.tensor(U_all[test_indx:, :], dtype=torch.float64)
X_next = torch.tensor(X_next_all[test_indx:, :, :, :], dtype=torch.float64)

NUM_TEST = X_test.shape[0]
NUM_DATA = X.shape[0]

dataset = torch.utils.data.TensorDataset(X, U, X_next)

In [None]:
# convert to grayscale
gray_w = [0.3, 0.59, 0.11]
def to_gray(Y):
    return gray_w[0] * Y[:, 0:1, :, :] + gray_w[1] * Y[:, 1:2, :, :] + gray_w[2] * Y[:, 2:3, :, :]
X = to_gray(X)
X_next = to_gray(X_next)
X_test = to_gray(X_test)
X_next_test = to_gray(X_next_test)

In [None]:
C,W,H = X.shape[1:]

In [None]:
def show_samples(X_samples, X_next_samples=None):
    num_samples = X_samples.shape[0]
    fig = plt.figure(figsize=(10,10))
    for k in range(num_samples):
        if X_next_samples is not None:
            fig.add_subplot(num_samples,3,k*3+1)
            plt.imshow(X_samples[k,:,:,:].to('cpu').type(torch.uint8).detach().numpy().squeeze(), cmap='gray')
            fig.add_subplot(num_samples,3,k*3+2)
            plt.imshow(X_next_samples[k,:,:,:].to('cpu').type(torch.uint8).detach().numpy().squeeze(), cmap='gray')
        else:
            fig.add_subplot(num_samples,2,k*2+1)
            plt.imshow(X_samples[k,:,:,:].to('cpu').type(torch.uint8).detach().numpy().squeeze(), cmap='gray')
    plt.show()

In [None]:
idx = [np.random.randint(NUM_DATA)]
show_samples(X[idx,:,:,:], X_next[idx,:,:,:])

# Train PWA model

In [None]:
dim_z = 6
dim_u = 1
use_cuda = False
num_epochs = 500
batch_size = 2
checkpoint_every = 1
savepoint_every = 1
learning_rate = 5e-4
kl_lo = 1e-3
kl_up = 1.
kl_center_step = 20
kl_steps_lo_to_up = 10
temp_lo = 1e-3
temp_up = 100.
temp_center_step = 100
temp_steps_lo_to_up = 50
use_l2 = False

writer = None
itr = 0

dim_in = X[0].shape
kl_lambda = SigmoidAnneal(torch.float32, kl_lo, kl_up, kl_center_step, kl_steps_lo_to_up)
temp_lambda = SigmoidAnneal(torch.float32, temp_lo, temp_up, temp_center_step, temp_steps_lo_to_up)

use_low_rank = False  # True if A = I + r*v^T
num_modes = 3

In [None]:
encoder = get_cartpole_encoder(dim_in, dim_z)
pwa_transition = get_cartpole_pwa_transition(num_modes, dim_z, dim_u, low_rank=use_low_rank)
decoder = get_cartpole_decoder(dim_z, dim_in) 

fn_pwa = 'model_pwa'
model_pwa = E2C(encoder, pwa_transition, decoder)

# if os.path.exists('pytorch/{}.pt'.format(fn_pwa)):
#     model_pwa.load_state_dict(torch.load('pytorch/{}.pt'.format(fn_pwa)))

In [None]:
writer, itr = train_vae(model_pwa, X, U, X_next, fn_pwa, verbose=True, use_cuda=use_cuda,
                        num_epochs=num_epochs, batch_size=batch_size,
                        checkpoint_every=checkpoint_every, savepoint_every=savepoint_every,
                        learning_rate=learning_rate,
                        kl_lambda=kl_lambda, temp_lambda=temp_lambda, use_l2=use_l2,
                        writer=writer, itr=itr, device_id=0)

# Run MICP controller

In [None]:
num_modes = len(model_pwa.trans.As)
N = 6 # Horizon for controller

Aks, Bks, oks = [], [], []
Ws, bs = [], []
for ii in range(num_modes):
    Aks.append(model_pwa.trans.As[0].detach().numpy())
    Bks.append(model_pwa.trans.Bs[0].detach().numpy())
    oks.append(model_pwa.trans.os[0].detach().numpy().flatten())
    Ws.append(model_pwa.trans.mode_classifier.weight[ii].detach().numpy().flatten())
    bs.append(model_pwa.trans.mode_classifier.weight[ii].detach().numpy().flatten())

In [None]:
M = np.maximum((model_pwa.trans.mode_classifier.weight.abs().max() + model_pwa.trans.mode_classifier.bias.abs().max()).detach().numpy(), 1e4)

z = cp.Variable((dim_z, N))
u = cp.Variable((dim_u, N-1))
y = cp.Variable((num_modes, N-1), boolean=True)

z0 = cp.Parameter(dim_z)
zg = cp.Parameter(dim_z)

cons = []

# Initial condition
cons += [z[:,0] == z0]

# Dynamics constraints
for ii in range(N-1):
    cons += [cp.sum(y[:,ii]) == 1]
    for jj in range(num_modes):
        Ak, Bk, ok = Aks[jj], Bks[jj], oks[jj]      
        cons += [Ak @ z[:,ii] + Bk @ u[:,ii] + ok - z[:,ii+1] <= M*(cp.sum(y[:,ii]) - y[jj,ii])]
        cons += [z[:,ii+1] - (Ak @ z[:,ii] + Bk @ u[:,ii] + ok) <= M*(cp.sum(y[:,ii]) - y[jj,ii])]

# Piecewise affine constraints
for ii in range(num_modes):
    w_ii, b_ii = Ws[ii], bs[ii]
    for jj in range(num_modes):
        if ii == jj:
            continue
        w_jj, b_jj = Ws[jj], bs[jj]
        for kk in range(N-1):
            cons += [w_jj @ z[:,kk] + b_jj - (w_ii @ z[:,kk] + b_ii) <= M*(cp.sum(y[:,ii]) - y[jj,ii])]

# Control constraints
force_max = 10.
for ii in range(N-1):
    cons += [cp.abs(u[:,ii]) <= force_max]

# LQR cost
lqr_cost = 0.
Q = np.eye(dim_z)
R = 0.1
for ii in range(1,N):
    lqr_cost += cp.quad_form(z[:,ii]-zg, Q)
for ii in range(N-1):
    lqr_cost += R*cp.abs(u[0,ii])

bin_prob = cp.Problem(cp.Minimize(lqr_cost), cons)
bin_prob_params = {'z0':z0, 'zg':zg}

In [None]:
bin_prob_params['z0'].value = model_pwa.encode(X_test[0,:,:,:].unsqueeze(0))[0].detach().numpy().flatten()
bin_prob_params['zg'].value = model_pwa.encode(X0[0,:,:,:].unsqueeze(0))[0].detach().numpy().flatten()

In [None]:
n_rollouts = 10
mpc_imgs = torch.zeros((n_rollouts+2,C,W,H))

mpc_imgs[0] = X_test[0,:,:,:]
mpc_imgs[-1] = X0[0,:,:,:]

for ii in range(n_rollouts):
    bin_prob.solve(solver=cp.GUROBI)
    print(bin_prob.value)
    if bin_prob.status not in ['optimal', 'feasible']:
        break

    img_out = model_pwa.decode(torch.tensor(z.value[:,1], dtype=torch.float).unsqueeze(0))
    mpc_imgs[ii+1] = img_out

    # Update initial condition for MICP
    bin_prob_params['z0'].value = z.value[:,1]

In [None]:
plt.imshow(mpc_imgs[-1,3:,:,:].to('cpu').type(torch.uint8).detach().numpy().transpose(1,2,0))