In [9]:
import numpy as np
import matplotlib.pyplot as plt

def gradient_ascent_sticky(ptc, pt0):
    rc = ptc[0]
    pc = ptc[1]
    sc = ptc[2]
    r0 = pt0[0]
    p0 = pt0[1]
    s0 = pt0[2]
    flag = False
    
    if rc <= 0:
        rc = 0
        opt_constraint = "line_yz"
    elif pc <= 0:
        pc = 0
        opt_constraint = "line_zx"
    elif sc <= 0:
        sc = 0
        opt_constraint = "line_xy"
    else:
        opt_constraint = "plane"
        
    if opt_constraint == "plane":
        grad_x = s0+r0-2*p0
        grad_y = 2*r0-s0-p0
        rc += grad_x*alpha
        pc += grad_y*alpha
        sc = 1-rc-pc
    elif opt_constraint == "line_xy":
        grad_x = 2*s0-p0-r0
        sc = 0
        rc += grad_x*alpha
        pc = 1-rc
        if rc < 0:
            rc = 0
            pc = 1
            flag = True
        if pc < 0:
            pc = 0
            rc = 1
            flag = True
    elif opt_constraint == "line_yz":
        grad_y = 2*r0-s0-p0
        rc = 0
        pc += grad_y*alpha
        sc = 1-pc
        if pc < 0:
            pc = 0
            sc = 1
            flag = True
        if sc < 0:
            sc = 0
            pc = 1
            flag = True
    elif opt_constraint == "line_zx":
        grad_z = 2*p0-r0-s0
        pc = 0
        sc += grad_z*alpha
        rc = 1-sc
        if sc < 0:
            sc = 0
            rc = 1
            flag = True
        if rc < 0:
            rc = 0
            sc = 1
            flag = True
    return rc, pc, sc, flag

def gradient_ascent(ptc, pt0, alpha):
    lr = alpha
    rc = ptc[0]
    pc = ptc[1]
    sc = ptc[2]
    r0 = pt0[0]
    p0 = pt0[1]
    s0 = pt0[2]
    
    grad_x = s0+r0-2*p0
    grad_y = 2*r0-s0-p0
    rc_p = rc+grad_x*lr
    pc_p = pc+grad_y*lr
    sc_p = 1-rc_p-pc_p
    if rc_p < 0:
        lr = -rc/grad_x
        rc_p = 0
        pc_p = pc+grad_y*lr
        lr = alpha-lr
        grad_y = 2*r0-s0-p0
        pc_p += grad_y*lr
        sc_p = 1-pc_p
        if sc_p < 0 or pc_p >= 1:
            sc_p = 0
            pc_p = 1
        elif pc_p < 0 or sc_p >= 1:
            pc_p = 0
            sc_p = 1
    elif pc_p < 0:
        lr = -pc/grad_y
        rc_p = rc+grad_x*lr
        pc_p = pc+grad_y*lr
        lr = alpha-lr
        grad_x = s0+r0-2*p0
        rc_p += grad_x*lr
        sc_p = 1-rc_p
        if sc_p < 0 or rc_p >= 1:
            sc_p = 0
            rc_p = 1
        elif rc_p < 0 or sc_p >= 1:
            rc_p = 0
            sc_p = 1
    elif sc_p < 0:
        lr = sc/(grad_x+grad_y)
        rc_p = rc+grad_x*lr
        sc_p = 0
        lr = alpha-lr
        grad_x = 2*s0-r0-p0
        rc_p += grad_x*lr
        pc_p = 1-rc_p
        if rc_p < 0 or pc_p >= 1:
            rc_p = 0
            pc_p = 1
        elif pc_p < 0 or rc_p >= 1:
            pc_p = 0
            rc_p = 1
    
        
    rc = rc_p
    sc = sc_p
    pc = pc_p
    return rc, pc, sc

