In [31]:
'''This is a note book.'''

In [69]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging

import math 
import matplotlib.pyplot as plt
import numpy as np 
from numpy.random import rand 
import scipy 
from scipy.special import gamma 
from scipy.special import gammaincc


def getallobj(p,n,R):
  # calculate all objective values for different shifts
  allobj=np.zeros(n)
  l=len(p)
    for k in range(1,n+1):
    o1=np.sum((p[1:l-k-1] - p[1+k:l-1])*np.log(p[1:l-k-1]/p[1+k:l-1]))           
    q_plus= p[l-1]*np.power(R, range(0,k))
    q_minus=p[0]*np.power(R, range(k-1,-1,-1))
    o2=np.sum((p[l-k:l]-q_plus)*np.log(p[l-k:l]/q_plus)) + np.sum((p[1:1+k]-q_minus)*np.log(p[1:1+k]/q_minus))
    o3=(p[0] + p[l-1])*k*(1-R**(k+1))/(1-R)*(- np.log(R))
    allobj[k-1]=o1 + o2 + o3
  return allobj


def getgrad(p,k,R,opt_type):
  # Calculate gradient for a shift of k
  l=len(p)
  rat1=np.concatenate((p[1+k:l],p[l-1]*np.power(R, range(1,k)))) / p[1:l-1]
  rat2=np.concatenate((p[0]*np.power(R,np.arange(k,0,-1)),p[0:l-k-2])) / p[1:l-1]
  gradmid=(- np.log(rat1) + 1 - rat1) + (- np.log(rat2) + 1 - rat2)

  end=len(rat1)
  Nfactor=k*(1-R**k)/(1-R)*(-np.log(R))
  gradN=sum(np.power(R,np.arange(0,k))*(np.log(rat1[end-k:end])+1-1.0/rat1[end-k:end])) + Nfactor
  gradmN=sum(np.power(R,np.arange(k-1,-1,-1))*((np.log(rat2[:k])+1-1.0/rat2[:k]))) + Nfactor

  grad=np.concatenate(([gradmN],gradmid,[gradN]))

  return grad

def getH(p,n,evals,R,opt_type):
  # Calculate the Hessian matrix
  l=len(p)

  d=np.zeros(l)
  for k in range(n):
    dN = sum(np.power(R, np.arange(0,k+1)))/p[l-1] + sum(p[-k-1:]) / p[l-1] ** 2
    dmN = sum(np.power(R,np.arange(k-1,-1,-1))) / p[0] + sum(p[1:1+k]) / p[0] ** 2
    dmid = 2.0/p[1:l-1] + (np.concatenate((p[2+k:], np.power(R, np.arange(1,k+1))*p[l-1])) + np.concatenate((np.power(R, np.arange(k,0,-1))*p[0], p[:l-2-k] )))/ p[1:l-1]**2
    d=d + evals[k]*np.concatenate(([dmN],dmid,[dN]))

  H=np.diagflat(d)

  q=np.zeros((n,1))
  for k in range(n):
    j = 1 + k
    q[k]=sum(evals[k:n]*(-np.power(R,np.arange(0,n-k))/p[j]-1.0/p[0]))
  H[0,1:n+1]=q.reshape(-1)
  H[1:n+1,0]=q.reshape(-1)

  # deal with i=l
  q=np.zeros((n,1))
  for k in range(n):
    i = l-k-1
    q[k]=sum(evals[k:n]*(-np.power(R,np.arange(0,n-k))/p[i]-1.0/p[l-1]))
  H[l-1,range(l-2,l-n-2,-1) ]=q.reshape(-1)
  H[range(l-2,l-n-2,-1),l-1]=q.reshape(-1)  

  for i in range(1,l-2):
    kmax = min(n,l-i-1)
    q=evals[:kmax]*(-1/p[i]-1.0/p[i+1:i+1+kmax])
    H[i,i+1:i+1+kmax]=q.reshape(-1)
    H[i+1:i+1+kmax,i]=q.reshape(-1)        
  
  return H

