In [166]:
import numpy as np

In [167]:
class NaiveBayes:

    def fit(self, X, y):
        n_samples, n_features = X.shape
        self._classes = np.unique(y)
        n_classes = len(self._classes)

        self._mean = np.zeros((n_classes, n_features))
        self._var = np.zeros((n_classes, n_features))
        self._prior = np.zeros(n_classes)

        for index, c in enumerate(self._classes):
            X_c = X[y == c]
            self._mean[index, :] = X_c.mean(axis = 0)
            self._var[index, :] = X_c.var(axis = 0)
            self._prior[index] = X_c.shape[0] / float(n_samples)

    def predict(self, X):
        y_pred = [self._predict(x) for x in X]
        return np.array(y_pred)

    def _predict(self, x):
        posteriors = []
        for index, _ in enumerate(self._classes):
            prior = np.log(self._prior[index])
            posterior = np.sum(self._efficient_gaussian(index, x))
            posterior += prior
            posteriors.append(posterior)

        return self._classes[np.argmax(posteriors)]

    def _gaussian(self, index, x):
        mean = self._mean[index]
        var = self._var[index]
        numerator = np.exp(-((x-mean) ** 2)/(2*var))
        denominator = np.sqrt(2 * np.pi * var)
        return np.log(numerator / denominator)

    def _efficient_gaussian(self, index, x):
        mean = self._mean[index]
        var = self._var[index]
        return -(np.log(2 * np.pi * var)) - (((x - mean) ** 2) / (2 * var))



In [168]:
data = np.loadtxt("./diabetes.csv", delimiter=',', skiprows=1)
np.random.shuffle(data)

x = data[:, :-1].astype(float)
y = data[:, -1]

split_index = int(len(x) * 0.7)

X_train = x[:split_index]
X_test = x[split_index:]

y_train = y[:split_index]
y_test = y[split_index:]


In [169]:
nb = NaiveBayes()

nb.fit(X_train, y_train)

In [170]:
y_pred = nb.predict(X_test)
y_pred_train = nb.predict(X_train)

correct = np.sum(y_pred==y_test)
correct_train =  np.sum(y_train==y_pred_train)
print(correct / len(y_pred))
print(correct_train / len(y_pred_train))

0.7272727272727273
0.74487895716946
