# LDA 线性判别

<a 
   href='https://www.bilibili.com/video/BV1k5411T7u5/?spm_id_from=333.337.search-card.all.click&vd_source=36a09b82c71dff08b1927ad8f0c1d3e3'
   target='_bank'>关于西瓜书上LDA推导式的由来</a>
   

In [1]:
import numpy as np

In [2]:
class LDA():
    def __init__(self):
        self.w = None
        
        def calculate_covariance_matrix(self, X, Y=None):
            # 计算协方差矩阵
            m = X.shape[0]
            X = X - np.mean(X, axis=0)
            Y = Y if Y == None else Y - np.mean(Y, axis = 0)
            return 1 / m * np.matmul(X.T, Y)
        
    # LDA拟合过程
    def fit(self, X, y):
        
        # 1 - 分类
        X0 = X[y == 0]
        X1 = X[y == 1]
        
        # 2 - 计算均值、方差
        u0, u1 = X1.mean(0), X2.mean(0)
        mean_diff = np.atleast_1d(u0 - u1)
        
        # 3 - 计算类内散度矩阵
        sigma0 = self.calculate_covariance_matrix(X0)
        sigma1 = self.calculate_covariance_matrix(X1)
        Sw = sigma0 + sigma1
        
          # 4 - 对类内散度矩阵进行奇异值分解
        U, S, V = np.linalg.svd(Sw)
        
        # 5 - 计算类内散度矩阵的逆
        Sw_ = np.dot(np.dot(V.T, np.linalg.pinv(S)), U.T)
        
        # 6 - 计算w
        self.w = Sw_.dot(mean_diff)
        
        
     # 对数据进行向量转换
    def transform(self, X, y):
        self.fit(X, y)
        X_transform = X.dot(self.w)
        return X_transform
    
    # LDA分类预测
    def predict(self, X):
        y_pred = []
        for sample in X:
            h = sample.dot(self.w)
            y = 1 * (h < 0)
            y_pred.append(y)
        return y_pred