# Oja's learning rule

<img src="img/oja.png" style='height : 300px'>

In [1]:
import numpy as np

In [2]:
def oja(X, alpha=np.float64(0.01), w=[[-1.], [0.]], epochs=2):
    """ Oja's learning rule
        Parameters
        ----------
        X: np.array
            Feature vectors
        alpha: np.array
            Learning rate
        w: np.array
            Initialised weights
        epochs: int, optional
            Number of training epochs
    """
    # standardization
    mu = np.mean(X, axis=0)
    norm_X = X - mu
    print('norm_X:' ,norm_X)
    w = np.array(w)
    # covariance matrix -> eigenvalues/vectors
    cov_mat = np.cov(norm_X.T)
    eigval, eigvec = np.linalg.eig(cov_mat)
    
    for epoch in range(epochs):
        print('#################')
        print('EPOCH ', epoch)
        print('#################')
        
        y_hat = np.dot(norm_X, w)
        print('y_hat:', y_hat)
        
        x_t_by_yw = norm_X - y_hat*w.T
        print('x.t-yw', x_t_by_yw)
        
        alpha_by_y_by_x_t_by_yw = alpha * y_hat * x_t_by_yw
        print('ny(x.t-yw)', alpha_by_y_by_x_t_by_yw)
        
        w = w + alpha * np.sum(y_hat * norm_X - np.square(y_hat) * w.T,  axis=0).reshape((2, 1))
        print('w:', w)

## Examples

In [3]:
data = np.array([
    [0., 1.], 
    [3., 5.],
    [5., 4.],
    [5., 6.], 
    [8., 7.],
    [9., 7.]
])

In [4]:
oja(data, alpha=0.01, w=[[-1],[ 0]], epochs=2)

norm_X: [[-5. -4.]
 [-2.  0.]
 [ 0. -1.]
 [ 0.  1.]
 [ 3.  2.]
 [ 4.  2.]]
#################
EPOCH  0
#################
y_hat: [[ 5.]
 [ 2.]
 [ 0.]
 [ 0.]
 [-3.]
 [-4.]]
x.t-yw [[ 0. -4.]
 [ 0.  0.]
 [ 0. -1.]
 [ 0.  1.]
 [ 0.  2.]
 [ 0.  2.]]
ny(x.t-yw) [[ 0.   -0.2 ]
 [ 0.    0.  ]
 [ 0.   -0.  ]
 [ 0.    0.  ]
 [-0.   -0.06]
 [-0.   -0.08]]
w: [[-1.  ]
 [-0.34]]
#################
EPOCH  1
#################
y_hat: [[ 6.36]
 [ 2.  ]
 [ 0.34]
 [-0.34]
 [-3.68]
 [-4.68]]
x.t-yw [[ 1.36   -1.8376]
 [ 0.      0.68  ]
 [ 0.34   -0.8844]
 [-0.34    0.8844]
 [-0.68    0.7488]
 [-0.68    0.4088]]
ny(x.t-yw) [[ 0.086496   -0.11687136]
 [ 0.          0.0136    ]
 [ 0.001156   -0.00300696]
 [ 0.001156   -0.00300696]
 [ 0.025024   -0.02755584]
 [ 0.031824   -0.01913184]]
w: [[-0.854344  ]
 [-0.49597296]]
