<a href="https://colab.research.google.com/github/annechris13/Master-Thesis/blob/master/mpc_class.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import numpy as np
import pandas as pd
from sklearn.datasets import make_spd_matrix
import random
import warnings
warnings.filterwarnings('error')
import time

In [0]:
class mpc():
  def __init__(self,max_iter=20):
    self.max_iter=20

  def solve(self,Q,q,G,h,A,b):
    self.Q=Q
    self.q=q
    self.G=G
    self.h=h
    self.G_T=torch.transpose(self.G,dim0=2,dim1=1)
    self.A=A
    self.b=b
    self.A_T=torch.transpose(self.A,dim0=2,dim1=1)
    self.nbatch, self.nx, self.nineq, self.neq = self.get_sizes()
    self.is_Q_pd()
    
    self.J=self.get_Jacobian()
    self.J=self.get_lu_J()
    #initial solution
    self.x,self.s,self.z,self.y=self.solve_kkt(-q.unsqueeze(-1),
                                               torch.zeros(self.nbatch,self.nineq).unsqueeze(-1).type_as(self.Q),
                                               self.h.unsqueeze(-1),self.b.unsqueeze(-1))
    alpha_p=self.get_initial(-self.z)
    alpha_d=self.get_initial(self.z)
    self.s=-self.z+alpha_p*(torch.ones(self.z.size()).type_as(self.z))
    self.z=self.z+alpha_d*(torch.ones(self.z.size()).type_as(self.z))
    #main iterations
    start = time.time()
    self.x,self.s,self.z,self.y=self.mpc_opt()
    op_val=0.5*torch.bmm(torch.transpose(self.x,dim0=2,dim1=1),
                         torch.bmm(self.Q,self.x))+torch.bmm(
                         torch.transpose(self.q.unsqueeze(-1),dim0=2,dim1=1),self.x)
    t = time.time() - start
    # print("Optimization - time taken:", t)
    return self.x, op_val

  def get_sizes(self):
    #2 dimensions ==> dimensions are (ninenq,nx), add dimension nbatch at pos 0
    if(self.Q.dim()==self.G.dim()==self.A.dim()==2):  
      self.Q=self.Q.unsqueeze(0)
      self.q=self.q.unsqueeze(0)
      self.G=self.G.unsqueeze(0)
      self.h=self.h.unsqueeze(0)
      if A is not None:
        self.A=self.A.unsqueeze(0)
        self.b=self.b.unsqueeze(0)
    #get sizes
    nbatch, nineq, nx = self.G.size()
    if self.A is not None:
      _,neq,_=self.A.size()
    else:
      neq=None
    return nbatch,nx,nineq,neq
  
  def is_Q_pd(self):
    for i in range(self.nbatch):
      e,_=torch.eig(self.Q[i])
      if not torch.all(e[:,0]>0): 
        #not all eigen values are positive ==> raise error
        raise RuntimeError("Q is not PD")
  
  def lu_factorize(self,x):
    #do lu factorization of x
    #avoid pivoting when possible, i.e when on cuda
    data, pivots = x.lu(pivot=not x.is_cuda)
    #define pivot matrix manually when on cuda 
    if x.is_cuda==True:
        #pivot matrix doesnt do any pivoting
        pivots = torch.arange(1, 1+x.size(1),).unsqueeze(0).repeat(x.size(0), 1).int().cuda()
    return (data, pivots)

  def get_diag_matrix(self,d):
    #return diagonal matrix with diagonal entries d
    nBatch, n, _ = d.size()
    Diag = torch.zeros(nBatch, n, n).type_as(d)
    I = torch.eye(n).repeat(nBatch, 1, 1).type_as(d).bool()
    Diag[I] = d.view(-1)
    return Diag
  
  def get_Jacobian(self):
    #get the jacobian kkt matrix as concatenation of 4 blocks, B2=transpose(B3)
    B1=torch.zeros(self.nbatch,self.nx+self.nineq,self.nx+self.nineq).type_as(self.Q)
    B3=torch.zeros(self.nbatch,self.neq+self.nineq,self.nx+self.nineq).type_as(self.Q)
    B4=torch.zeros(self.nbatch,self.neq+self.nineq,self.neq+self.nineq).type_as(self.Q)

    B1[:,:self.nx,:self.nx]=self.Q
    #D here is unit identity matrix (initial case)
    self.D=torch.eye(self.nineq).repeat(self.nbatch,1,1).type_as(self.Q)
    B1[:,-self.nineq:,-self.nineq:]=self.D

    B3[:,:self.nineq,:self.nx]=self.G
    B3[:,-self.neq:,:self.nx]=self.A
    B3[:,:self.nineq,self.nineq:]=torch.eye(self.nineq).repeat(self.nbatch,1,1).type_as(self.Q)

    B2=torch.transpose(B3, dim0=2, dim1=1)
    self.J=torch.cat((torch.cat((B1,B2),dim=2),torch.cat((B3,B4),dim=2)),dim=1)
    return self.J

  def get_lu_J(self,d=None):
    # the jacobian J is modified when d is specified
    if d!=None:
      self.D=self.get_diag_matrix(d)
      self.J[:,self.nx:self.nx+self.nineq,self.nx:self.nx+self.nineq]=self.D
    self.J_lu,self.J_piv= self.lu_factorize(self.J)
    return self.J

  def solve_kkt(self,rx,rs,rz,ry):
    #TODO: Implement solving the KKT system using block elimination
    # solve the KKT system with jacobian J and F specified by rx,rs,rz,ry
    F=torch.cat((rx,rs,rz,ry), dim=1)
    step=F.lu_solve(self.J_lu,self.J_piv)
    dx=step[:,:self.nx,:]
    ds=step[:,self.nx:self.nx+self.nineq,:]
    dz=step[:,self.nx+self.nineq:-self.neq,:]
    dy=step[:,-self.neq:,:]
    return(dx,ds,dz,dy)
  
  def mpc_opt(self):
    # J=self.get_Jacobian()
    count=0
    bat=np.array([i for i in range(self.nbatch)])
    for i in range(self.max_iter):
        # print("iteration: ",i)
        rx= -(torch.bmm(self.A_T,self.y)+torch.bmm(self.G_T,self.z)+torch.bmm(self.Q,self.x)+self.q.unsqueeze(-1))
        rs=-self.z
        rz=-(torch.bmm(self.G,self.x)+self.s-self.h.unsqueeze(-1))
        ry=-(torch.bmm(self.A,self.x)-self.b.unsqueeze(-1))
        d=self.z/self.s
        mu=torch.abs(torch.bmm(torch.transpose(self.s,dim0=2,dim1=1),self.z).sum(1))
        pri_resid=torch.abs(rx)
        dual_1_resid=torch.abs(rz)
        dual_2_resid=torch.abs(ry)
        # if (i%4==0):
        #   log=((pri_resid.sum(1)<1e-12)*(dual_1_resid.sum(1)<1e-12)*(dual_2_resid.sum(1)<1e-12)*(mu<1e-12)).squeeze(1).numpy().astype(bool)
          # print("already converged: ",bat[log] )

        resids=np.array([pri_resid.max(),mu.max(),dual_1_resid.max(),dual_2_resid.max()])
        try:
          if (resids<1e-12).all():
            # print("Early exit at iteration no:",i)
            return(self.x,self.s,self.z,self.y)
        except:
          print(bat[torch.isnan(pri_resid.sum(1)).squeeze(1)])
          raise RuntimeError("invalid res")
        
        #affine step calculation
        #get modified Jacobian and its lu factorization
        self.J=self.get_lu_J(d)
        dx_aff,ds_aff,dz_aff,dy_aff=self.solve_kkt(rx,rs,rz,ry)
        #affine step size calculation
        alpha = torch.min(self.get_step(self.z, dz_aff),self.get_step(self.s, ds_aff))
        
        #affine updates for s and z
        s_aff=self.s+alpha*ds_aff
        z_aff=self.z+alpha*dz_aff
        mu_aff=torch.abs(torch.bmm(torch.transpose(s_aff,dim0=2,dim1=1),z_aff).sum(1))
        
        #find sigma for centering in the direction of mu
        sigma=(mu_aff/mu)**3

        #find centering+correction steps
        rx=torch.zeros(rx.size()).type_as(self.Q)
        rs=((sigma*mu).unsqueeze(-1).repeat(1,self.nineq,1)-ds_aff*dz_aff)/self.s
        rz=torch.zeros(rz.size()).type_as(self.Q)
        ry=torch.zeros(ry.size()).type_as(self.Q)
        dx_cor,ds_cor,dz_cor,dy_cor=self.solve_kkt(rx,rs,rz,ry)

        dx=dx_aff+dx_cor
        ds=ds_aff+ds_cor
        dz=dz_aff+dz_cor
        dy=dy_aff+dy_cor
        # find update step size
        alpha = torch.min(torch.ones(self.nbatch).type_as(self.Q).view(self.nbatch,1,1),0.99*torch.min(self.get_step(self.z, dz),self.get_step(self.s, ds)))
        # update
        self.x+=alpha*dx
        self.s+=alpha*ds
        self.z+=alpha*dz
        self.y+=alpha*dy

        if(i==self.max_iter-1 and (resids>1e-10).any()):
          print("no of mu not converged: ",len(mu[mu>1e-10]))
          # print("no of primal residual not converged: ",len(pri_resid[pri_resid>1e-10]))
          # print("no of dual residual 1 not converged: ",len(dual_1_resid[dual_1_resid>1e-10]))
          # print("no of dual residual 2 not converged: ",len(dual_2_resid[dual_2_resid>1e-10]))
          print("mpc warning: Residuals not converged, need more itrations")
    return(self.x,self.s,self.z,self.y)

  def get_step(self,v,dv):
    v=v.squeeze(2)
    dv=dv.squeeze(2)
    div= -v/dv
    ones=torch.ones_like(div)
    div=torch.where(torch.isinf(div),ones,div)
    div=torch.where(torch.isnan(div),ones,div)
    div[dv>0]=max(1.0,div.max())
    return (div.min(1)[0]).view(v.size()[0],1,1)
  def get_initial(self,z):
      #get step size using line search for initialization 
      nbatch,_,_=z.size()
      dz=torch.ones(z.size()).type_as(z)
      alpha=torch.tensor([]).type_as(z)
      for b in range(nbatch):
        step=torch.tensor([-0.1]).type_as(z)
        z_=z[b,:,:]
        dz_=dz[b,:,:]
        while True:
          if (z_+step*dz_ >0).all():
            if step<0:
              alpha=torch.cat((alpha,torch.tensor([0]).type_as(z)))
            else:
              alpha=torch.cat((alpha,1+step))
            break
          else:
            step=step+0.1
      return alpha.view(nbatch,1,1)