In [None]:
def main(n, xmax, C, c_exp=2, opt_type=1, r=0.9, tol=1e-08, verbose=2):

  N = math.ceil(n*xmax)
  xmax = N/n

  x=np.arange(-N,N+1)/n
  l=len(x)

  # The two constraints (probability normalization, and the variance constraint) are handled by A*p=b as follows. 
  # Note that even though the variance constraint is an inequality, it will always be tight at optimal, so we treat it as an equality constraint.

  c=np.zeros(l)
  for i in range(l):
    if x[i]>0:
      c[i] = ((x[i]+1/2/n)**(c_exp+1)-(x[i]-1/2/n)**(c_exp + 1))*(n/(c_exp+1))
    elif x[i]==0:
      c[i] = (1 / 2 / n) ** c_exp / (1 + c_exp)
    else:
      c[i]=np.dot(((abs(x[i]) + 1 / 2 / n) ** (c_exp + 1) - (abs(x[i]) - 1 / 2 / n) ** (c_exp + 1)),(n / (c_exp + 1)))

  # slight over-estimate, but getting the exact value requires a nasty infinite sum. This upper bounds the sum by an integral. 
  # This still gives a true bound, without much loss
  cN = n*r**(-N-1/2)*(-n*np.log(r))**(-c_exp-1)*gamma(c_exp+1)*gammaincc(c_exp+1,-np.log(r)*(N-1/2))
  c[0]=cN
  c[l-1]=cN
  d = np.ones(l)
  d[0] = d[l-1] = 1/(1-r)  
  A = np.stack((c,d),axis = 0) 

  b = np.stack(([C],[1]),axis = 0)
  # We start with a guess for the distribution p.
  # This is a distribution designed to have farily heavy tails (which seems to make the solver work better), and to roughly satisfy the cost constraint.
  q = C ** (- 1 - 2 / c_exp)
  p = 1/(1+q*abs(x)**(c_exp+2))
  # make sure that p is normalized (it actually doesn't need to satisfy the cost constraint exactly initially -- the algorithm with ensure that)
  p = p/np.dot(A[1],p) 
  
  iter=0

  t=1

  isfeasible = False

  allobj = getallobj(p,n,r,opt_type)
  
  primobj=np.amax(allobj)
  fval=np.log(sum(np.exp(t*(allobj - primobj)))) / t + primobj

  while 1:

    iter = iter + 1
    eta = np.exp(t*(allobj - primobj))
    eta=eta / sum(eta)
    allgrads=np.zeros((l,n))

    for k in range(n):
      allgrads[:,k]=getgrad(p,k+1,r,opt_type).reshape(-1)
    print(np.shape(allgrads))
    print(np.shape(eta.reshape(n,1)))
    grad=np.matmul(allgrads,eta.reshape(n,1))
    print(grad[0:15])

    rpri=np.matmul(A,p.reshape(len(p),1))-b

    # The next few lines are typically the majority of the run-time.
    # Compute Hessian matrix
    H=getH(p,n,eta,r,opt_type)
    # strictly positive definite. This doesn't much change overall performance
    fullH = H + np.matmul(np.matmul(allgrads,np.diagflat(t*eta)),allgrads.T)-np.matmul(t*(1-1e-05)*grad,grad.T)
    
    try:
      R=np.linalg.cholesky(fullH)

      gtilde=np.linalg.solve(R.T,grad)

      AR=np.matmul(A, np.linalg.inv(R))
      U,S,V = np.linalg.svd(AR,full_matrices=False)
      s=S
      vtilde = - gtilde - np.matmul(V.T,(1.0/s.reshape(2,1)*np.matmul(U.T,rpri))) + np.matmul(V.T,(np.matmul(V,gtilde)))
      v=np.linalg.solve(R,vtilde)

    except np.linalg.LinAlgError as err:
      if verbose >=1:
          print('Chlesky failed')
      A01=np.concatenate((fullH, A.T),axis=1)
      A02=np.concatenate((A, np.zeros((2,2))),axis=1)
      A1 = np.concatenate((A01,A02),axis=0)
      b1=np.concatenate((-grad, -rpri))
      vw = np.linalg.solve(A1,b1)
      v = vw[0:l]
    nwt_dec=-sum(grad*v) / 2

    # implementing a line search. We move to p+dst*v
    dst=1
    while 1:
      pnew = p + dst*v.reshape(-1)
      if min(pnew) > 0:
        allobj=getallobj(pnew,n,r,opt_type)
        primobj=max(allobj)
        newfval=np.log(sum(np.exp(t*(allobj - primobj)))) / t + primobj
        if not isfeasible:
          break
        if newfval < fval + 0.1*dst*(np.matmul(grad.T,v)):
          break
        if dst < 1e-08:
          break
      dst=dst / 2

    if not isfeasible and dst == 1:
      isfeasible=True
      if verbose >= 1:
          print('Feasible!')
    p=pnew
    fval=newfval
    if verbose >= 1:
      print('iter=%d  dst=%1.1e  primobj=%1.4f  nwt_dec=%1.2e  t=%1.1e \n' %(iter,dst,primobj,nwt_dec,t));

    if isfeasible:
      # n/e/t is an estimate for the duality gap, if we are at exact optimality for the soft-max problem. 
      # Thus nwt_dec+n/e/1 is an estimate for the total duality gap
      if nwt_dec + n / exp(1) / t < tol:
          break
      # if we take a full Newton step or get extremely close to optimal, then increase t.
      if dst == 1 or nwt_dec < tol / 2:
          # This is a fairly modest increase in t.
          # This seems to make the solver the most robust, if not necessarily the most efficient.
          t=t*1.25
    
  plt.plot(x,p)
  filename=('cactus_x_d1.0v%.2f.csv' %C)
  np.savetxt(filename, x, delimiter=",")
      
  filename=('cactus_p_d1.0_v%.2f.csv' %C)
  np.savetxt(filename, p, delimiter=",")
      
  return x, p 

