In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import LinearLocator



In [None]:
nx = 20
ny = nx
lx = 2
ly = 2

xt = np.array([0.5, 0.5])
a = np.array([0.2, 0.8])


xt = np.expand_dims(xt, axis=[0,1])
a = np.expand_dims(a, axis=[0,1])

x1 = np.arange(lx/nx, lx+lx/nx, lx/nx)
x2 = np.arange(ly/ny, ly+ly/ny, ly/ny)
X1, X2 = np.meshgrid(x1, x2)
X = np.array([X1, X2]).T



In [None]:
X.shape, a.shape, xt.shape

In [None]:
def f(x, a):
    T = np.sum(a*x, axis=-1)
    return - np.log(T) + T

def MU_surrogate(x, a, xt):
    ut = a*xt
    ut = ut / np.sum(ut, axis=-1, keepdims=True)
    return np.sum( - ut * ( np.log(a*x) - np.log(ut) ) + a*x, axis=-1)

def EM_surrogate(x, a, xt):
    pass

def Bregman_surrogate(x, a, xt):
    L = 1
    ax = np.sum(a*x, axis=-1)
    axt = np.sum(a*xt, axis=-1)
    return ax - ax / axt + np.sum(- np.log(x) + x / xt, axis=-1)

def Weighted_Bregman_surrogate(x, a, xt):
    L = 1
    ax = np.sum(a*x, axis=-1)
    axt = np.sum(a*xt, axis=-1)
    ut = a*xt
    ut = ut / np.sum(ut, axis=-1, keepdims=True)
    return ax - ax / axt + np.sum(- ut*np.log(x) + ut*x / xt, axis=-1)


In [None]:
Z = f(X,a)
Z_MU = MU_surrogate(X, a, xt)
Z_Bregman = Bregman_surrogate(X, a, xt)
Z_Weighted_Bregman = Weighted_Bregman_surrogate(X, a, xt)

fxk = f(xt, a)
Z_Bregman = Z_Bregman-Bregman_surrogate(xt, a, xt) + fxk
Z_Weighted_Bregman = Z_Weighted_Bregman-Weighted_Bregman_surrogate(xt, a, xt) + fxk

In [None]:
Z.shape, X1.shape

In [None]:
def subplot3d(ax, X1, X2, Z):
    # Plot the surface.
    surf = ax.plot_surface(X1, X2, Z, cmap=plt.cm.coolwarm,
                        linewidth=0, antialiased=False)
    # Customize the z axis.
    # ax.set_zlim(-1.01, 1.01)
    ax.zaxis.set_major_locator(LinearLocator(10))
    # A StrMethodFormatter is used automatically
    ax.zaxis.set_major_formatter('{x:.02f}')

    # Add a color bar which maps values to colors.
    plt.colorbar(surf, shrink=0.5, aspect=5)

plt.figure(figsize=(18, 5),)

ax = plt.subplot(1,3,1,  projection='3d')
subplot3d(ax, X1, X2, Z)

ax = plt.subplot(1,3,2,  projection='3d')
subplot3d(ax, X1, X2, Z_MU)

ax = plt.subplot(1,3,3,  projection='3d')
subplot3d(ax, X1, X2, Z_Bregman)

plt.show()

In [None]:


plt.figure(figsize=(8,8))
rstride = 2
cstride = 2
ax = plt.subplot( projection='3d')

ax.plot_wireframe(X1, X2, Z, rstride=rstride, cstride=cstride, color="black", label="$f(x)$")
ax.plot_wireframe(X1, X2, Z_MU, rstride=rstride, cstride=cstride, color="green", label="MU surrogate")
ax.plot_wireframe(X1, X2, Z_Bregman, rstride=rstride, cstride=cstride, color="blue", label="Bregman surrogate")
ax.scatter3D(xt[:,:,0], xt[:,:,1], fxk[0], "o",color="red", label="$x^t$", linewidth=6)
ax.legend()

In [None]:
nx = 500
lx = 2
y = 0.5

xt = np.array([0.5, y])
a = np.array([0.2, 0.8])


xt = np.expand_dims(xt, axis=[0,1])
a = np.expand_dims(a, axis=[0,1])

x1 = np.arange(lx/nx, lx+lx/nx, lx/nx)
x2 = np.arange(y, y+ly/ny, ly/ny)
X1, X2 = np.meshgrid(x1, x2)
X = np.array([X1, X2]).T

In [None]:
z = f(X,a)
z_MU = MU_surrogate(X, a, xt)
z_Bregman = Bregman_surrogate(X, a, xt)
z_Weighted_Bregman = Weighted_Bregman_surrogate(X, a, xt)

fxk = f(xt, a)
z_Bregman = z_Bregman-Bregman_surrogate(xt, a, xt) + fxk
z_Weighted_Bregman = z_Weighted_Bregman-Weighted_Bregman_surrogate(xt, a, xt) + fxk

In [None]:
def argmin(x, z):
    i = np.argmin(z)
    return x[i], z[i]

In [None]:
plt.figure(figsize=(6,4))

plt.plot(x1, z, color="black", label="$f(x)$")
plt.plot(x1, z_MU, color="green", label="MU surrogate")
plt.plot(*argmin(x1,z_MU), "x", color="green", label="MU minimum")
plt.plot(x1, z_Bregman,  color="blue", label="Bregman surrogate")
plt.plot(*argmin(x1,z_Bregman), "x", color="blue", label="Bregman minimum")

# plt.plot(x1, z_Weighted_Bregman,  color="orange", label="Weigthed Bregman surrogate")
# plt.plot(*argmin(x1,z_Weighted_Bregman), "x", color="orange", label="Weighted Bregman minimum")


plt.plot(xt[0, 0, :1], fxk, "o", color="red", label="$x^t$" )
plt.ylim([1,2])
plt.xlim([0,2])
plt.legend(loc=1)