In [3]:
import torch
import torch.nn as nn
from mpc import mpc
from mpc import util
import torch.nn.functional as F
import pdb

# Cost net (simple running cost and MLP terminal cost)
g = x^2 + u^2


F = xT * MLP(x) * x

In [4]:
class MPCCostNetwork(nn.Module):
    def __init__(self, n_state, n_ctrl):
        super().__init__()
        self.n_state = n_state
        self.n_ctrl = n_ctrl

        # 运行损失 xTQx + uTRu
        self.Q = torch.eye(n_state)
        self.R = torch.eye(n_ctrl)

        # 终端损失网络 xTFx
        self.terminal_net = nn.Sequential(
            nn.Linear(n_state, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_state * n_state)
        )
        
    def forward(self, tau, t=None, T=None):
        """
        Args:
            tau: [batch_size, n_state + n_ctrl], concatenated form of state and action
            t: Current timestamp
            T: Total prediction horizon
        Returns:
            cost: [batch_size] 
        """
        batch_size = tau.size(0)
        
        # 分离状态和控制
        # 如果t == T-1，ctrl将会是是零向量，表示不存在，所以不会index out of range
        state = tau[:, :self.n_state]
        ctrl = tau[:, self.n_state:]
        
        # 如果是最后一个时间步，只计算终端损失
        # if t is not None and T is not None and t == T-1:
        #     terminal_F = self.terminal_net(state).view(batch_size, self.n_state, self.n_state)
        #     state = state.unsqueeze(1)
        #     terminal_cost = torch.bmm(torch.bmm(state, terminal_F), state.transpose(1, 2))
        #     # pdb.set_trace()
        #     return terminal_cost.squeeze(-1).squeeze(-1) # Return shape: (batch_size, )
        # else:
        #     return self.running_cost(state, ctrl)
        return self.running_cost(state, ctrl)
    
    def running_cost(self, state, ctrl): 
        # 简单二次型运行代价
        batch_size = state.shape[0]
        state = state.unsqueeze(1) # (batch_size, 1, n_state)
        ctrl = ctrl.unsqueeze(1)
        Q = self.Q.repeat(batch_size, 1, 1) # (batch_size, n_state, n_state)
        R = self.R.repeat(batch_size, 1, 1)
        xTQx = torch.bmm(torch.bmm(state, Q), state.transpose(1, 2))
        uTRx = torch.bmm(torch.bmm(ctrl, R), ctrl.transpose(1, 2))
        return (xTQx + uTRx).squeeze(-1).squeeze(-1) # Return shape: (batch_size,)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

In [5]:
# Bolza to Lagrange, so that the cost function can fit the mpc api.
class MPCCostNetwork_Lagrange(nn.Module):
    def __init__(self, n_state, n_ctrl):
        super().__init__()
        self.n_state = n_state
        self.n_ctrl = n_ctrl

        # 运行损失 xTQx + uTRu
        self.Q = torch.eye(n_state)
        self.R = torch.eye(n_ctrl)

        # 终端损失网络 xTFx
        self.terminal_net = nn.Sequential(
            nn.Linear(n_state, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_state * n_state)
        )
        
    def forward(self, tau, last_tau=None):
        """
        Args:
            tau: [batch_size, n_state + n_ctrl], concatenated form of state and action
            t: Current timestamp
            T: Total prediction horizon
        Returns:
            cost: [batch_size] 
        """
        batch_size = tau.size(0)
        
        # 分离状态和控制
        # 如果t == T-1，ctrl将会是是零向量，表示不存在，所以不会index out of range
        state = tau[:, :self.n_state]
        ctrl = tau[:, self.n_state:]
        state_last = last_tau[:, :self.n_state]
        ctrl_last = last_tau[:, self.n_state:]
        
        return self.running_cost(state, ctrl) + self.terminal_bias(state, state_last)
    
    def running_cost(self, state, ctrl): 
        # 简单二次型运行代价
        batch_size = state.shape[0]
        state = state.unsqueeze(1) # (batch_size, 1, n_state)
        ctrl = ctrl.unsqueeze(1)
        Q = self.Q.repeat(batch_size, 1, 1) # (batch_size, n_state, n_state)
        R = self.R.repeat(batch_size, 1, 1)
        xTQx = torch.bmm(torch.bmm(state, Q), state.transpose(1, 2))
        uTRx = torch.bmm(torch.bmm(ctrl, R), ctrl.transpose(1, 2))
        return (xTQx + uTRx).squeeze(-1).squeeze(-1) # Return shape: (batch_size,)
    
    def terminal_bias(self, state, last_state):
        x_detached = state.detach()
        term_cost = self.terminal_net(x_detached)
        term_cost.backward()
        dhdx = x_detached.grad # Shape: (batch_size, n_state)
        dxdt = state - last_state # Shape: (batch_size, n_state)
        print(dhdx.unsqueeze(1).shape, dxdt.unsqueeze(-1).shape)
        # dh/dx * dx/dt
        term_bias = torch.bmm(dhdx.unsqueeze(1), (state - last_state).unsqueeze(-1)).squeeze(-1).squeeze(-1)
        return term_bias


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