if __name__ == '__main__':
  app.run(main(n=200, xmax=8, C=0.25, c_exp=2, opt_type=1, r=0.9, tol=1e-05, verbose=2))

(3201, 200)
(200, 1)
[[-63.87662019]
 [ 17.53846502]
 [ 17.4344836 ]
 [ 17.3305397 ]
 [ 17.22663321]
 [ 17.12276406]
 [ 17.01893217]
 [ 16.91513749]
 [ 16.81137996]
 [ 16.70765954]
 [ 16.6039762 ]
 [ 16.50032993]
 [ 16.39672071]
 [ 16.29314856]
 [ 16.18961349]]
iter=1  dst=3.9e-03  primobj=4.4176  nwt_dec=4.63e+00  t=1.0e+00 

(3201, 200)
(200, 1)
[[-17.58267142]
 [ 17.29930371]
 [ 17.19863331]
 [ 17.09805757]
 [ 16.99740764]
 [ 16.89668677]
 [ 16.79589814]
 [ 16.69504484]
 [ 16.59412987]
 [ 16.49315617]
 [ 16.3921266 ]
 [ 16.29104393]
 [ 16.18991088]
 [ 16.08873008]
 [ 15.98750412]]
iter=2  dst=2.0e-03  primobj=4.4023  nwt_dec=2.93e+00  t=1.0e+00 

(3201, 200)
(200, 1)
[[ 2.94294713]
 [17.18536414]
 [17.08595163]
 [16.98668725]
 [16.8873046 ]
 [16.78780819]
 [16.68820244]
 [16.58849162]
 [16.48867991]
 [16.38877136]
 [16.28876994]
 [16.18867947]
 [16.08850372]
 [15.98824632]
 [15.88791082]]
iter=3  dst=9.8e-04  primobj=4.3955  nwt_dec=2.57e+00  t=1.0e+00 

(3201, 200)
(200, 1)
[[12.23