In the following we will use the derivations from the simple FA model
to obtain a probabilistic version of CCA.

In [2]:
%matplotlib inline

In [4]:
import numpy as np
import matplotlib.pyplot as plt

Let $X$ and $Y$ be two aligned data sets of dimensions $p_x \times n$ and $p_y \times n$.
We refer to the corresponding $i_{th}$ observations as $x_i$, $y_i$ -- these are column
vectors of dimensions $p_x$, and $p_y$ respectively. We consider a generative model under
which there exist latent variables $z_i \in \mathbb{R}^k$ with $i \in \{1, 2, \ldots n\}$, 
such that:

\begin{align}
p(x_i | z_i)  &= \mathcal{N} (W_x z_i + \mu_x, \Psi_x) \\
p(y_i | z_i)  &= \mathcal{N} (W_y z_i + \mu_y, \Psi_y) \\
p(z_i) &=  \mathcal{N}(0, I_{k})
\end{align}



Here $W_x \in \mathbb{R}^{p_x \times k}$, and $W_y \in \mathbb{R}^{p_y \times k}$.


The updates for this model follow from a factor model with

\begin{align}
\Lambda = \begin{bmatrix}
    W_x \\
    W_y
\end{bmatrix},
\Psi = \begin{bmatrix}
    \Psi_x & 0 \\
    0 & \Psi_y
    \end{bmatrix},\
 v_i = (x_i; y_i) \in \mathbb{R}^{p_x + p_y}
\end{align}

In particular, letting, 
\begin{align}
\beta = \Lambda^\top (\Lambda \Lambda^\top + \Psi)^{-1},
\end{align}

The latent expectation and covariance are given by the formulas

\begin{align}
\mathbb{E}[z | v_i] &= \beta v_i \\
\mathbb{E}[z z^\top | v_i] &= I_{k} - \beta \Lambda + \beta v_i v_i^\top \beta^\top,
\end{align}

Putting everything together,

\begin{align}
\Lambda^* & = \left(\sum_{i=1}^n v_i \mathbb{E}[z|v_i]^\top \right) \left(\sum_{i=1}^n \mathbb{E}[z z^\top | v_i]\right)^{-1} \\
\Psi^* & = \frac{1}{n} diag(\sum_{i=1}^n v_i v_i^\top - \Lambda^* \mathbb{E}[z|v_i] v_i^\top)
\end{align}

In [5]:
def gen_simple_dataset(p_x, p_y, k, n, sigma_x, sigma_y):
    
    X     = np.zeros((p_x, n))
    Y     = np.zeros((p_y, n))
    
    W_x   = np.random.random((p_x, k))
    W_y   = np.random.random((p_y, k))
    Z     = np.random.random((k, n))
    
    m_x   = np.dot(W_x, Z)
    m_y   = np.dot(W_y, Z)
    
    
    Psi_x = sigma_x * np.eye(p_x)
    Psi_y = sigma_y * np.eye(p_y)
    
    for i in range(n):
        X[:,i] = np.random.multivariate_normal(mean = m_x[:,i], cov = Psi_x)
        Y[:,i] = np.random.multivariate_normal(mean = m_y[:,i], cov = Psi_y)
        
    return X, Y, Z, W_x, W_y

In [6]:
def E_z_given_v_i_Murphy(Lambda, Psi, vi):
    
    beta = np.linalg.inv(np.dot(Lambda, Lambda.T) + Psi)
    beta = np.dot(Lambda.T, beta)
    
    return np.dot(beta, vi)

def E_zzT_give_v_i_Murphy(Lambda, Psi, vi):
    
    beta  = np.linalg.inv(np.dot(Lambda, Lambda.T) + Psi)
    beta  = np.dot(Lambda.T, beta)
    
    _, k = Lambda.shape
    
    bv    = np.dot(beta, vi)
    bvvb  = np.dot(bv, bv.T)
    
    return  np.eye(k) - np.dot(beta, Lambda) + bvvb


