In [2]:
import numpy as np

In [208]:
def normal(x, mu, Sigma):
    d = Sigma.shape[0]
    if d != len(mu):
        raise ValueError
    return np.exp(-(x-mu).T.dot(np.linalg.inv(Sigma)).dot(x-mu)/ 2) / np.sqrt((2*np.pi)**d * np.linalg.det(Sigma))

def v(a,b):
    u = np.array([a-b])
    return u.T.dot(u)

def loss(Y, z, mu, Sigma, alpha=0.5, lamb=1.0):
    return -sum([normal(y, mu, Sigma) for y in Y]) + lamb * alpha * normal(z, mu, Sigma)

def MLE(Y):
    mu = np.mean(Y)
    Sigma = np.var(Y)
    if type(Sigma) in (np.float32, np.float64):
        return (np.array([mu]), np.array([[Sigma]]))
    else:
        return (np.array([mu]), np.array([Sigma]))
    
def penMLE(Y, z, alpha=0.5, lamb=1.0):
    mu, Sigma = MLE(Y)
    for n_steps in range(10):
        gamma = lamb * alpha * normal(z, mu, Sigma)
        Sigma_new = (np.sum([v(y,mu) for y in Y]) - gamma*v(z,mu)) / (len(Y)-gamma)
        mu_new = (np.sum(Y)-gamma*z) / (len(Y)-gamma)
        if np.linalg.norm(mu_new-mu) < 1e-16 and np.linalg.norm(Sigma_new-Sigma, ord='fro') < 1e-16:
            break
        mu, Sigma = mu_new, Sigma_new
    print('steps', n_steps)
    return (mu_new, Sigma_new)

def check(Y, z, mu, Sigma, alpha=0.5, lamb=1.0, times=10):
    l = loss(Y, z, mu, Sigma, alpha=alpha, lamb=lamb)
    for _ in range(times):
        mu_shift = (np.random.rand(len(mu))-0.5)*0.00001
        Sigma_shift = (np.random.rand(len(mu), len(mu))-0.5)*0.00001
        l_rand = loss(Y, z, mu+mu_shift, Sigma+Sigma_shift, alpha=alpha, lamb=lamb)
        if l_rand < l-0.0001:
            print(mu+mu_shift, Sigma+Sigma_shift, l_rand)
            return False
    return True

In [187]:
Y = [np.array([y]) for y in np.random.normal(size=20)]
Y

[array([0.01934399]),
 array([-0.52152182]),
 array([0.43895457]),
 array([-0.41209333]),
 array([-0.8876564]),
 array([0.46572342]),
 array([1.22547919]),
 array([0.69125323]),
 array([-1.37494949]),
 array([-1.37951803]),
 array([-1.99603114]),
 array([0.07860545]),
 array([-0.03780545]),
 array([0.26214566]),
 array([0.18003701]),
 array([-0.24868734]),
 array([-0.36876575]),
 array([-0.97047288]),
 array([0.22908998]),
 array([1.48644767])]

In [209]:
z = np.array([2.0])
alpha = 0.5
lamb = 1.0
mu_MLE, Sigma_MLE = MLE(Y)
mu_penMLE, Sigma_penMLE = penMLE(Y, z, alpha=alpha, lamb=lamb)
print(mu_MLE, Sigma_MLE)
print(mu_penMLE, Sigma_penMLE)

steps 9
[-0.15602107] [[0.72102852]]
[-0.15701992] [[0.71920803]]


In [210]:
print(loss(Y, z, mu_MLE, Sigma_MLE, alpha=alpha, lamb=lamb))
print(loss(Y, z, mu_penMLE, Sigma_penMLE, alpha=alpha, lamb=lamb))
print(check(Y, z, mu_penMLE, Sigma_penMLE, alpha=alpha, lamb=lamb, times=40))

-6.651718775974659
-6.655803168194078
True