def gradient_ascent_infinite(ptc, pt0, alpha):
    lr = alpha
    rc = ptc[0]
    pc = ptc[1]
    sc = ptc[2]
    r0 = pt0[0]
    p0 = pt0[1]
    s0 = pt0[2]
    
    grad_x = s0+r0-2*p0
    grad_y = 2*r0-s0-p0
    rc_p = rc+grad_x*lr
    pc_p = pc+grad_y*lr
    sc_p = 1-rc_p-pc_p    
        
    rc = rc_p
    sc = sc_p
    pc = pc_p
    return rc, pc, sc

In [10]:
%matplotlib qt
fig = plt.figure()
sub = fig.add_subplot(1, 1, 1)
sub.plot(np.array([-1/np.sqrt(2), 0, 1/np.sqrt(2), -1/np.sqrt(2)]), np.array([-1/np.sqrt(6), np.sqrt(2/3), -1/np.sqrt(6), -1/np.sqrt(6)]))
plt.show()

rot_matrix = np.array([[-np.sqrt(1/2), np.sqrt(1/2), 0], [-np.sqrt(1/6), -np.sqrt(1/6), np.sqrt(2/3)], [np.sqrt(1/3), np.sqrt(1/3), np.sqrt(1/3)]])
alpha = 0.01
# ra, pa, sa = (1/3, 1/6, 0.5)
# rb, pb, sb = (0.3, 0.4, 0.3)
# ra, pa, sa = (1/3, 1/6, 0.5)
# rb, pb, sb = (1/3, 5/9, 1/9)
# ra, pa, sa = (0.34, 0.12, 0.54)
# rb, pb, sb = (0.3, 0.2, 0.5)
# ra, pa, sa = (0.12, 0.34, 0.54)
# rb, pb, sb = (0.2, 0.3, 0.5)
ra, pa, sa = (0.3, 0.45, 0.25)
rb, pb, sb = (0.3, 0.4, 0.3)
p1 = np.matmul(rot_matrix, np.array([ra, pa, sa]))
p0 = np.matmul(rot_matrix, np.array([rb, pb, sb]))
sub.scatter(p1[0], p1[1])
# plt.annotate("i1", (p1[0], p1[1]))
sub.scatter(p0[0], p0[1])
# plt.annotate("i0", (p0[0], p0[1]))

for i in range(500):
    (ra_t, pa_t, sa_t) = gradient_ascent_infinite((ra, pa, sa), (rb, pb, sb), alpha)
    (rb_t, pb_t, sb_t) = gradient_ascent_infinite((rb, pb, sb), (ra, pa, sa), alpha)
    ra = ra_t
    rb = rb_t
    pa = pa_t
    pb = pb_t
    sa = sa_t
    sb = sb_t
    
#     (ra, pa, sa) = gradient_ascent_infinite((ra, pa, sa), (rb, pb, sb), alpha)
#     (rb, pb, sb) = gradient_ascent_infinite((rb, pb, sb), (ra, pa, sa), alpha)

# A's projections:
#     (rb_t_proj, pb_t_proj, sb_t_proj) = gradient_ascent_infinite((rb, pb, sb), (ra, pa, sa), alpha)
# B's projections:
#     (ra_t_proj, pa_t_proj, sa_t_proj) = gradient_ascent_infinite((ra, pa, sa), (rb, pb, sb), alpha)
#     (ra, pa, sa) = gradient_ascent_infinite((ra, pa, sa), (rb_t_proj, pb_t_proj, sb_t_proj), alpha)
#     (rb, pb, sb) = gradient_ascent_infinite((rb, pb, sb), (ra_t_proj, pa_t_proj, sa_t_proj), alpha)
    
    
    p1 = np.matmul(rot_matrix, np.array([ra, pa, sa]))
    p0 = np.matmul(rot_matrix, np.array([rb, pb, sb]))
    sub.scatter(p1[0], p1[1])
    sub.scatter(p0[0], p0[1])
    sub.plot(np.array([-1/np.sqrt(2), 0, 1/np.sqrt(2), -1/np.sqrt(2)]), np.array([-1/np.sqrt(6), np.sqrt(2/3), -1/np.sqrt(6), -1/np.sqrt(6)]))

#     print(i, rc, pc, sc, rc+pc+sc)

sub.scatter(0, 0)
plt.show()    