In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Function
from torch.autograd.functional import hessian, jacobian

import einops

import matplotlib.pyplot as plt

import casadi as ca

import numpy as np

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [35]:
def patchify(images, patch_size=5):
    """Splitting images into patches.
    Args:
        images: Input tensor with size (batch, channels, height, width)
    Returns:
        A batch of image patches with size (
          batch, (height / patch_size) * (width / patch_size), 
        channels * patch_size * patch_size)
    """
    return einops.rearrange(
        images,
        'b (h p1) (w p2) -> b (h w) (p1 p2)',
        p1=patch_size,
        p2=patch_size
    )

def unpatchify(patches, patch_size=5):
    """Combining patches into images.
    Args:
        patches: Input tensor with size (
        batch, (height / patch_size) * (width / patch_size), 
        channels * patch_size * patch_size)
    Returns:
        A batch of images with size (batch, channels, height, width)
    """
    return einops.rearrange(
        patches,
        'b (h w) (p1 p2) -> b (h p1) (w p2)',
        p1=patch_size,
        p2=patch_size,
        h=int(patches.shape[1] ** 0.5),
        w=int(patches.shape[1] ** 0.5),
    )


In [36]:
grids, commands = torch.load('data/robot_field_data.pt')

In [37]:
grid_patchified = patchify([grids[0]])
grid_patchified.shape

torch.Size([1, 400, 25])

## Define Transformer Encoder

In [38]:
class Transformer(nn.Module):
    """Transformer Encoder 
    Args:
        embedding_dim: dimension of embedding
        n_heads: number of attention heads
        n_layers: number of attention layers
        feedforward_dim: hidden dimension of MLP layer
    Returns:
        Transformer embedding of input
    """
    # TODO embedding_dim? -> size of Q,p ???
    def __init__(self, embedding_dim=256, n_heads=1, n_layers=3, feedforward_dim=64):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.feedforward_dim = feedforward_dim
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embedding_dim,
                nhead=self.n_heads,
                dim_feedforward=self.feedforward_dim,
                activation=F.gelu,
                batch_first=True,
                dropout=0.1,
            ),
            num_layers=n_layers,
        )

    def forward(self, x):
        return self.transformer(x)

### Define Optimization Problem

In [39]:
# Use it by calling the apply method:
x0 = torch.Tensor([0, 0, 0])
x0.requires_grad=True
u0 = torch.Tensor([0, 0, 0])
u0.requires_grad=True
u_des = torch.Tensor([0.4, 0, 0])
u_des.requires_grad=True


In [40]:
opt_N=20
opt_dt=0.1

opti = ca.Opti()

opt_x=opti.variable(3, opt_N+1)
opt_u=opti.variable(3, opt_N)

opt_x0 = opti.parameter(3)
opt_u0 = opti.parameter(3)

opt_u_des = opti.parameter(3)

opt_P = opti.parameter(6, 6)
opt_q = opti.parameter(6)

# stage cost
cost = 0
for i in range(opt_N):
  cost += 0.1*(ca.vertcat(opt_x[:, i], opt_u[:, i]).T @ opt_P.T @ opt_P @ ca.vertcat(opt_x[:, i], opt_u[:, i])
                   + opt_q.T @ ca.vertcat(opt_x[:, i], opt_u[:, i]))
  cost += 10*(opt_u_des - opt_u[:, i]).T @ (opt_u_des - opt_u[:, i])
  
opti.minimize(cost)

# system dynamics
for i in range(opt_N):
  opti.subject_to(opt_x[0, i+1] == opt_x[0, i] + opt_dt * (ca.cos(opt_x[2, i]) * opt_u[0, i] - ca.sin(opt_x[2, i]) * opt_u[1, i] + opt_u[2, i]) )
  opti.subject_to(opt_x[1, i+1] == opt_x[1, i] + opt_dt * (ca.sin(opt_x[2, i]) * opt_u[0, i] + ca.cos(opt_x[2, i]) * opt_u[1, i] + opt_u[2, i]))
  opti.subject_to(opt_x[2, i+1] == opt_x[2, i] + opt_dt * (opt_u[2, i]))
  
  # constraints on control rate
  # for i in range(N):
  #   opti.subject_to( u[:,i+1] - u[:,i] <= np.array([0.1, 0.1, 0.1]) )
  #   opti.subject_to( u[:,i+1] - u[:,i] >= -np.array([0.1, 0.1, 0.1]) )
  # opti.subject_to( u[:,0] - u0 <= np.array([0.1, 0.1, 0.1]) )
  # opti.subject_to( u[:,0] - u0 >= -np.array([0.1, 0.1, 0.1]) )
  
  # initial condition
  opti.subject_to(opt_x[:, 0] == opt_x0)
  
opti.solver('ipopt', {'ipopt.print_level':0, 'print_time':0})