# Test Problem


In [0]:

# nb=2
# # 
# xy=2
# ineq=2
# eq=1
# #to do: extract dimensions from problem parameters + check/add batch dimension
# Q_=torch.tensor([[4,1,1,2],[6,2,2,2]]).view(nb,xy,xy).type(torch.DoubleTensor)
# q_=torch.tensor([[1,1],[1,6]]).view(nb,xy).type(torch.DoubleTensor)
# G_=torch.tensor([[-1,0,0,-1],[-1,0,0,-1]]).view(nb,ineq,xy).type(torch.DoubleTensor)
# h_=torch.tensor([[0,0],[0,0]]).view(nb,ineq).type(torch.DoubleTensor)
# A_=torch.tensor([[1,1],[2,3]]).view(nb,eq,xy).type(torch.DoubleTensor)
# b_=torch.tensor([[1],[4]]).view(nb,eq).type(torch.DoubleTensor)
# solver=mpc()
# solver.solve(Q_,q_,G_,h_,A_,b_)
# # # do,po=lu_hack(Q)
# # # print(po)
# # # print(do)

# Non-class Implementation 


In [0]:
# def check_Q_pd(Q):
#   #check if Q is pd:
#   nbatch=Q.size()[0]
#   for i in range(nbatch):
#     e,_=torch.eig(Q[i])
#     if not torch.all(e[:,0]>0):
#       raise RuntimeError("Q is not PD")

