In [1]:
from collections import Counter

import os
import numpy as np

In [2]:
def get_X_Y_from(file):
    with open(file, 'r',encoding='utf16') as f:
        docs = [line.split() for line in f.readlines()]
    X = [doc[:-1] for doc in docs]
    Y = [doc[-1]  for doc in docs]
    return X,Y

In [32]:
# paths
# train_input_file = os.path.join(os.getcwd(),"train.in")
# validation_input_file = os.path.join(os.getcwd(),"validation.in")
# train_file = os.path.join(os.getcwd(),"small.in")
# X_train,Y_train = get_X_Y_from(train_file)
# # print(test_X,test_Y)
# test data
X_train = [
    ["cricket","very","small","insect"],
    ["play","music"],
    ["play","play","cricket","football"],
    ["like","singing"],
    ["insect","small","live"]
]

Y_train = ["Biology","Music","Sports","Music","Biology"]

X_test = [["want","play","cricket"]]
Y_test = ["Sports"]


# set comprehention to find unique words
unique_words = {word for doc in X_train for word in doc}
print(unique_words)

{'singing', 'live', 'very', 'music', 'small', 'like', 'football', 'insect', 'play', 'cricket'}


In [55]:
class ClassWordCounter:
    def __init__(self,n_docs,n_words,word_counter):
        self.n_docs       = n_docs
        self.n_words      = n_words
        self.word_counter = word_counter 
        
        
    def __str__(self):
        return f"n_docs: {self.n_docs}\nn_words: {self.n_words}\nword_counter: {self.word_counter}"
    
    def __repr__(self):
        return f"ClassWordCounter({self.n_docs},{self.n_words},{self.word_counter})"
    
    
    
def get_class_word_counters(X_train,Y_train) -> ClassWordCounter:
    class_word_counters = dict()
    for cls,doc in zip(Y_train,X_train):
        if cls not in class_word_counters:
            class_word_counters[cls] = ClassWordCounter(0,0,Counter())
        class_word_counters[cls].n_docs += 1
        class_word_counters[cls].n_words += len(doc)
        class_word_counters[cls].word_counter += Counter(doc)
    # print(class_word_counters)
    return class_word_counters


class_word_counters = get_class_word_counters(X_train,Y_train)
class_word_counters

{'Biology': ClassWordCounter(2,7,Counter({'small': 2, 'insect': 2, 'cricket': 1, 'very': 1, 'live': 1})),
 'Music': ClassWordCounter(2,4,Counter({'play': 1, 'music': 1, 'like': 1, 'singing': 1})),
 'Sports': ClassWordCounter(1,4,Counter({'play': 2, 'cricket': 1, 'football': 1}))}

In [56]:
def naive_bayes_predict(document,class_word_counters,n_docs,n_unique_words,alpha):
    print(document)
    probabilities = dict()
    for word in document:
        for cls,class_word_counter in class_word_counters.items():
            if cls not in probabilities:
                # print(f"prior {cls}: P({cls}) = {class_word_counter.n_docs}/{n_docs}")
                # prior: P(cls)
                probabilities[cls] = class_word_counter.n_docs/n_docs
    
            n_word_cls = class_word_counter.word_counter[word]
            n_cls      = class_word_counter.n_words
            
            # zero problem
            # print(f"P({word}|{cls}) *= {n_word_cls}/{n_cls}")
            
            # with smoothin factor
            print(f"P({word}|{cls}) *= ({n_word_cls} + {alpha})/({n_cls} + {alpha}*{n_unique_words})")
            probabilities[cls] *= ((n_word_cls + alpha)/(n_cls + (alpha * n_unique_words)))
    
    print("Probabilities: ",probabilities)
    prediction = max(probabilities,key=lambda key: probabilities[key])
    return prediction


n_docs = len(X_train)
n_unique_words = len(unique_words)
document = X_test[0]
naive_bayes_predict(document,class_word_counters,n_docs,n_unique_words,1)

['want', 'play', 'cricket']
P(want|Biology) *= (0 + 1)/(7 + 1*10)
P(want|Music) *= (0 + 1)/(4 + 1*10)
P(want|Sports) *= (0 + 1)/(4 + 1*10)
P(play|Biology) *= (0 + 1)/(7 + 1*10)
P(play|Music) *= (1 + 1)/(4 + 1*10)
P(play|Sports) *= (2 + 1)/(4 + 1*10)
P(cricket|Biology) *= (1 + 1)/(7 + 1*10)
P(cricket|Music) *= (0 + 1)/(4 + 1*10)
P(cricket|Sports) *= (1 + 1)/(4 + 1*10)
Probabilities:  {'Biology': 0.00016283329940972927, 'Music': 0.0002915451895043731, 'Sports': 0.0004373177842565597}


'Sports'