# Fisher linear discriminant analysis

In [13]:
import numpy as np

In [64]:
def get_fisher_cost(X, y, w):
    """ Fisher's LDA
        X: np.array
            Feature vectors
        y: np.array
            Class vector
        w: np.array
            Initialised weights
    """
    # scatter between classes
    mu = []
    subsets = []
    classes = []
    # get classes
    for c in np.unique(y):
        idxs = np.where(y == c)
        subset = np.array([
            X[idxs[0]]
        ])
        classes.append(c)
        
        for i in range(1, len(idxs)):
            subset = np.concatenate((subset, X[idxs[i]]), axis=0)
        subset = subset[0]
        subsets.append(subset)
        mu.append(np.mean(subset, axis=0))
    classes = np.array(classes)
    subsets = np.array(subsets)
    mu = np.array(mu)
    print('mu:', mu)
    
    # between class scatter
    sb = np.abs(np.dot(w, np.subtract(mu[0], mu[1])))**2
    print('sb: ', sb)
    
    print('subsets')
    print(subsets)
    
    sw = 0
    
    # for each class
    for i, subset in enumerate(subsets):
        class_contr = 0
        # for each member of the class
        for x in subset:
            class_contr += np.dot(w, np.subtract(x, mu[i]))**2
        sw += class_contr
    
    print('sw:', sw)
    
    cost = sb/sw
    
    print('cost:', cost)

## Examples

In [65]:
X = np.array([
    [1, 2], 
    [2, 1],
    [3, 3], 
    [6, 5], 
    [7, 8]
])
y = np.array([1, 1, 1, 2, 2])
w_1 = np.array([-1, 5])
w_2 = np.array([2, -3])

In [66]:
get_fisher_cost(X, y, w_1)

mu: [[2.  2. ]
 [6.5 6.5]]
sb:  324.0
subsets
[array([[1, 2],
       [2, 1],
       [3, 3]])
 array([[6, 5],
       [7, 8]])]
sw: 140.0
cost: 2.3142857142857145


In [67]:
get_fisher_cost(X, y, w_2)

mu: [[2.  2. ]
 [6.5 6.5]]
sb:  20.25
subsets
[array([[1, 2],
       [2, 1],
       [3, 3]])
 array([[6, 5],
       [7, 8]])]
sw: 38.5
cost: 0.525974025974026


as w_1's cost is higher, it's a more effective projection weight

In [71]:
for x in X:
    print(np.dot(x, w_1))

9
3
12
19
33


In [72]:
for x in X:
    print(np.dot(x, w_2))

-4
1
-3
-3
-10


w_1 makes the dataset separable, while w_2 it doesn't succeed in it