In [None]:
import numpy as np
from matplotlib import animation, pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.linalg import inv
import scipy.stats as stats
import scipy.interpolate as interp
from scipy.optimize import minimize, check_grad, approx_fprime
from sympy import *

# import warnings
# warnings.filterwarnings('ignore')

from sympy.interactive import printing
printing.init_printing(use_latex=True)

In [None]:
N = 100

sigma_y = 1
mu_y = 2*np.ones(N) #4 + 2*np.sin(np.arange(N)/5 + 2)
y = mu_y + np.random.randn(N)*np.sqrt(sigma_y)

alpha = 0.9
Da = np.eye((N))
for i in range(N-1):
    Da[i+1,i] = -alpha

Da_inv = np.linalg.inv(Da)
a = Da_inv @ y


mu_m = 5.0*np.ones((N))
lambda_m = 25
sigma_m = np.zeros((N,N))

for i in np.arange(N):
    for j in np.arange(i+1):
        sigma_m[i][j] = np.exp(-(i - j)**2/lambda_m)
        sigma_m[j][i] = np.exp(-(i - j)**2/lambda_m)

        
sigma_m = sigma_m + np.diag(np.ones(N) * 0.01)
Sm_inv = np.linalg.inv(sigma_m)
        
m = np.random.multivariate_normal(mu_m, sigma_m)

sigma_r = 1
sigma_g = 1

rfp = m + np.random.randn(N)*np.sqrt(sigma_r)
gcamp = m*a + np.random.randn(N)*np.sqrt(sigma_g)

y_guess = Da @ (gcamp/rfp)

In [None]:
plt.plot(a)
plt.plot(y_guess)
plt.show()

In [None]:
def Dalpha_negloglike(y, g, r, sr, sg, sy, Sm_inv, Da_inv, mu_m, mu_y):
    N = len(y)
    
    Da_y = np.diag(Da_inv @ y)
    p0 = (1.0/sr)*np.eye(N) + (1.0/sg)*Da_y**2 + Sm_inv
    p0_inv = np.linalg.inv(p0)
    
    p1 = 0.5 * g.T @ Da_y @ p0_inv @ Da_y @ g / (1.0/sg)**2
    p2 = (1.0/sg) * g.T @ Da_y @ p0_inv @ (r/sr + Sm_inv @ mu_m)
    p3 = -(1.0/(2*sy)) * y.T @ y + (1.0/sy) * y.T @ mu_y
    (sign, logdet) = np.linalg.slogdet(p0)
    p4 = -0.5 * sign * logdet
    
    return -(p1 + p2 + p3 + p4)

In [None]:
q = minimize(Dalpha_negloglike, y_guess, args=(gcamp,rfp,sigma_r, sigma_g, sigma_y, Sm_inv, Da_inv, mu_m, mu_y))#, method='CG', options={'maxiter':10000})

In [None]:
plt.plot(y, color = 'blue')
plt.plot(y_guess, color = 'red')
plt.plot(q.x, color = 'cyan')
plt.show()

In [None]:
plt.plot(a, color = 'blue')
plt.plot(gcamp/rfp, color = 'red')
plt.plot(Da_inv @ q.x, color = 'cyan')
plt.show()

In [None]:
qq = np.arange(4)
np.dot(qq,qq.T)

In [None]:
N = 100

mu_m = 3.0*np.ones((N))
lambda_m = 10
sigma_m = np.zeros((N,N))

mu_a = 3.0*np.ones((N))
lambda_a = 10
sigma_a = np.zeros((N,N))

for i in np.arange(N):
    for j in np.arange(i+1):
        sigma_m[i][j] = np.exp(-(i - j)**2/lambda_m)
        sigma_m[j][i] = np.exp(-(i - j)**2/lambda_m)
        sigma_a[i][j] = np.exp(-(i - j)**2/lambda_a)
        sigma_a[j][i] = np.exp(-(i - j)**2/lambda_a)

sigma_m = sigma_m + np.diag(np.ones(N) * 0.01)        
sigma_a = sigma_a + np.diag(np.ones(N) * 0.01)        
        
m = np.random.multivariate_normal(mu_m, sigma_m)
a = np.random.multivariate_normal(mu_a, sigma_a)

