<a href="https://colab.research.google.com/github/annechris13/Master-Thesis/blob/master/mpc_class.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import numpy as np
import pandas as pd
from sklearn.datasets import make_spd_matrix
import random
import warnings
warnings.filterwarnings('error')
import time

In [0]:
def check_Q_pd(Q):
  #check if Q is pd:
  nbatch=Q.size()[0]
  for i in range(nbatch):
    e,_=torch.eig(Q[i])
    if not torch.all(e[:,0]>0):
      raise RuntimeError("Q is not PD")

In [0]:
def lu_hack(x):
    #do lu factorization of x
    data, pivots = x.lu(pivot=not x.is_cuda)
    if x.is_cuda:
        if x.ndimension() == 2:
            pivots = torch.arange(1, 1+x.size(0)).int().cuda()
        elif x.ndimension() == 3:
            pivots = torch.arange(
                1, 1+x.size(1),
            ).unsqueeze(0).repeat(x.size(0), 1).int().cuda()
        else:
            assert False
    return (data, pivots)

In [0]:
def bdiag(d):
    #return diagonal matrix with diagonal entries d
    nBatch, sz, _ = d.size()
    D = torch.zeros(nBatch, sz, sz).type_as(d)
    I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).bool()
    D[I] = d.squeeze().view(-1)
    return D

In [0]:
def get_Hessian(Q,G,A):
    #get the hessian kkt matrix
    nbatch,nineq,nx=G.size()
    neq=A.size()[1]
    B1=torch.zeros(nbatch,nx+nineq,nx+nineq).type_as(Q)
    B3=torch.zeros(nbatch,neq+nineq,nx+nineq).type_as(Q)
    B4=torch.zeros(nbatch,neq+nineq,neq+nineq).type_as(Q)

    B1[:,:nx,:nx]=Q
    #D here is unit identity matrix
    B1[:,-nineq:,-nineq:]=torch.eye(nineq).repeat(nbatch,1,1).type_as(Q)

    B3[:,:nineq,:nx]=G
    B3[:,-neq:,:nx]=A
    B3[:,:nineq,nineq:]=torch.eye(nineq).repeat(nbatch,1,1).type_as(Q)

    B2=torch.transpose(B3, dim0=2, dim1=1)

    H=torch.cat((torch.cat((B1,B2),dim=2),torch.cat((B3,B4),dim=2)),dim=1)
  
    return H

In [0]:
def solve_kkt(H,rx,rs,rz,ry,d=None):
    # solve the KKT system with hessian H and F specified by rx,rs,rz,ry
    # the hessian H is modified when d is specified
    nx=rx.size()[1]
    nineq=rz.size()[1]
    neq=ry.size()[1]
    if d!=None:
      D=bdiag(d)
      H[:,nx:nx+nineq,nx:nx+nineq]=D
    # print("H: ",H)
    H_lu,H_piv= lu_hack(H)
    F=torch.cat((rx,rs,rz,ry), dim=1)
    step=F.lu_solve(H_lu,H_piv)

    rx=step[:,:nx,:]
    rs=step[:,nx:nx+nineq,:]
    rz=step[:,nx+nineq:-neq,:]
    ry=step[:,-neq:,:]
    return(rx,rs,rz,ry)

In [0]:
def get_initial(z):
      #get step size using line search for initialization 
      nbatch,_,_=z.size()
      dz=torch.ones(z.size()).type_as(z)
      alpha=torch.tensor([]).type_as(z)
      for b in range(nbatch):
        step=torch.tensor([-0.1]).type_as(z)
        z_=z[b,:,:]
        dz_=dz[b,:,:]
        while True:
          if (z_+step*dz_ >0).all():
            if step<0:
              alpha=torch.cat((alpha,torch.tensor([0]).type_as(z)))
            else:
              alpha=torch.cat((alpha,1+step))
            break
          else:
            step=step+0.1
      return alpha.view(nbatch,1,1)

In [0]:
# def get_step(v,dv):
#       #get step sizes for each iteration
#       #TO DO: find efficient and accurate line search algorithm
#       nbatch,_,_=v.size()
#       alpha=torch.tensor([]).type_as(v)
#       for b in range(nbatch):
#         step=torch.tensor([1]).type_as(v)
#         v_=v[b,:,:]
#         dv_=dv[b,:,:]
#         while step>0:
#           if (v_+step*dv_ >=0).all() or step==0:
#             alpha=torch.cat((alpha,step))
#             break
#           else:
#             step=step-0.01
#         if(step<0):
#           alpha=torch.cat((alpha,torch.tensor([0]).type_as(v)))
#       return alpha.view(nbatch,1,1)

# def get_step(v, dv):
#     #qpth version of get_step
#     a = -v / dv
#     a[dv > 0] = max(1.0, a.max())
#     return a.min(1)[0].squeeze()

def get_step(v,dv):
  v=v.squeeze(2)
  dv=dv.squeeze(2)
  div= -v/dv
  ones=torch.ones_like(div)
  div=torch.where(torch.isinf(div),ones,div)
  div=torch.where(torch.isnan(div),ones,div)
  div[dv>0]=max(1.0,div.max())
  return (div.min(1)[0]).view(v.size()[0],1,1)

