In [14]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

In [62]:
class NBM:
    def __init__(self, vectorizer='count', laplace=1):
        self.laplace = laplace
        self.vectorizer = vectorizer

    def transform(self, X_train, X_test, method):
        v = CountVectorizer()
        X_train = v.fit_transform(X_train)
        X_test = vectorizer.transform(X_test)
        if method == 'Tfid':
            print("Tfid")
            v = TfidfTransformer()
            X_train = v.fit_transform(X_train)
            X_test = vectorizer.transform(X_test)
        return X_train, X_test.toarray() 

    def fit(self, X, y):
        # Shape        
        m, n = X.shape
        self.classes = np.unique(y)
        k = len(self.classes)
        # fit
        self.likelihoods = np.zeros((k,n))
        self.priors = np.zeros(k)
        for idx, label in enumerate(self.classes):
            X_classed = X[y==label]
            self.likelihoods[idx,:] = self.likelihood_fn(X_classed)
            self.priors[idx] = self.prior_fn(y, label)

    def likelihood_fn(self, X_class):
        dividend  = ((X_class.sum(axis=0)) + self.laplace)
        devider = (np.sum(X_class.sum(axis=0) + self.laplace))
        return  dividend/devider

    def prior_fn(self, y, label):
        return len(y[y==label])/len(y)

    def predict(self, X_test):
        yhat = np.log(self.priors) + X_test @ np.log(self.likelihoods.T)
        yhat = np.argmax(yhat, axis=1)
        return yhat
    


In [50]:
from sklearn.datasets import fetch_20newsgroups

data = fetch_20newsgroups()
data.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [54]:
categories = ['talk.religion.misc', 'soc.religion.christian',
              'sci.space', 'comp.graphics']
train = fetch_20newsgroups(subset='train', categories=categories)
test = fetch_20newsgroups(subset='test', categories=categories)

In [56]:
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(train.data)
X_test = vectorizer.transform(test.data)
X_test = X_test.toarray()  #vectorizer gives us a sparse matrix; convert back to dense matrix

y_train = train.target
y_test = test.target

'''print("X_train: ", X_train[0])
print("y_train: ", y_train[0])'''

print(X_test.shape)
print(X_train.shape)

(1432, 35329)
(2153, 35329)


In [64]:
model = NBM()

X_train, X_test = model.transform(train.data, test.data, 'Count')
y_train = train.target
y_test = test.target

print(X_train.shape)
print(X_test.shape)


model.fit(X_train, y_train)

yhat = model.predict(X_train)

(2153, 35329)
(1432, 35329)