sigma_r = 0.2
sigma_g = 0.2

rfp = m + np.random.randn(N)*np.sqrt(sigma_r)
gcamp = m*a + np.random.randn(N)*np.sqrt(sigma_g)

In [None]:
def step1(a, g, r, sr, sg, Sm, Sa, mu_m, mu_a):
    N = len(a)
    p1 = -0.5 * (a - mu_a).T @ inv(Sa) @ (a - mu_a)
    p2_0 = r/sr + np.diag(a) @ g / sg + inv(Sm) @ mu_m
    p2_1 = np.eye(N)/sr  + np.diag(a**2 / sg) + inv(Sm)
    p2 = 0.5 * p2_0.T @ inv(p2_1) @ p2_0
    (sign, logdet) = np.linalg.slogdet(p2_1)
    p3 = -0.5 * sign * logdet
    return -(p1 + p2 + p3)

def step2(sigmas, a, g, r, Sm, Sa, mu_m, mu_a):
    N = len(a)
    sr, sg = sigmas
    
    p1 = -0.5 * (N*np.log(sr) + N*np.log(sg) + (r.T @ r)/sr + (g.T @ g)/sg)
    p2_0 = r/sr + np.diag(a) @ g / sg + inv(Sm) @ mu_m
    p2_1 = np.eye(N)/sr  + np.diag(a**2 / sg) + inv(Sm)
    p2 = 0.5 * p2_0.T @ inv(p2_1) @ p2_0
    (sign, logdet) = np.linalg.slogdet(p2_1)
    p3 = -0.5 * sign * logdet
    print(p1, p2, p3, -(p1 + p2 + p3))
    return -(p1 + p2 + p3)

def step12(sigmas_a, g, r, Sm, Sa, mu_m, mu_a):
    N = len(g)
    sr, sg = sigmas_a[0], sigmas_a[1]
    a = sigmas_a[2:]
    
    p0 = -0.5 * (a - mu_a).T @ inv(Sa) @ (a - mu_a)
    p1 = -0.5 * (0.5*N*np.log(sr**2) + 0.5*N*np.log(sg**2) + (r.T @ r)/sr + (g.T @ g)/sg)
    p2_0 = r/sr + np.diag(a) @ g / sg + inv(Sm) @ mu_m
    p2_1 = np.eye(N)/sr  + np.diag(a**2 / sg) + inv(Sm)
    p2 = 0.5 * p2_0.T @ inv(p2_1) @ p2_0
    (sign, logdet) = np.linalg.slogdet(p2_1)
    p3 = -0.5 * sign * logdet
#     print(p0, p1, p2, p3, -(p1 + p2 + p3))
    return -(p0 + p1 + p2 + p3)
    

In [None]:
noisy_a = a + np.random.randn(N)*0.5
q = minimize(step1, gcamp/rfp, args=(gcamp,rfp,sigma_r, sigma_g, sigma_m, sigma_a, mu_m, mu_a))#, method='CG', options={'maxiter':10000})

In [None]:
sigmas = np.hstack(([np.var(rfp), np.var(gcamp)], gcamp/rfp))
q3 = minimize(step12, sigmas, args=(gcamp,rfp,sigma_m, sigma_a, mu_m, mu_a))#, method='CG', options={'maxiter':10000})

In [None]:
q2 = minimize(step2, [0.01,0.01], args=(a,gcamp,rfp, sigma_m, sigma_a, mu_m, mu_a), method='L-BFGS-B', bounds=((.0001,None),(.0001,None)))#, method='CG', options={'maxiter':10000})

In [None]:
plt.plot(a)
# plt.plot(noisy_a, color= 'green')
plt.plot(q.x, color = 'red')
plt.show()

In [None]:
q2

In [None]:
np.var(gcamp)

In [None]:
qq = np.arange(10)
qq[2:]

In [None]:
rfp.T @ rfp

In [None]:
plt.plot(m)
plt.show()

In [None]:
np.eye(3)*4

In [None]:
a, v_r, v_g, r, g = symbols("a v_r v_g r g")
expr = -Rational(1/2)*log(a**2*v_g + v_r) + Rational(1/2)*(v_r*r + v_g*g*a)**2/(a**2*v_g + v_r)
deriv = diff(expr,a)
numer = fraction(factor(deriv))[0]

