In [1]:
import numpy as np
import pandas as pd

In [5]:
# A small test data set I used to verify everything is working correctly
beetles_df = pd.read_csv("beetles.txt", sep=" ")

In [10]:
class LinearDiscriminantAnalysis:

    def fit(self, df, label):
        self._df = df
        self._label = label
        self._classes =  df[label].unique()
        self._class_means = self._get_class_means()
        self._cov_matrix = df.drop([label], axis = 1).cov()
        
    def predict(self, X):
        
        X = np.array(X)
        
        predictions = []
                
        for i in range(X.shape[0]):
            
            row = X[i]
                    
            best_score = np.NINF
            best_class = None

            for k in self._classes:
                score = self._discrim_func(row, k)

                if score > best_score:
                    best_score = score
                    best_class = k

            predictions.append(best_class)
        
        if len(predictions) == 1:
            return predictions[0]
        else:
            return predictions
            
            
    def _get_class_means(self):
        class_feature_means = pd.DataFrame()

        for c, rows in self._df.groupby(self._label):
            class_feature_means[c] = rows.mean()

        return class_feature_means.drop([self._label], axis = 0)
    
    def _discrim_func(self, X, k):
    
        pi_k = len(self._df[self._df[self._label] == k]) #[0]

        c_means = self._class_means[k]

        cov_inv = np.linalg.inv(self._cov_matrix)

        result = X.transpose().dot(cov_inv).dot(c_means)
        result = result - (1/2) * c_means.transpose().dot(cov_inv).dot(c_means) 
        result = result + np.log(pi_k)

        return result

In [11]:
lda = LinearDiscriminantAnalysis()
lda.fit(beetles_df, 'Species')

In [12]:
full_X = beetles_df.drop(['Species'], axis = 1)
lda.predict(full_X)

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2]