In [365]:
class CFTOC(Function):
  @staticmethod
  def forward(ctx, embedding, x0, u0, u_des):     
    # split and reshape to get P and q for CFTOC cost function
    P, q = torch.split(embedding, [36, 6], dim=1)
    P = P.view(P.shape[0], 6, 6)
    batch_dim = P.shape[0]
    
    u_opt = []
    x_opt = []
    
    # loop through batch
    for i in range(batch_dim):
      opti.set_value(opt_x0, x0.detach().numpy())
      opti.set_value(opt_u0, u0.detach().numpy())
    
      opti.set_value(opt_u_des, u_des.detach().numpy())

      opti.set_value(opt_P, P[i].detach().numpy())
      opti.set_value(opt_q, q[i].detach().numpy())

      sol = opti.solve()

      u_opt.append(sol.value(opt_u))        
      x_opt.append(sol.value(opt_x))        
    
    # u_opt = torch.Tensor(u_opt).flatten(start_dim=1)
    
    u_opt = torch.Tensor(u_opt).T.flatten(end_dim=1).T
    
    
    # print(u_opt.shape)
    # x_opt = torch.Tensor(x_opt).flatten()
    
    # TODO make devectoring of embedding form P, q part of the cost function to make it
    # differentiable towards output
        
    ctx.save_for_backward(embedding, u_opt, u_des)
    
    return u_opt
  
  @staticmethod
  def backward(ctx, grad_output):
    '''Eq (11) from paper.'''
    
    # print('DEBUG Upstream gradient:', grad_output.shape)

    embedding, u_opt, u_des = ctx.saved_tensors

    # Ensure that the tensors have requires_grad set to True
    assert embedding.requires_grad, "Embedding tensor does not require grad"
    assert u_opt.requires_grad, "u_opt tensor does not require grad"

    batch_dim = embedding.shape[0]

    def cost(embedding, u, u_des):        
        P, q = torch.split(embedding, [36, 6], dim=0)
        P = P.reshape(6, 6)
        
        print(P, q)

        x = torch.zeros(3*(opt_N+1))
        x[0:3] = x0.T 








        i = 0
        x[3*(i+1)+0] = x[3*i+0] + opt_dt * (torch.cos(x0[2]) * u[3*i] - torch.sin(x0[2]) * u[3*i+1] + u[3*i+2])
        x[3+(i+1)+1] = x[3*i+1] + opt_dt * (torch.sin(x0[2]) * u[3*i] + torch.cos(x0[2]) * u[3*i+1] + u[3*i+2])
        x[3*(i+1)+2] = x0[2] + opt_dt * (u[3*i+2])  
        
        i = 1
        x[3*(i+1)+0] = x[3*i+0] + opt_dt * (torch.cos(x0[2] + opt_dt * (u[3*i+2])) * u[3*i] - torch.sin(x0[2] + u[i] + opt_dt * (u[3*i+2])) * u[3*i+1] + u[3*i+2])
        x[3+(i+1)+1] = x[3*i+1] + opt_dt * (torch.sin(x0[2] + opt_dt * (u[3*i+2])) * u[3*i] + torch.cos(x0[2] + u[i] + opt_dt * (u[3*i+2])) * u[3*i+1] + u[3*i+2])
        x[3*(i+1)+2] = x[3*i+2] + opt_dt * (u[3*i+2])      
        
        i = 2
        x[3*(i+1)+0] = x[3*i+0] + opt_dt * (torch.cos(x0[2] + opt_dt * (u[3*(i-1)+2]) + opt_dt * (u[3*i+2])) * u[3*i] - torch.sin(x0[2] + u[i] + opt_dt * (u[3*i+2])) * u[3*i+1] + u[3*i+2])
        x[3+(i+1)+1] = x[3*i+1] + opt_dt * (torch.sin(x0[2] + opt_dt * (u[3*(i-1)+2]) + opt_dt * (u[3*i+2])) * u[3*i] + torch.cos(x0[2] + u[i] + opt_dt * (u[3*i+2])) * u[3*i+1] + u[3*i+2])
        x[3*(i+1)+2] = x[3*i+2] + opt_dt * (u[3*i+2])  
        
        
        # torch.sum(u[0:3*i+2::3])
        
            
        
        
        # for i in range(opt_N):
        #   x[3*(i+1)+0] = x[3*i+0] # + opt_dt * (torch.cos(x[3*i+2]) * u[3*i] - torch.sin(x[3*i+2]) * u[3*i+1] + u[3*i+2])
        #   x[3+(i+1)+1] = x[3*i+1] # + opt_dt * (torch.sin(x[3*i+2]) * u[3*i] + torch.cos(x[3*i+2]) * u[3*i+1] + u[3*i+2])
        #   x[3*(i+1)+2] = x[3*i+2] # + opt_dt * (u[3*i+2])

        cost = 0
        for i in range(opt_N):
            cost += 0.1 * (torch.cat([x[3*i:3*i+3], u[3*i:3*i+3]], dim=-1).T @ P.T
                           @ P @ torch.cat([x[3*i:3*i+3], u[3*i:3*i+3]], dim=-1)
                           + q.T @ torch.cat([x[3*i:3*i+3], u[3*i:3*i+3]], dim=-1))
            cost += 10 * (u_des - u[3*i:3*i+3]).T @ (u_des - u[3*i:3*i+3])       
        
        # for i in range(opt_N):
        #     x[:, i+1] = x[:, i] + opt_dt * (torch.cos(x[2, i]) * u[0, i] - torch.sin(x[2, i]) * u[1, i] + u[2, i])

        # cost = 0
        # for i in range(opt_N):
        #     cost += 0.1 * (torch.cat([x[:, i], u[:, i]], dim=-1).T @ P.T
        #                    @ P @ torch.cat([x[:, i], u[:, i]], dim=-1)
        #                    + q.T @ torch.cat([x[:, i], u[:, i]], dim=-1))
        #     cost += 10 * (u_des - u[:, i]).T @ (u_des - u[:, i])

        # cost = u[:-1].T@u[:-1]*torch.cos(x.T@x)
        
        return cost

    grad_embedding_list = []
    
    for i in range(batch_dim):

      cost_hessian = hessian(lambda u: cost(embedding[i], u, u_des), u_opt[i])
      cost_hessian_inv = torch.inverse(cost_hessian)

      
      print(cost_hessian_inv.shape)

      # grad_embedding_list.append(grad_embedding_i.unsqueeze(0))

    # grad_embedding = torch.cat(grad_embedding_list, dim=0)

    return None, None, None, None

