In [109]:
import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from jax import vmap, jit, tree

from sklearn.model_selection import train_test_split

In [110]:
# Case 3: Dataset with Redundant Features
data3 = {
    'Feature1': ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A',
                 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A', 'B'],
    'Feature2': ['W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',
                 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W'],  # Redundant feature
    'Feature3': ['L', 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L',
                 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L', 'M'],
    'Target':   [0, 0, 1, 0, 1, 1, 2, 2, 2, 0,
                 0, 1, 0, 1, 1, 2, 2, 2, 0, 0]
}

data = pd.DataFrame(data3)
print("\nCase 3: Dataset with Redundant Features")
print(data)


Case 3: Dataset with Redundant Features
   Feature1 Feature2 Feature3  Target
0         A        W        L       0
1         B        W        M       0
2         C        W        N       1
3         A        W        L       0
4         B        W        M       1
5         C        W        N       1
6         A        W        L       2
7         B        W        M       2
8         C        W        N       2
9         A        W        L       0
10        B        W        M       0
11        C        W        N       1
12        A        W        L       0
13        B        W        M       1
14        C        W        N       1
15        A        W        L       2
16        B        W        M       2
17        C        W        N       2
18        A        W        L       0
19        B        W        M       0


In [111]:
from sklearn.preprocessing import LabelEncoder

label_encoders = {}
for column in ['Feature1', 'Feature2', 'Feature3']:
    le = LabelEncoder()
    data[column] = le.fit_transform(data[column])
    label_encoders[column] = le

# Display the updated DataFrame
print(data.head())

   Feature1  Feature2  Feature3  Target
0         0         0         0       0
1         1         0         1       0
2         2         0         2       1
3         0         0         0       0
4         1         0         1       1


In [112]:
X = jnp.asarray(data.drop(columns=['Target']).to_numpy(dtype=jnp.int32))
y = jnp.asarray(data['Target'].to_numpy(dtype=jnp.int32))

In [113]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=12)

In [114]:
X_train.shape, y_train.shape, X.shape, y.shape

((16, 3), (16,), (20, 3), (20,))

In [115]:
unique_classes = jnp.unique(y) 
unique_categories = list(map(jnp.unique, X.T))

print(f'Classes: {unique_classes.tolist()}')
print(f'Categories in each feature column: {unique_categories}')

Classes: [0, 1, 2]
Categories in each feature column: [Array([0, 1, 2], dtype=int32), Array([0], dtype=int32), Array([0, 1, 2], dtype=int32)]


In [116]:
@jit
def compute_priors(y):
    return jnp.unique(y, return_counts=True, size=len(unique_classes))[1] / jnp.sum(jnp.unique(y, return_counts=True, size=len(unique_classes))[1])

priors = compute_priors(y_train)
priors

Array([0.375 , 0.3125, 0.3125], dtype=float32)

In [117]:
indices_for_the_classes = [jnp.where(y_train == class_) for class_ in unique_classes]

def restructure_matrix_into_blocks(X:jax.Array):
    @jit
    def restructure_by_indices(indices:jax.Array):
        return X[indices]
    return restructure_by_indices

X_train_restructured = tree.flatten(tree.map(restructure_matrix_into_blocks(X_train), indices_for_the_classes))[0]

In [118]:
X_train_restructured[0]

Array([[1, 0, 1],
       [1, 0, 1],
       [0, 0, 0],
       [0, 0, 0],
       [0, 0, 0],
       [0, 0, 0]], dtype=int32)

In [119]:
def return_likelihoods_for_feature_column(column:jax.Array):
    counts_of_feature_in_block = jnp.unique(column, return_counts=True)[1]
    return counts_of_feature_in_block / jnp.sum(counts_of_feature_in_block)

likelihoods_for_block_0 = list(map(return_likelihoods_for_feature_column, X_train_restructured[0].T))
print(likelihoods_for_block_0)

[Array([0.6666667 , 0.33333334], dtype=float32), Array([1.], dtype=float32), Array([0.6666667 , 0.33333334], dtype=float32)]


In [120]:
def compute_likelihoods_for_blocks(block:jax.Array):
    return list(map(return_likelihoods_for_feature_column, block.T))

likelihoods = tree.map(compute_likelihoods_for_blocks, X_train_restructured)
likelihoods

[[Array([0.6666667 , 0.33333334], dtype=float32),
  Array([1.], dtype=float32),
  Array([0.6666667 , 0.33333334], dtype=float32)],
 [Array([0.2, 0.8], dtype=float32),
  Array([1.], dtype=float32),
  Array([0.2, 0.8], dtype=float32)],
 [Array([0.4, 0.2, 0.4], dtype=float32),
  Array([1.], dtype=float32),
  Array([0.4, 0.2, 0.4], dtype=float32)]]

In [121]:
X_test

Array([[1, 0, 1],
       [1, 0, 1],
       [0, 0, 0],
       [1, 0, 1]], dtype=int32)

In [122]:
block = 0
j = 1
x_j = X_test.T[j]
likelihoods[block][j][x_j], x_j

(Array([1., 1., 1., 1.], dtype=float32), Array([0, 0, 0, 0], dtype=int32))

In [123]:
def retrieve_likelihood_for_block_i_feature_j_xij(x_ij:jax.Array, i:int, j:int):
    return likelihoods[i][j][x_ij]

def retrieve_likelihood_for_block_i_feature_j(feature_column:jax.Array, i:int, j:int):
    return vmap(retrieve_likelihood_for_block_i_feature_j_xij, in_axes=(0, None, None))(feature_column, i, j)

def retrieve_likelihood_for_block_i(X:jax.Array, i:int, j:int):
    return vmap(retrieve_likelihood_for_block_i_feature_j, in_axes=(0, None, None))(X, i, j)

block_of_likelihoods = [[] for _ in range(len(unique_classes))]
posteriors_array = []
for i in range(unique_classes.shape[0]):
    for j in range(X_test.shape[1]):
        v_array = retrieve_likelihood_for_block_i_feature_j(X_test.T[j], i, j)
        block_of_likelihoods[i].append(v_array)
    posteriors = jnp.prod(jnp.vstack(block_of_likelihoods[i]), axis=0)*priors[i]
    posteriors_array.append(posteriors)

y_pred = jnp.vstack(posteriors_array).argmin(axis=0)
jnp.where(y_pred == y_test, 1, 0).mean()
#jnp.prod(jnp.vstack(block_of_likelihoods[0]), axis=0)*priors[0], block_of_likelihoods[0]

Array(0.25, dtype=float32)

In [124]:
from sklearn.naive_bayes import MultinomialNB

model_sk = MultinomialNB()
model_sk_fitted = model_sk.fit(X_train, y_train)
y_pred_sk = model_sk_fitted.predict(X_test)
jnp.where(y_pred_sk == y_test, 1, 0).mean()

Array(0.5, dtype=float32)

In [125]:
class MultinomialNaiveBayes():
    def fit(self, X:jax.Array, y:jax.Array):
        # Computing priors
        self.priors = compute_priors(y)

        # Computing likelihoods
        self.unique_classes = jnp.unique(y)
        self.num_classes = len(self.unique_classes)

        indices_for_the_classes = [jnp.where(y == class_) for class_ in self.unique_classes]

        self.X_restructured = tree.flatten(tree.map(restructure_matrix_into_blocks(X), indices_for_the_classes))[0]
        
        self.blocks_of_likelihoods = tree.map(compute_likelihoods_for_blocks, self.X_restructured)

        return self
    
    def predict(self, X:jax.Array):
        self.log_posteriors = jnp.zeros(shape=(X.shape[1], self.num_classes))
        for i in range(self.unique_classes.shape[0]):
            for j in range(X.shape[1]):
                array_of_prior_and_likelihoods = jnp.hstack((priors[i], retrieve_likelihood_for_block_i_feature_j(X.T[j], i, j)))
                log_posterior = jnp.sum(jnp.log(array_of_prior_and_likelihoods))
                self.log_posteriors = self.log_posteriors.at[j, i].set(log_posterior)

        return self.log_posteriors.argmin(axis=1)
    
model = MultinomialNaiveBayes()
model_fitted = model.fit(X_train, y_train)
model_fitted.predict(X_test), X_test

(Array([2, 1, 2], dtype=int32),
 Array([[1, 0, 1],
        [1, 0, 1],
        [0, 0, 0],
        [1, 0, 1]], dtype=int32))