In [12]:
import numpy as np

In [13]:
rng=np.random.RandomState(1)
X=rng.randint(5,size=(6,100))
y=np.array([1,2,3,4,5,6])

In [14]:
class MultiNB(object):
    
    def __init__(self):
        self.priors = None # 1D iterable containing prior probabilities of each class
        
        self.params = None  # 2D Numpy array to be fitted
                            # Rows: tokens, column: features e.g. integer counts of each word token
      
        self.unique_labels = None  # Return these values when making prediction
    
    def fit(self, X, y, alpha=1.0):
        assert ((alpha <= 1.0) and (alpha > 0.0)) 
        
        self.unique_labels = np.unique(y)
        self.params = np.zeros(shape = (X.shape[1], len(self.unique_labels))) #X.shape[1] is the number of unique tokens
        self.priors = np.zeros(shape = (len(self.unique_labels),))
        
        for ix,label in enumerate(self.unique_labels):
            mask = (y == label) # Boolean mask for extracting training samples corresponding to label
            
            # Add-1 smoothing; verified numerically that probabilities column-sum to 1
            token_counts_in_label = (np.sum(X[mask, :], axis=0) + alpha)
            total_tokens_in_label = np.sum(X[mask, :]) + X.shape[1] * alpha
            self.params[:, ix] = token_counts_in_label / total_tokens_in_label
            self.priors[ix] = np.sum(mask)/len(y)
    
    def predict_log_likelihood(self, X):
       
        log_params = np.log(self.params)
        log_likelihoods = np.dot(X, log_params)
        return log_likelihoods
            
    def predict(self, X):
        
        log_likelihoods = self.predict_log_likelihood(X)
        index_to_label = np.argmax(log_likelihoods, axis=1)
        pred_y = np.asarray([self.unique_labels[index] for index in index_to_label])
        return pred_y
        

In [15]:
like=MultiNB()
like.fit(X,y)
like.predict(X[2:3])

array([3])

In [16]:
from sklearn.naive_bayes import MultinomialNB
multi=MultinomialNB()
multi.fit(X,y)
print(multi.predict(X[2:3]))

[3]