In [0]:
def mpc(Q,G,A,p,b,h,x,s,z,y, max_iter=20):
    nbatch,nx,_=Q.size()
    nineq=G.size()[1]
    neq=A.size()[1]
    H=get_Hessian(Q,G,A)
    A_T=torch.transpose(A,dim0=2,dim1=1)
    G_T=torch.transpose(G,dim0=2,dim1=1)
    count=0
    bat=np.array([i for i in range(nbatch)])
    for i in range(max_iter):
        # print("iteration: ",i)
        rx= -(torch.bmm(A_T,y)+torch.bmm(G_T,z)+torch.bmm(Q,x)+p.unsqueeze(2))
        rs=-z
        rz=-(torch.bmm(G,x)+s-h.unsqueeze(2))
        ry=-(torch.bmm(A,x)-b.unsqueeze(2))
        d=z/s
        mu=torch.abs(torch.bmm(torch.transpose(s,dim0=2,dim1=1),z).sum(1))
        pri_resid=torch.abs(rx)
        dual_1_resid=torch.abs(rz)
        dual_2_resid=torch.abs(ry)
        if (i%4==0):
          log=((pri_resid.sum(1)<1e-12)*(dual_1_resid.sum(1)<1e-12)*(dual_2_resid.sum(1)<1e-12)*(mu<1e-12)).squeeze(1).numpy().astype(bool)
          # print("already converged: ",bat[log] )

        resids=np.array([pri_resid.max(),mu.max(),dual_1_resid.max(),dual_2_resid.max()])
        try:
          if (resids<1e-12).all():
            # print("Early exit at iteration no:",i)
            return(x,s,z,y)
        except:
          print(bat[torch.isnan(pri_resid.sum(1)).squeeze(1)])
          raise RuntimeError("invalid res")
        
        #affine step calculation
        dx_aff,ds_aff,dz_aff,dy_aff=solve_kkt(H,rx,rs,rz,ry,d)
        #affine step size calculation
        alpha = torch.min(get_step(z, dz_aff),get_step(s, ds_aff))
        
        #affine updates for s and z
        s_aff=s+alpha*ds_aff
        z_aff=z+alpha*dz_aff
        mu_aff=torch.abs(torch.bmm(torch.transpose(s_aff,dim0=2,dim1=1),z_aff).sum(1))
        
        #find sigma for centering in the direction of mu
        sigma=(mu_aff/mu)**3

        #find centering+correction steps
        rx=torch.zeros(rx.size()).type_as(Q)
        rs=((sigma*mu).unsqueeze(2).repeat(1,nineq,1)-ds_aff*dz_aff)/s
        rz=torch.zeros(rz.size()).type_as(Q)
        ry=torch.zeros(ry.size()).type_as(Q)
        dx_cor,ds_cor,dz_cor,dy_cor=solve_kkt(H,rx,rs,rz,ry,d)

        dx=dx_aff+dx_cor
        ds=ds_aff+ds_cor
        dz=dz_aff+dz_cor
        dy=dy_aff+dy_cor
        # find update step size
        alpha = torch.min(torch.ones(nbatch).type_as(Q).view(nbatch,1,1),0.99*torch.min(get_step(z, dz),get_step(s, ds)))
        # update
        x+=alpha*dx
        s+=alpha*ds
        z+=alpha*dz
        y+=alpha*dy

        if(i==max_iter-1 and (resids>1e-10).any()):
          # print("no of mu not converged: ",len(mu[mu>1e-10]))
          # print("no of primal residual not converged: ",len(pri_resid[pri_resid>1e-10]))
          # print("no of dual residual 1 not converged: ",len(dual_1_resid[dual_1_resid>1e-10]))
          # print("no of dual residual 2 not converged: ",len(dual_2_resid[dual_2_resid>1e-10]))
          print("mpc warning: Residuals not converged, need more itrations")

    return(x,s,z,y)

In [0]:
def opt(Q,p,G,h,A,b):
  nbatch,nx,_=Q.size()
  nineq=G.size()[1]
  neq=A.size()[1]
  check_Q_pd(Q)
  H=get_Hessian(Q,G,A)
  A_T=torch.transpose(A,dim0=2,dim1=1)
  G_T=torch.transpose(G,dim0=2,dim1=1)
  #initial solution
  x,s,z,y=solve_kkt(H,-p.unsqueeze(2),torch.zeros(nbatch,nineq).unsqueeze(2).type_as(Q),h.unsqueeze(2),b.unsqueeze(2))
  alpha_p=get_initial(-z)
  alpha_d=get_initial(z)
  s=-z+alpha_p*(torch.ones(z.size()).type_as(z))
  z=z+alpha_d*(torch.ones(z.size()).type_as(z))
  #main iterations
  start = time.time()
  x,s,z,y=mpc(Q,G,A,p,b,h,x,s,z,y,30)
  op_val=0.5*torch.bmm(torch.transpose(x,dim0=2,dim1=1),torch.bmm(Q,x))+torch.bmm(torch.transpose(p.unsqueeze(2),dim0=2,dim1=1),x)
  t = time.time() - start
  # print(t)
  return x