In [366]:
# embedding = torch.ones((1,42), requires_grad=True)
# u_opt = CFTOC.apply(embedding, x0, u0, u_des)
# print(u_opt)

In [367]:
class MPCTransformer(nn.Module):
    """MPC transformer
    Args:
        TODO
        embedding_dim: dimension of embedding
        patch_size: image patch size
        num_patches: number of image patches
    Returns:
        TODO
    """
    def __init__(self, embedding_dim=256, patch_size=5, num_patches=20):
        super().__init__()
        
        self.patch_size = patch_size
        self.num_patches = num_patches
        
        self.embedding_dim = embedding_dim

        self.transformer = Transformer(embedding_dim)
        
        self.position_encoding = nn.Parameter(
            torch.randn(1, num_patches * num_patches, embedding_dim) * 0.02
        )
        
        self.patch_projection = nn.Linear(patch_size * patch_size, embedding_dim)
        
        self.output_head = nn.Sequential(
            nn.LayerNorm(embedding_dim), 
            nn.Linear(embedding_dim, 6*6+6)  # TODO P,q
        )
        
        self.cftoc = CFTOC()

    def forward(self, images):
        """ 
        (1) Splitting images into fixed-size patches; 
        (2) Linearly embed each image patch, prepend CLS token; 
        (3) Add position embeddings;
        (4) Feed the resulting sequence of vectors to Transformer encoder.
        (5) Extract the embeddings corresponding to the CLS token.
        (6) Apply output head to the embeddings to obtain the logits
        """
        patches = patchify(images, self.patch_size)
        
        patch_embeddings = self.patch_projection(patches)
        
        embeddings = patch_embeddings + self.position_encoding
        
        transformer_embeddings = self.transformer(embeddings)
        transformer_embeddings = transformer_embeddings[:, 0, :]
        
        output_embeddings = self.output_head(transformer_embeddings)
        
        u_opt = CFTOC.apply(output_embeddings, x0, u0, u_des)

        return u_opt

In [368]:
torch.autograd.set_detect_anomaly(True)

model = MPCTransformer()

idx = 1000

input_grid = grids[idx].to(torch.float32).unsqueeze(0)
input_grid.requires_grad=True

# get expert control input
u_expert = commands[idx:idx+opt_N].T[[0,1,5],:].unsqueeze(0).flatten()

u_model = model(input_grid)

# print(u_model.shape)

loss_func = nn.MSELoss()
loss = loss_func(u_expert, u_model)
# print(loss)
loss.backward()


  return F.mse_loss(input, target, reduction=self.reduction)


tensor([[ 0.3270, -0.8240,  0.4814, -0.3080, -0.1347, -0.1071],
        [-0.3478, -0.2319, -0.6237,  0.4094, -0.1325,  0.7680],
        [ 0.3273, -1.1287,  0.0561,  0.1408,  0.2156,  0.6102],
        [-0.9771,  0.7120, -0.0619, -0.2835,  0.1496,  0.0249],
        [ 1.0286,  0.1238, -0.4658, -0.0342,  0.3236,  0.2138],
        [-0.5620, -0.1156,  0.0703,  0.0972, -0.3066, -0.5429]],
       grad_fn=<ReshapeAliasBackward0>) tensor([ 0.3416, -0.9555,  0.2199, -0.1279,  0.1443,  0.3250],
       grad_fn=<SplitWithSizesBackward0>)
torch.Size([60, 60])


In [None]:
test = torch.zeros(1,3,20)
test = torch.tensor([[1,2,3,4,5],[6,7,8,9,10]])

print(test)
print((test.T.flatten(end_dim=1)).T.shape)
print(test.T.flatten(end_dim=1).T)

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
torch.Size([10])
tensor([ 1,  6,  2,  7,  3,  8,  4,  9,  5, 10])


In [203]:
test.unsqueeze(dim=1)

tensor([[[ 1,  2,  3,  4,  5]],

        [[ 6,  7,  8,  9, 10]]])