In [None]:
roots = solve(Eq(numer,0),a)

In [None]:
f_compl = lambdify((v_r, v_g, r, g), roots[2], "numpy")

def f(v_r, v_g, r, g):
    return f_compl(complex(v_r),complex(v_g),r.astype(complex),g.astype(complex))

In [None]:
def loglike(sigma, r, g, a, hyper):
    vr, vg = sigma
    ar, ag, br, bg = hyper
    T = len(a)
    
    logL0 = -(br*vr + bg*vg) + (ar - 1.0 + T/2)*np.log(vr) + (ag - 1.0 + T/2)*np.log(vg)
    logL1 = np.sum(-0.5*(vg*g**2 + vr*r**2))
    logL2 = np.sum(-0.5*np.log(vr + vg*a**2))
    logL3 = np.sum(0.5*((vr*r + vg*g*a)**2)/(vr + vg*a**2))
    
    all_logL = logL0 + logL1 + logL2 + logL3
    print(vr , vg, logL0, logL1 + logL2 + logL3)
    return -all_logL

def loglike_grad(sigma, r, g, a, hyper):
    vr, vg = sigma
    ar, ag, br, bg = hyper
    T = len(a)
    
    vr_logL0 = -br + (ar - 1.0 + T/2)/vr
    vr_logL1 = np.sum(-0.5*r**2)
    vr_logL2 = np.sum(-0.5*(1/(vr + vg*a**2)))
    vr_logL3 = np.sum(0.5*((vr + vg*a**2)*(2*vr*r**2 + 2*r*vg*g*a) - (vr*r + vg*g*a)**2)/(vr + vg*a**2)**2)
    vr_all_logL = vr_logL0 + vr_logL1 + vr_logL2 + vr_logL3
    
    vg_logL0 = -bg + (ag - 1.0 + T/2)/vg
    vg_logL1 = np.sum(-0.5*g**2)
    vg_logL2 = np.sum(-0.5*(a**2/(vr + vg*a**2)))
    vg_logL3 = np.sum(0.5*((vr + vg*a**2)*(2*vr*r*g*a + 2*vg*(g**2)*(a**2)) - (a**2)*(vr*r + vg*g*a)**2)/(vr + vg*a**2)**2)
    vg_all_logL = vg_logL0 + vg_logL1 + vg_logL2 + vg_logL3
    
    return np.array([-vr_all_logL, -vg_all_logL])

In [None]:
N = 10000

activity = np.array([np.sin(i/10.0) for i in range(N)]) + 3
motion = 3*np.array([np.sin(i/10.0 + np.pi/4) for i in range(N)]) + 6
nu_r = 50.0
nu_g = 50.0

hyper = [1.0 ,1.0, 1.0, 1.0]

rfp = motion + np.random.randn(N)/np.sqrt(nu_r)
gcamp = motion*activity + np.random.randn(N)/np.sqrt(nu_g)

In [None]:
vr_guess = [50.0]
vg_guess = [50.0]
activity_guess = []

for i in range(2):
    
    act = np.real(f(vr_guess[-1], vg_guess[-1], rfp, gcamp))
    q = minimize(loglike, [vr_guess[-1], vg_guess[-1]], jac=loglike_grad, args=(rfp,gcamp,act,hyper), 
                 method='L-BFGS-B', bounds=((.0001,None),(.0001,None)))
    
    vr_guess.append(q.x[0])
    vg_guess.append(q.x[1])
    activity_guess.append(act)
    break


In [None]:
q

In [None]:
plt.plot(act, color='red')
plt.plot(activity, color='blue')
plt.show()

In [None]:
heat = np.zeros((50,50))

minval = -100
maxval = 1

vr_test = np.linspace(minval,maxval,50)
vg_test = np.linspace(minval,maxval,50)

for i_ind, i in enumerate(vr_test):
    for j_ind, j in enumerate(vg_test):
        test_hyper=[i,i,j,j]
        heat[i_ind,j_ind] = loglike([50.0,50.0], rfp,gcamp,act,test_hyper)