# Sys Dyna (F = x + u)

In [6]:
class DynamicsF(nn.Module):
    def forward(self, state, action):
        # Dimension of state & action: (batch_size, C)
        assert state.shape[-1] == action.shape[-1]
        return state + action

# Modify the original MPC lib
To support terminal cost.

In [7]:
from torch.autograd import Variable
import sys

class myMPC(mpc.MPC):
    def approximate_cost(self, x, u, Cf, diff=True):
        with torch.enable_grad():
            tau = torch.cat((x, u), dim=2).data
            tau = Variable(tau, requires_grad=True)

            if self.slew_rate_penalty is not None:
                print("""
                MPC Error: Using a non-convex cost with a slew rate penalty is not yet implemented.
                The current implementation does not correctly do a line search.
                More details: https://github.com/locuslab/mpc.pytorch/issues/12
                """)
                sys.exit(-1)
            
            costs = list()
            hessians = list()
            grads = list()
            for t in range(self.T):
                tau_t = tau[t]
                cost = Cf(tau_t, t=t, T=self.T) #
                grad = torch.autograd.grad(cost.sum(), tau_t,
                                           create_graph=True, retain_graph=True)[0]
                hessian = list()
                for v_i in range(tau.shape[2]):
                    hessian.append(
                        torch.autograd.grad(grad[:, v_i].sum(), tau_t,
                                            retain_graph=True)[0]
                    )
                hessian = torch.stack(hessian, dim=-1)
                costs.append(cost)
                grads.append(grad - util.bmv(hessian, tau_t))
                hessians.append(hessian)
            
            costs = torch.stack(costs, dim=0)
            grads = torch.stack(grads, dim=0)
            hessians = torch.stack(hessians, dim=0)
            if not diff:
                return hessians.data, grads.data, costs.data
            return hessians, grads, costs

# Solve optimization problem use MPC

In [9]:
from mpc.mpc import QuadCost

LQR_ITER = 100
batch_size, T, mpc_T = 1, 5, 5
DTYPE = torch.float
nx, nu = 1, 1 # 用1, 1会出问题，最后得到的控制都是nan
n_sc = nx + nu
dynamics = DynamicsF()
cost = MPCCostNetwork_Lagrange(nx, nu)

torch.manual_seed(43)

init_state = torch.randn(batch_size, nx, dtype=DTYPE)
u_init = None
x_now = init_state

C = torch.randn(T*batch_size, n_sc, n_sc) # shape (T*n_batch, n_sc, n_sc)
C = torch.bmm(C, C.transpose(1, 2)).view(T, batch_size, n_sc, n_sc) # shape (T, n_batch, n_sc, n_sc) # 二次项损失
c = torch.randn(T, batch_size, n_sc) # 一次项损失
Qd_cost = QuadCost(C, c)

VN_list = []
x_list = []
u_list = []
for t in range(T):
    ctrl = LagrangeMPC(nx, nu, mpc_T, lqr_iter=LQR_ITER, verbose=1,
                exit_unconverged=False, eps=1e-2, n_batch=batch_size, backprop=False, u_init=u_init,
                grad_method=mpc.GradMethods.AUTO_DIFF)

    x_seq, u_seq, objs = ctrl(x_now, cost, dynamics)
    action = u_seq[0]

    VN_list.append(objs)
    x_list.append(x_now)
    u_list.append(action)

    x_now = dynamics(x_now, action)
    u_init = torch.cat((u_seq[1:], torch.zeros(1, batch_size, nu, dtype=DTYPE)), dim=0)
    # print(x_seq) # Shape: (mpc_T, batch_size, nx)
    # print(u_seq) # Shape: (mpc_T, batch_size, nu)
    # print(objs)  # Shape: (batch_size,)
    # Verbose > 0时，打印出来的log数据分别表示：
    # iLQR迭代次数、最优轨迹的平均代价、action更新的最大范数(表示action的变化量)、线搜索步长均值、QP子问题的总迭代次数

TypeError: 'NoneType' object is not subscriptable

# Verify the RDP inequality

In [None]:
print(VN_list)
print(x_list)
print(u_list)

In [None]:
def cal_RDP_ineq(VN_list, x_list, u_list, cost, alpha):
    RDP_ineq = []
    for t in range(T-1):
        tau = torch.cat((x_list[t], u_list[t]), dim=-1)
        RDP = VN_list[t] - (VN_list[t+1] + alpha * cost(tau, t, T))
        RDP_ineq.append(RDP)
    return RDP_ineq

RDP_ineq = cal_RDP_ineq(VN_list, x_list, u_list, cost, 1)
print(RDP_ineq)

In [None]:
def interior_point_loss(RDP_ineq):
    loss = 0
    for RDP in RDP_ineq:
        loss += -torch.log(RDP)
    return loss

loss = interior_point_loss(RDP_ineq)
print(loss)
loss.backward()

In [None]:
alpha = 1
multiplier = 1

In [None]:
for t in range(T-1):
    tau = torch.cat((x_list[t], u_list[t]), dim=-1)
    RDP_ineq = VN_list[t+1] + alpha * cost(tau, t, T) - VN_list[t]
    print(RDP_ineq)