In [0]:
# def lu_hack(x):
#     #do lu factorization of x
#     data, pivots = x.lu(pivot=not x.is_cuda)
#     if x.is_cuda:
#         if x.ndimension() == 2:
#             pivots = torch.arange(1, 1+x.size(0)).int().cuda()
#         elif x.ndimension() == 3:
#             pivots = torch.arange(
#                 1, 1+x.size(1),
#             ).unsqueeze(0).repeat(x.size(0), 1).int().cuda()
#         else:
#             assert False
#     return (data, pivots)

In [0]:
# def bdiag(d):
#     #return diagonal matrix with diagonal entries d
#     nBatch, sz, _ = d.size()
#     D = torch.zeros(nBatch, sz, sz).type_as(d)
#     I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).bool()
#     D[I] = d.squeeze().view(-1)
#     return D

In [0]:
# def get_Hessian(Q,G,A):
#     #get the hessian kkt matrix
#     nbatch,nineq,nx=G.size()
#     neq=A.size()[1]
#     B1=torch.zeros(nbatch,nx+nineq,nx+nineq).type_as(Q)
#     B3=torch.zeros(nbatch,neq+nineq,nx+nineq).type_as(Q)
#     B4=torch.zeros(nbatch,neq+nineq,neq+nineq).type_as(Q)