def E_z_given_v_i_Bishop(Lambda, Psi, vi):
    
    LT_P_L = np.dot(Lambda.T, np.dot(np.linalg.inv(Psi), Lambda)) #12.66, 12.67, 12.68
    G      = np.linalg.inv(np.eye(LT_P_L.shape[0]) + LT_P_L) 
    
    beta   = np.dot(G, np.dot(Lambda.T, np.linalg.inv(Psi))) 
    
    return np.dot(beta, vi)

def E_zzT_give_v_i_Bishop(Lambda, Psi, vi):
    
    LT_P_L = np.dot(Lambda.T, np.dot(np.linalg.inv(Psi), Lambda))
    G      = np.linalg.inv(np.eye(LT_P_L.shape[0]) + LT_P_L) 
    
    E_z    =  E_z_given_v_i_Bishop(Lambda, Psi, vi)
    
    return G + np.dot(E_z, E_z.T)

In [7]:
def M_step(Lambda, Psi, V):
    
    '''
    Psi   : (p_x + p_y) by (p_x + p_y)  
    Lambda: (p_x + p_y) by k 
    V     : (p_x + p_y) by n
    '''
    
    p, k         = Lambda.shape
    
    # update lambda
    # -------------------
    
    # These are the two terms in the Lambda update
    Lambda_new_1 = np.zeros((p, k))
    Lambda_new_2 = np.zeros((k, k))

    for i in range(n):
        
        # Expectation terms
        # The difference between Bishop and Murphy is that 
        # Bishop's derivation uses the Woodbury identity (see G&H)
        # while Murphy implementation just uses Numpy's built-in inverse
        # function.
        Exp_i        = E_z_given_v_i_Bishop(Lambda, Psi, V[:,i,None])
        Cov_i        = E_zzT_give_v_i_Bishop(Lambda, Psi, V[:,i,None])
        
        Lambda_new_1 += np.dot(V[:,i,None], Exp_i.T)
        Lambda_new_2 += Cov_i
    
    Lambda_star  = np.dot(Lambda_new_1, np.linalg.inv(Lambda_new_2)) 
    
    # update psi
    # -------------------
    Psi_new      = np.zeros(Psi.shape)
    for i in range(n):
        Exp_i        = E_z_given_v_i_Bishop(Lambda, Psi, V[:,i,None])
        Psi_new      = Psi_new + np.dot(V[:,i,None], V[:,i,None].T) - np.dot(Lambda_star, np.dot(Exp_i, V[:,i,None].T))

    Psi_star     = 1./n * np.diag(np.diag(Psi_new))
    
    return Lambda_star, Psi_star

In [9]:
#test
#generate dataset

p_x = 50
p_y = 30
k = 10
n = 1000
sigma_x = 1.0
sigma_y = 1.5

X, Y, Z, W_x, W_y = gen_simple_dataset(p_x, p_y, k, n, sigma_x, sigma_y)
print(X.shape)

(50, 1000)


In [7]:
#initialize
sigma_init = 0.5
W_x_init   = np.random.random((p_x, k))
W_y_init   = np.random.random((p_y, k))

Psi_x_init = sigma_x * np.eye(p_x)
Psi_y_init = sigma_y * np.eye(p_y)

In [8]:
V = np.concatenate((X, Y), axis=0)
V.shape

(80, 1000)

In [9]:
Lambda_init = np.concatenate((W_x_init, W_y_init), axis=0)
Lambda_init.shape

(80, 10)

In [10]:
Psi_init = np.block([[Psi_x_init, np.zeros((p_x, p_y))], [np.ones((p_y, p_x)), Psi_y_init]])
Psi_init.shape

(80, 80)

In [17]:
iters      = 100
Lambda_old = Lambda_init
Psi_old    = Psi_init

for _ in range(iters):
    print(_, " iteration")
    
    Lambda_new, Psi_new = M_step(Lambda_old, Psi_old, V)
    Lambda_old = Lambda_new
    Psi_old    = Psi_new


