In [None]:
import numpy as np

class LinearDiscriminantAnalysis():
    def __init__(self):
        self.priors_ = None

    def _class_means(X, y):
        means = []
        for c in np.unique(y):
            means.append(np.mean(X[y == c], axis=0))
        return np.array(means)

    def _class_cov(X, y, self.priors_):
        cov = np.zeros((X.shape[1], X.shape[1]))
        for c in np.unique(y):
            diff = X[y == c] - self.means_[c]
            cov += self.priors_[c] * np.dot(diff.T, diff)
        return cov

    def fit(self, X, y):
        self.means_ = self._class_means(X, y)
        self.priors_ = np.bincount(y) / len(y)
        self.covariance_ = self._class_cov(X, y, self.priors_)
        self.coef_ = np.linalg.solve(self.covariance_, self.means_.T).T
        self.intercept_ = -0.5 * np.diag(np.dot(self.means_, self.coef_.T)) + np.log(self.priors_)

        # Alternative way to compute the weights and bias
        # self.coef_ = np.dot(self.means_, np.linalg.inv(self.covariance_))
        # self.intercept_ = -0.5 * np.diag(np.dot(np.dot(self.means_, np.linalg.inv(self.covariance_)), self.means_.T)) + np.log(self.priors_)
        return self