In [None]:
plt.imshow(heat, cmap='hot', extent=[minval,maxval,maxval,minval], interpolation='none')
plt.colorbar()
plt.show()

In [None]:
from scipy.stats import gamma

xx = np.arange(0,1,.001)

In [None]:
check_grad(loglike, loglike_grad, [ 4.43415611,  4.06254183], rfp,gcamp,act,hyper)

In [None]:
hyper = [5.0 ,5.0, 1.0, 1.0]

approx_fprime(np.array([ 4.43415611,  4.06254183]), loglike, .0000001, rfp,gcamp,act, hyper)

In [None]:
loglike_grad([ 4.43415611,  4.06254183], rfp, gcamp, act, hyper)

In [None]:
a, v_r, v_g, r, g = symbols("a v_r v_g r g")
expr = -Rational(1/2)*log(a**2*v_g + v_r) + Rational(1/2)*(v_r*r + v_g*g*a)**2/(a**2*v_g + v_r)
deriv = diff(expr,a)
numer = fraction(factor(deriv))[0]

In [None]:
roots = solve(Eq(numer,0),a)

In [None]:
f_compl = lambdify((v_r, v_g, r, g), roots[2], "numpy")

def f(v_r, v_g, r, g):
    return f_compl(complex(v_r),complex(v_g),r.astype(complex),g.astype(complex))

In [None]:
def loglike(sigma, r, g, a):
    vr, vg = sigma
    logL1 = -(vg*g**2 + vr*r**2)
    logL2 = ((vr*r + vg*g*a)**2)/(vr + vg*a**2)
    logL3 = np.log((vr*vg)/(vr + vg*a**2))
    all_logL = 0.5*np.sum(logL1 + logL2 + logL3)
    return -all_logL

def loglike_grad(sigma, r, g, a):
    vr, vg = sigma
    
    vr_logL1 = -r**2
    vr_logL2 = ((vr + vg*a**2)*(2*vr*r**2 + 2*r*vg*g*a) - (vr*r + vg*g*a)**2)/(vr + vg*a**2)**2
    vr_logL3 = (vg*a**2)/(vr*(vr + vg*a**2))
    vr_all_logL = 0.5*np.sum(vr_logL1 + vr_logL2 + vr_logL3)
    
    vg_logL1 = -g**2
    vg_logL2 = ((vr + vg*a**2)*(2*vr*r*g*a + 2*vg*(g**2)*(a**2)) - (a**2)*(vr*r + vg*g*a)**2)/(vr + vg*a**2)**2
    vg_logL3 = vr/(vg*(vr + vg*a**2))
    vg_all_logL = 0.5*np.sum(vg_logL1 + vg_logL2 + vg_logL3)
    
    return np.array([-vr_all_logL, -vg_all_logL])

In [None]:
check_grad(loglike, loglike_grad, [10.0,10.0], rfp,gcamp,act)

In [None]:
N = 100

activity = np.array([np.sin(i/10.0) for i in range(N)]) + 3
motion = 3*np.array([np.sin(i/10.0 + np.pi/4) for i in range(N)]) + 6
nu_r = 100
nu_g = 100

rfp = motion + np.random.randn(N)/np.sqrt(nu_r)
gcamp = motion*activity + np.random.randn(N)/np.sqrt(nu_g)

In [None]:
vr_guess = [100.0]
vg_guess = [100.0]
activity_guess = []

for i in range(2):
    
    act = np.real(f(vr_guess[-1], vg_guess[-1], rfp, gcamp))
    q = minimize(loglike, [vr_guess[-1], vg_guess[-1]], jac=loglike_grad, args=(rfp,gcamp,activity))
    
    vr_guess.append(q.x[0])
    vg_guess.append(q.x[1])
    activity_guess.append(act)


In [None]:
q

In [None]:
plt.plot(act, color='green')
plt.plot(activity, color = 'blue')
#plt.plot(rfp, color='red')

plt.tight_layout()
plt.show()

In [None]:
heat = np.zeros((50,50))

minval = 50
maxval = 1250

vr_test = np.linspace(minval,maxval,50)
vg_test = np.linspace(minval,maxval,50)