#     B1[:,:nx,:nx]=Q
#     #D here is unit identity matrix
#     B1[:,-nineq:,-nineq:]=torch.eye(nineq).repeat(nbatch,1,1).type_as(Q)

#     B3[:,:nineq,:nx]=G
#     B3[:,-neq:,:nx]=A
#     B3[:,:nineq,nineq:]=torch.eye(nineq).repeat(nbatch,1,1).type_as(Q)

#     B2=torch.transpose(B3, dim0=2, dim1=1)

#     H=torch.cat((torch.cat((B1,B2),dim=2),torch.cat((B3,B4),dim=2)),dim=1)
  
#     return H

In [0]:
# def solve_kkt(H,rx,rs,rz,ry,d=None):
#     # solve the KKT system with hessian H and F specified by rx,rs,rz,ry
#     # the hessian H is modified when d is specified
#     nx=rx.size()[1]
#     nineq=rz.size()[1]
#     neq=ry.size()[1]
#     if d!=None:
#       D=bdiag(d)
#       H[:,nx:nx+nineq,nx:nx+nineq]=D
#     # print("H: ",H)
#     H_lu,H_piv= lu_hack(H)
#     F=torch.cat((rx,rs,rz,ry), dim=1)
#     step=F.lu_solve(H_lu,H_piv)

#     rx=step[:,:nx,:]
#     rs=step[:,nx:nx+nineq,:]
#     rz=step[:,nx+nineq:-neq,:]
#     ry=step[:,-neq:,:]
#     return(rx,rs,rz,ry)

In [0]:
# def get_initial(z):
#       #get step size using line search for initialization 
#       nbatch,_,_=z.size()
#       dz=torch.ones(z.size()).type_as(z)
#       alpha=torch.tensor([]).type_as(z)
#       for b in range(nbatch):
#         step=torch.tensor([-0.1]).type_as(z)
#         z_=z[b,:,:]
#         dz_=dz[b,:,:]
#         while True:
#           if (z_+step*dz_ >0).all():
#             if step<0:
#               alpha=torch.cat((alpha,torch.tensor([0]).type_as(z)))
#             else:
#               alpha=torch.cat((alpha,1+step))
#             break
#           else:
#             step=step+0.1
#       return alpha.view(nbatch,1,1)

In [0]:
# def get_step(v,dv):
#       #get step sizes for each iteration
#       #TO DO: find efficient and accurate line search algorithm
#       nbatch,_,_=v.size()
#       alpha=torch.tensor([]).type_as(v)
#       for b in range(nbatch):
#         step=torch.tensor([1]).type_as(v)
#         v_=v[b,:,:]
#         dv_=dv[b,:,:]
#         while step>0:
#           if (v_+step*dv_ >=0).all() or step==0:
#             alpha=torch.cat((alpha,step))
#             break
#           else:
#             step=step-0.01
#         if(step<0):
#           alpha=torch.cat((alpha,torch.tensor([0]).type_as(v)))
#       return alpha.view(nbatch,1,1)

# def get_step(v, dv):
#     #qpth version of get_step
#     a = -v / dv
#     a[dv > 0] = max(1.0, a.max())
#     return a.min(1)[0].squeeze()

# def get_step(v,dv):
#   v=v.squeeze(2)
#   dv=dv.squeeze(2)
#   div= -v/dv
#   ones=torch.ones_like(div)
#   div=torch.where(torch.isinf(div),ones,div)
#   div=torch.where(torch.isnan(div),ones,div)
#   div[dv>0]=max(1.0,div.max())
#   return (div.min(1)[0]).view(v.size()[0],1,1)

In [0]:
# def mpc(Q,G,A,p,b,h,x,s,z,y, max_iter=20):
#     nbatch,nx,_=Q.size()
#     nineq=G.size()[1]
#     neq=A.size()[1]
#     H=get_Hessian(Q,G,A)
#     A_T=torch.transpose(A,dim0=2,dim1=1)
#     G_T=torch.transpose(G,dim0=2,dim1=1)
#     count=0
#     bat=np.array([i for i in range(nbatch)])
#     for i in range(max_iter):
#         # print("iteration: ",i)
#         rx= -(torch.bmm(A_T,y)+torch.bmm(G_T,z)+torch.bmm(Q,x)+p.unsqueeze(2))
#         rs=-z
#         rz=-(torch.bmm(G,x)+s-h.unsqueeze(2))
#         ry=-(torch.bmm(A,x)-b.unsqueeze(2))
#         d=z/s
#         mu=torch.abs(torch.bmm(torch.transpose(s,dim0=2,dim1=1),z).sum(1))
#         pri_resid=torch.abs(rx)
#         dual_1_resid=torch.abs(rz)
#         dual_2_resid=torch.abs(ry)
#         if (i%4==0):
#           log=((pri_resid.sum(1)<1e-12)*(dual_1_resid.sum(1)<1e-12)*(dual_2_resid.sum(1)<1e-12)*(mu<1e-12)).squeeze(1).numpy().astype(bool)
#           # print("already converged: ",bat[log] )