0  iteration
1  iteration
2  iteration
3  iteration
4  iteration
5  iteration
6  iteration
7  iteration
8  iteration
9  iteration
10  iteration
11  iteration
12  iteration
13  iteration
14  iteration
15  iteration
16  iteration
17  iteration
18  iteration
19  iteration
20  iteration
21  iteration
22  iteration
23  iteration
24  iteration
25  iteration
26  iteration
27  iteration
28  iteration
29  iteration
30  iteration
31  iteration
32  iteration
33  iteration
34  iteration
35  iteration
36  iteration
37  iteration
38  iteration
39  iteration
40  iteration
41  iteration
42  iteration
43  iteration
44  iteration
45  iteration
46  iteration
47  iteration
48  iteration
49  iteration
50  iteration
51  iteration
52  iteration
53  iteration
54  iteration
55  iteration
56  iteration
57  iteration
58  iteration
59  iteration
60  iteration
61  iteration
62  iteration
63  iteration
64  iteration
65  iteration
66  iteration
67  iteration
68  iteration
69  iteration
70  iteration
71  iteration
72

In [18]:
#Psi_new[:p_x, :p_x]
#Psi_new[p_x:, p_x:]

In [19]:
Z_est = np.zeros(Z.shape)
for i in range(Z.shape[1]):
    Z_est[:,i] = E_z_given_v_i_Bishop(Lambda_new, Psi_new, V[:,i])

In [20]:
Z[:5,:10]

array([[ 0.61800372,  0.67704729,  0.2996038 ,  0.93629477,  0.21101635,
         0.71937384,  0.84181225,  0.65311834,  0.89782692,  0.00519197],
       [ 0.66799676,  0.62337837,  0.03383087,  0.4436514 ,  0.31945389,
         0.60760997,  0.15943079,  0.12650056,  0.54836049,  0.61637373],
       [ 0.46972741,  0.476861  ,  0.69091939,  0.18208955,  0.91634602,
         0.13338135,  0.83745667,  0.21151824,  0.5781412 ,  0.23121   ],
       [ 0.34601919,  0.12237393,  0.83612163,  0.50926918,  0.44824805,
         0.8002443 ,  0.28162397,  0.32801907,  0.94693596,  0.88070894],
       [ 0.20771045,  0.62730002,  0.66165557,  0.48189134,  0.71685325,
         0.42670958,  0.62194144,  0.31624676,  0.5961284 ,  0.46497062]])

In [21]:
Z_est[:5,:10]

array([[-0.42169936,  0.30292641,  0.30286639,  1.15133226, -0.8116444 ,
         0.03960725, -0.07976182,  1.37369668,  0.0299327 , -0.19885088],
       [ 0.93851271,  0.72200708,  1.22059855,  1.02445259,  1.27681264,
         0.88396021, -0.41274453,  0.86941553, -0.62026521,  0.21252015],
       [ 0.16143269,  0.1501104 ,  0.30418442,  0.23048693,  0.82848035,
         0.22267908,  0.69063287,  0.25169303, -0.15031354,  0.89182338],
       [ 0.66459168,  0.23214191, -0.54130598,  0.73900007,  1.5053348 ,
         0.46537987, -0.33090441, -0.47185377,  0.81399271, -0.17601385],
       [-0.14242478,  0.81189462, -0.01998319,  1.58702391,  0.31319139,
         1.03841325,  1.81867443,  0.88254865,  0.9427338 ,  0.6998384 ]])

In [22]:
#fig, ax = plt.subplots(figsize=(16, 2))
#ax.imshow(Z[:,:30], interpolation='nearest')
#plt.show()
#fig, ax = plt.subplots(figsize=(16, 2))
#ax.imshow(Z_est[:,:30], interpolation='nearest')
#plt.tight_layout()

#fig, ax = plt.subplots(figsize=(16, 2))
#ax.imshow(Z_est[:,:30] - Z[:,:30], interpolation='nearest')
#plt.tight_layout()

In [23]:
np.linalg.norm(Z_est - Z)

70.774693707826302