for i_ind, i in enumerate(vr_test):
    for j_ind, j in enumerate(vg_test):
        heat[i_ind,j_ind] = loglike([i,j], rfp,gcamp,act)

In [None]:
plt.imshow(heat, cmap='hot', extent=[minval,maxval,maxval,minval], interpolation='none')
plt.colorbar()
plt.show()

In [None]:
a, s_r, s_g, s_m, mu, r, g = symbols("a s_r s_g s_m mu r g")
expr = -Rational(1/2)*log(a**2*s_m + s_g) - Rational(1/2)*(g - a*mu)**2/(a**2*s_m + s_g)
deriv = diff(expr,a)
numer = fraction(factor(deriv))[0]
roots = solve(Eq(numer,0),a)

In [None]:
f_compl = lambdify((s_g, s_m, mu, r, g), roots[2], "numpy")

def f(s_g, s_m, mu, r, g):
    return f_compl(complex(s_g),complex(s_m),complex(mu),r.astype(complex),g.astype(complex))

In [None]:
def loglike(sigma, rfp, gcamp, act):
    
    logL1 = len(act)*np.log(sigma[0] + sigma[2])
    logL2 = np.sum(((rfp - sigma[3])**2)/(sigma[0] + sigma[2]))
    logL3 = np.sum(np.log(act**2 * sigma[2] + sigma[1]))
    logL4 = np.sum(((gcamp - act*sigma[3])**2)/(act**2 * sigma[2] + sigma[1]))
    all_logL = -0.5*(logL1 + logL2 + logL3 + logL4)
    return -np.real(all_logL)

In [None]:
N = 500

sigma_r2 = 1.8
sigma_g2 = 1.5
sigma_m2 = 2.3
mu = 10

activity = np.array([np.sin(i/10.0) for i in range(N)]) + 3
motion = np.sqrt(sigma_m2)*np.random.randn(N) + 10


rfp = motion + np.sqrt(sigma_r2)*np.random.randn(N)
gcamp = motion*activity + np.sqrt(sigma_g2)*np.random.randn(N)

In [None]:
sr_guess = [1.8]
sg_guess = [1.5]
sm_guess = [2.3]
mu_guess = [10]
activity_guess = []

for i in range(5):
    
    act = np.real(f(sg_guess[-1], sm_guess[-1], mu_guess[-1], rfp, gcamp))
    q = minimize(loglike, [sr_guess[-1], sg_guess[-1], sm_guess[-1], mu_guess[-1]],
                 args=(rfp,gcamp,activity))
    
    sr_guess.append(q.x[0])
    sg_guess.append(q.x[1])
    sm_guess.append(q.x[2])
    mu_guess.append(q.x[3])
    activity_guess.append(act)
    break


In [None]:
a, s_r, s_g, r, g = symbols("a s_r s_g r g")
expr = -Rational(1/2)*log(a**2*s_r**2 + s_g**2) - Rational(1/2)*(g - a*r)**2/(a**2*s_r**2 + s_g**2)
deriv = diff(expr,a)
numer = fraction(factor(deriv))[0]

In [None]:
roots = solve(Eq(numer,0),a)

In [None]:
roots[2]

In [None]:
f_compl = lambdify((s_r, s_g, r, g), roots[2], "numpy")

def f(s_r, s_g, r, g):
    return f_compl(complex(s_r),complex(s_g),r.astype(complex),g.astype(complex))

In [None]:
def loglike(sigma, rfp,gcamp,act):
    all_loglike1 = np.log(act**2 / sigma[0] + 1/sigma[1])
    all_loglike2 = ((gcamp - act*rfp)**2)/(act**2 / sigma[0] + 1/sigma[1])
    all_loglike = -0.5*(all_loglike1 + all_loglike2)
    return -np.real(np.sum(all_loglike))

In [None]:
N = 100

activity = np.array([np.sin(i/10.0) for i in range(N)]) + 3
motion = 3*np.array([np.sin(i/10.0 + np.pi/4) for i in range(N)]) + 6
sigma_r = 0.8
sigma_g = 0.5

rfp = motion + sigma_r*np.random.randn(N)
gcamp = motion*activity + sigma_g*np.random.randn(N)

In [None]:
act = np.real(f(1,1, rfp, gcamp))