#         resids=np.array([pri_resid.max(),mu.max(),dual_1_resid.max(),dual_2_resid.max()])
#         try:
#           if (resids<1e-12).all():
#             # print("Early exit at iteration no:",i)
#             return(x,s,z,y)
#         except:
#           print(bat[torch.isnan(pri_resid.sum(1)).squeeze(1)])
#           raise RuntimeError("invalid res")
        
#         #affine step calculation
#         dx_aff,ds_aff,dz_aff,dy_aff=solve_kkt(H,rx,rs,rz,ry,d)
#         #affine step size calculation
#         alpha = torch.min(get_step(z, dz_aff),get_step(s, ds_aff))
        
#         #affine updates for s and z
#         s_aff=s+alpha*ds_aff
#         z_aff=z+alpha*dz_aff
#         mu_aff=torch.abs(torch.bmm(torch.transpose(s_aff,dim0=2,dim1=1),z_aff).sum(1))
        
#         #find sigma for centering in the direction of mu
#         sigma=(mu_aff/mu)**3

#         #find centering+correction steps
#         rx=torch.zeros(rx.size()).type_as(Q)
#         rs=((sigma*mu).unsqueeze(2).repeat(1,nineq,1)-ds_aff*dz_aff)/s
#         rz=torch.zeros(rz.size()).type_as(Q)
#         ry=torch.zeros(ry.size()).type_as(Q)
#         dx_cor,ds_cor,dz_cor,dy_cor=solve_kkt(H,rx,rs,rz,ry,d)

#         dx=dx_aff+dx_cor
#         ds=ds_aff+ds_cor
#         dz=dz_aff+dz_cor
#         dy=dy_aff+dy_cor
#         # find update step size
#         alpha = torch.min(torch.ones(nbatch).type_as(Q).view(nbatch,1,1),0.99*torch.min(get_step(z, dz),get_step(s, ds)))
#         # update
#         x+=alpha*dx
#         s+=alpha*ds
#         z+=alpha*dz
#         y+=alpha*dy

#         if(i==max_iter-1 and (resids>1e-10).any()):
#           # print("no of mu not converged: ",len(mu[mu>1e-10]))
#           # print("no of primal residual not converged: ",len(pri_resid[pri_resid>1e-10]))
#           # print("no of dual residual 1 not converged: ",len(dual_1_resid[dual_1_resid>1e-10]))
#           # print("no of dual residual 2 not converged: ",len(dual_2_resid[dual_2_resid>1e-10]))
#           print("mpc warning: Residuals not converged, need more itrations")

#     return(x,s,z,y)

In [0]:
# def opt(Q,p,G,h,A,b):
#   nbatch,nx,_=Q.size()
#   nineq=G.size()[1]
#   neq=A.size()[1]
#   check_Q_pd(Q)
#   H=get_Hessian(Q,G,A)
#   A_T=torch.transpose(A,dim0=2,dim1=1)
#   G_T=torch.transpose(G,dim0=2,dim1=1)
#   #initial solution
#   x,s,z,y=solve_kkt(H,-p.unsqueeze(2),torch.zeros(nbatch,nineq).unsqueeze(2).type_as(Q),h.unsqueeze(2),b.unsqueeze(2))
#   alpha_p=get_initial(-z)
#   alpha_d=get_initial(z)
#   s=-z+alpha_p*(torch.ones(z.size()).type_as(z))
#   z=z+alpha_d*(torch.ones(z.size()).type_as(z))
#   #main iterations
#   start = time.time()
#   x,s,z,y=mpc(Q,G,A,p,b,h,x,s,z,y,30)
#   op_val=0.5*torch.bmm(torch.transpose(x,dim0=2,dim1=1),torch.bmm(Q,x))+torch.bmm(torch.transpose(p.unsqueeze(2),dim0=2,dim1=1),x)
#   t = time.time() - start
#   # print(t)
#   return x