<a href="https://colab.research.google.com/github/annechris13/Master-Thesis/blob/master/newton_method.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import numpy as np
import pandas as pd

In [0]:
#example problem - 2 batches , 2 variable qp
nbatch=2
nBatch=nbatch
nx=2
nineq=2
neq=1
#to do: extract dimensions from problem parameters + check/add batch dimension
Q=torch.tensor([[4,1,1,2],[6,2,2,2]]).view(nbatch,nx,nx).type(torch.DoubleTensor)
p=torch.tensor([[1,1],[1,6]]).view(nbatch,nx).type(torch.DoubleTensor)
G=torch.tensor([[-1,0,0,-1],[-1,0,0,-1]]).view(nbatch,nineq,nx).type(torch.DoubleTensor)
h=torch.tensor([[0,0],[0,0]]).view(nbatch,nineq).type(torch.DoubleTensor)
A=torch.tensor([[1,1],[2,3]]).view(nbatch,neq,nx).type(torch.DoubleTensor)
b=torch.tensor([[1],[4]]).view(nbatch,neq).type(torch.DoubleTensor)

In [0]:
#check if Q is psd:
for i in range(nbatch):
  e,_=torch.eig(Q[i])
  if not torch.all(e[:,0]>0):
    raise RuntimeError("Q is not PD")

In [0]:
def lu_hack(x):
    data, pivots = x.lu(pivot=not x.is_cuda)
    if x.is_cuda:
        if x.ndimension() == 2:
            pivots = torch.arange(1, 1+x.size(0)).int().cuda()
        elif x.ndimension() == 3:
            pivots = torch.arange(
                1, 1+x.size(1),
            ).unsqueeze(0).repeat(x.size(0), 1).int().cuda()
        else:
            assert False
    return (data, pivots)

In [0]:
def bdiag(d):
    nBatch, sz, _ = d.size()
    D = torch.zeros(nBatch, sz, sz).type_as(d)
    I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).bool()
    D[I] = d.squeeze().view(-1)
    return D

In [0]:
def get_Hessian(Q,G,A):
    nbatch,nineq,nx=G.size()
    neq=A.size()[1]
    B1=torch.zeros(nbatch,nx+nineq,nx+nineq).type_as(Q)
    B3=torch.zeros(nbatch,neq+nineq,nx+nineq).type_as(Q)
    B4=torch.zeros(nbatch,neq+nineq,neq+nineq).type_as(Q)

    B1[:,:nx,:nx]=Q
    B1[:,-nineq:,-nineq:]=torch.eye(nineq).repeat(nbatch,1,1).type_as(Q)

    B3[:,:nineq,:nx]=G
    B3[:,-neq:,:nx]=A
    B3[:,:nineq,nineq:]=torch.eye(nineq).repeat(nbatch,1,1).type_as(Q)

    B2=torch.transpose(B3, dim0=2, dim1=1)

    H=torch.cat((torch.cat((B1,B2),dim=2),torch.cat((B3,B4),dim=2)),dim=1)
  
    return H

In [0]:
def solve_kkt(H,rx,rs,rz,ry,d=None):
    if d!=None:
      D=bdiag(d)
      H[:,nx:nx+nineq,nx:nx+nineq]=D
    # print("H: ",H)
    F=torch.cat((rx,rs,rz,ry), dim=1)
    H_lu,H_piv= lu_hack(H)
    step=F.lu_solve(H_lu,H_piv)

    rx=step[:,:nx,:]
    rs=step[:,nx:nx+nineq,:]
    rz=step[:,nx+nineq:-neq,:]
    ry=step[:,-neq:,:]
    return(rx,rs,rz,ry)

In [0]:
def get_initial(z):
      nbatch,_,_=z.size()
      dz=torch.ones(z.size()).type_as(z)
      alpha=torch.tensor([]).type_as(z)
      for b in range(nbatch):
        step=torch.tensor([-0.1]).type_as(z)
        z_=z[b,:,:]
        dz_=dz[b,:,:]
        while True:
          if (z_+step*dz_ >0).all():
            if step<0:
              alpha=torch.cat((alpha,torch.tensor([0]).type_as(z)))
            else:
              alpha=torch.cat((alpha,1+step))
            break
          else:
            step=step+0.1
      return alpha.view(nbatch,1,1)

In [0]:
def get_step(v,dv):
      #TO DO: find efficient and accurate line search algorithm
      nbatch,_,_=v.size()
      alpha=torch.tensor([]).type_as(v)
      for b in range(nbatch):
        step=torch.tensor([1]).type_as(v)
        v_=v[b,:,:]
        dv_=dv[b,:,:]
        while True:
          if (v_+step*dv_ >=0).all() or step==0:
            alpha=torch.cat((alpha,step))
            break
          else:
            step=step-0.1
      return alpha.view(nbatch,1,1)


In [0]:
def pure_newton(Q,G,A,p,b,h,x,s,z,y, max_iter=2):
    H=get_Hessian(Q,G,A)
    A_T=torch.transpose(A,dim0=2,dim1=1)
    G_T=torch.transpose(G,dim0=2,dim1=1)
    
    for i in range(max_iter):
        rx= -(torch.bmm(A_T,y)+torch.bmm(G_T,z)+torch.bmm(Q,x)+p.unsqueeze(2))
        rs=-z
        rz=-(torch.bmm(G,x)+s-h.unsqueeze(2))
        ry=-(torch.bmm(A,x)-b.unsqueeze(2))
        d=z/s
        #affine step calculation
        dx,ds,dz,dy=solve_kkt(H,rx,rs,rz,ry,d)
        #step size calculation
        alpha = torch.min(get_step(z, dz),get_step(s, ds))
        #update step
        x+=alpha*dx
        s+=alpha*ds
        z+=alpha*dz
        y+=alpha*dy
        # print(x)
        # print(rx,rs,rz,ry)
    return(x,s,z,y)

In [0]:
H=get_Hessian(Q,G,A)
A_T=torch.transpose(A,dim0=2,dim1=1)
G_T=torch.transpose(G,dim0=2,dim1=1)
#initial solution
x,s,z,y=solve_kkt(H,-p.unsqueeze(2),torch.zeros(nbatch,nineq).unsqueeze(2).type_as(Q),h.unsqueeze(2),b.unsqueeze(2))
alpha_p=get_initial(-z)
alpha_d=get_initial(z)
s=-z+alpha_p*(torch.ones(z.size()).type_as(z))
z=z+alpha_d*(torch.ones(z.size()).type_as(z))

In [74]:
x,s,z,y=pure_newton(Q,G,A,p,b,h,x,s,z,y,20)
op_val=0.5*torch.bmm(torch.transpose(x,dim0=2,dim1=1),torch.bmm(Q,x))+torch.bmm(torch.transpose(p.unsqueeze(2),dim0=2,dim1=1),x)
print("optimal point: \n", x.numpy().reshape(nbatch,-1,nx))
print("\noptimal objective value: \n", op_val.numpy().reshape(nbatch,-1,1))


optimal point: 
 [[[0.25 0.75]]

 [[0.5  1.  ]]]

optimal objective value: 
 [[[1.875]]

 [[9.25 ]]]
