In [1]:
from sklearn.datasets import *
from sklearn.model_selection import *
from sklearn.metrics import *
import numpy as np

In [2]:
iris = load_iris()
X, y = iris.data, iris.target
class_names = iris.target_names

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 1)

In [6]:
class NaiveBayes:
    def fit(self, X_train, y_train):
        self.classes = np.unique(y_train)
        self.mean = np.array([X_train[y_train == c].mean(axis = 0) for c in self.classes])
        self.var = np.array([X_train[y_train == c].var(axis = 0) for c in self.classes])
        self.priors = np.array([X_train[y_train == c].shape[0]/len(y) for c in self.classes])

    def predict(self, X_test):
        y_pred = [self._predict(x) for x in X_test]
        return np.array(y_pred)

    def _predict(self, x):
        posteriors = [ np.log(prior) + np.sum(np.log(self.pdf(idx, x))) for idx, prior in enumerate(self.priors)]
        return self.classes[np.argmax(posteriors)]

    def pdf(self, class_idx, x):
        mean, var = self.mean[class_idx], self.var[class_idx]
        numerator = np.exp(-(x-mean)**2/(2*var))
        denominator = np.sqrt(2*np.pi*var)
        return numerator/denominator
        

In [7]:
nb = NaiveBayes()
nb.fit(X_train, y_train)
y_pred = nb.predict(X_test)
print("Predictions: ", class_names[y_pred])
print("Classification Report:")
print(classification_report(y_test, y_pred))

Predictions:  ['setosa' 'versicolor' 'versicolor' 'setosa' 'virginica' 'virginica'
 'virginica' 'setosa' 'setosa' 'virginica' 'versicolor' 'setosa'
 'virginica' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'
 'setosa' 'setosa' 'versicolor' 'versicolor' 'virginica' 'setosa'
 'virginica' 'versicolor' 'setosa' 'setosa' 'versicolor' 'virginica'
 'versicolor' 'virginica' 'versicolor' 'virginica' 'virginica' 'setosa'
 'versicolor' 'setosa' 'versicolor' 'virginica' 'virginica' 'setosa'
 'versicolor' 'virginica' 'versicolor']
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        14
           1       0.94      0.89      0.91        18
           2       0.86      0.92      0.89        13

    accuracy                           0.93        45
   macro avg       0.93      0.94      0.93        45
weighted avg       0.94      0.93      0.93        45