In [None]:
sr_guess = [0.8]
sg_guess = [0.5]
activity_guess = []

for i in range(5):
    
    act = np.real(f(sr_guess[-1], sg_guess[-1], rfp, gcamp))
    q = minimize(loglike, [sr_guess[-1], sg_guess[-1]], args=(rfp,gcamp,act))
    
    sr_guess.append(q.x[0])
    sg_guess.append(q.x[1])
    activity_guess.append(act)
    break


In [None]:
q

In [None]:
plt.plot((gcamp - activity*rfp)**2, color = 'green')
plt.plot((gcamp - act*rfp)**2, color = 'blue')
plt.show()

In [None]:
((gcamp - act*rfp)**2).shape

In [None]:
solve(Eq(sr_diff,0),s_g)

In [None]:
plt.plot(act, color='green')
plt.plot(activity, color = 'blue')
# plt.plot(act2, color = 'red')
plt.show()

In [None]:
plt.plot(rfp)
plt.plot(gcamp)
plt.show()

In [None]:
def root2(s_r, s_g, r, g):
    s_r, s_g, r, g = complex(s_r),complex(s_g),r.astype(complex),g.astype(complex)
    alpha = g*r/(3*s_r**2)
    beta = (-1 + np.sqrt(3)*1j)/2.0
    gamma = (1/(3*s_r**4))*(-g**2 * s_r**2 + r**2 * s_g**2 + s_g**2 * s_r**2)
    delta = -alpha**2 + gamma
    eta = alpha*(alpha**2 - (3*s_g**2)/(2*s_r**2) - (3*gamma)/2)
    
#     sub_mu = eta + (delta**3 + eta**2)**(1/2)
#     new_mu = np.zeros((sub_mu.shape))
    
#     for j,k in enumerate(sub_mu):
#         if k < 0:
#             new_mu_val = [i for i in np.roots([1,0,0,-k]) if np.imag(i) < 0][0]
#             sub_mu[j] = new_mu_val
#             print('Positive')
        
#     mu = beta*sub_mu**(1/3)

    mu = beta * (eta + (delta**3 + eta**2)**(1/2))**(1/3)
#     print("a",alpha,"b",beta,"g",gamma,"d",delta,"e",eta,"m",mu)
    
#     return alpha, beta, gamma, delta, eta, mu, -alpha + delta/mu - mu
    return -alpha + delta/mu - mu

In [None]:
heat = np.zeros((50,50))

sr_test = np.linspace(0.1,1.0,50)
sg_test = np.linspace(0.1,1.0,50)

for i_ind, i in enumerate(sr_test):
    for j_ind, j in enumerate(sg_test):
        heat[i_ind,j_ind] = loglike([i,j], rfp,gcamp,act)
        break
    break

In [None]:
plt.imshow(heat, cmap='hot')
plt.colorbar()
plt.show()

In [None]:
sr_guess

In [None]:
deriv

In [None]:
diff(expr,s_r)

In [None]:
rfp[38:42]

In [None]:
gcamp.shape

In [None]:
alpha, beta, gamma, delta, eta, mu, ans = root2(0.2,1.8,rfp[40],gcamp[40])

In [None]:
alpha2, beta2, gamma2, delta2, eta2, mu2, ans2 = root2(0.2,1.8,rfp[41],gcamp[41])

In [None]:
alpha3, beta3, gamma3, delta3, eta3, mu3, ans3 = root2(0.2,1.8,rfp[39],gcamp[39])

In [None]:
((delta**3 + eta**2)**(1/2) + eta)

In [None]:
mu22 = (-(-((delta2**3 + eta2**2)**(1/2) + eta2))**(1./3))*beta2

In [None]:
mu33 = (-(-((delta3**3 + eta3**2)**(1/2) + eta3))**(1./3))*beta3

In [None]:
-alpha + delta/mu11 - mu11

In [None]:
np.roots([1,0,0,-2306.5378570205648])

In [None]:
mu11 = (13.21255665 +0.j)*beta

In [None]:
mu11

In [None]:
ww = np.array([1,2,3 + 4j, 4 - 5j, 6])
[i for i in ww if np.imag(i) < 0][0][0]

In [None]:
roots[2]