In [19]:
import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from jax import jit, vmap

In [20]:
# Define the data
data = {
    'Feature1': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C',
                 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B',
                 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A',
                 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C',
                 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B'],
    'Feature2': ['X', 'Y', 'X', 'Z', 'X', 'Y', 'X', 'Z', 'Y', 'X',
                 'X', 'Z', 'Y', 'X', 'Y', 'Z', 'Y', 'X', 'Y', 'X',
                 'X', 'Y', 'X', 'Z', 'Y', 'X', 'Z', 'Y', 'X', 'Y',
                 'Z', 'X', 'Y', 'Z', 'X', 'Y', 'X', 'Z', 'X', 'Y',
                 'X', 'Y', 'X', 'Z', 'X', 'Y', 'X', 'Z', 'Y', 'X'],
    'Feature3': ['L', 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L',
                 'L', 'M', 'N', 'L', 'M', 'L', 'N', 'M', 'L', 'N',
                 'L', 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'L', 'N',
                 'L', 'M', 'L', 'M', 'N', 'L', 'M', 'L', 'N', 'L',
                 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L', 'M'],
    'Target': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
               0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
               0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
               0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
               0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
}

# Create DataFrame
data = pd.DataFrame(data)

# Display the DataFrame
print(data.head())

  Feature1 Feature2 Feature3  Target
0        A        X        L       0
1        B        Y        M       1
2        A        X        N       0
3        C        Z        L       1
4        B        X        M       0


In [21]:
from sklearn.preprocessing import LabelEncoder

# Initialize LabelEncoder
label_encoders = {}
for column in ['Feature1', 'Feature2', 'Feature3']:
    le = LabelEncoder()
    data[column] = le.fit_transform(data[column])
    label_encoders[column] = le

# Display the updated DataFrame
print(data.head())

   Feature1  Feature2  Feature3  Target
0         0         0         0       0
1         1         1         1       1
2         0         0         2       0
3         2         2         0       1
4         1         0         1       0


In [22]:
X = data.drop(columns=['Target']).to_numpy()
y = data['Target'].to_numpy()

X, y = map(jnp.array, (
    X, y
))


In [23]:
def split_data(data, val_size=0.1, test_size=0.2):
    """ 
    Splits data.
    """
    split_index_test = int(len(data) * (1-test_size))

    data_non_test = data[:split_index_test]
    data_test = data[split_index_test:]

    split_index_val = int(len(data_non_test) * (1-val_size))

    data_train = data_non_test[:split_index_val]
    data_val = data_non_test[split_index_val:]

    return data_train, data_val, data_test

In [24]:
(X_train, X_val, X_test), (y_train, y_val, y_test) = map(
    split_data,
    (X, y)
)

In [25]:
unique_classes = jnp.unique(y).tolist()
indices_for_each_class = [jnp.where(y_train==class_) for class_ in unique_classes]

dictionary_of_stds = dict(zip(unique_classes,
                                  [[jnp.std(X_train[collection_of_indices][:,j]).item() for j in range(X_train.shape[1])] for collection_of_indices in indices_for_each_class]))

dictionary_of_stds[0]

[0.8089011311531067, 0.6781419515609741, 0.8975274562835693]

In [26]:
val = [1,2,3]
keys = ['a', 'b', 'c']

dict(zip(keys, [i for i in range(len(keys))]))

{'a': 0, 'b': 1, 'c': 2}

In [27]:
def compute_priors(y:jax.Array):
    """
    Obtain prior probabilities.

    Args:
        y (jax.Array): Label vector.
    
    Returns:
        prior_probabilities (jax.Array): Vector of prior probabilities.
    """
    unique_classes = jnp.unique_values(y)
    prior_probabilities = []
    
    for index, class_ in enumerate(unique_classes.tolist()):
        prior_probabilities.append(jnp.mean(jnp.where(y==class_, 1, 0)))

    return jnp.array(prior_probabilities)

In [28]:
compute_priors(y)

Array([0.5, 0.5], dtype=float32)

In [29]:
def gaussian_pdf(x, mean, std):
    return jnp.exp(-0.5 * ((x-mean)/std)**2 )/(std*jnp.sqrt(2*jnp.pi))

In [30]:
def compute_means(X:jax.Array, y:jax.Array, random_state=12):
    """ 
    Computes means.
    """
    np.random.seed(random_state)

    unique_classes = jnp.unique(y).tolist()
    indices_for_each_class = [jnp.where(y==class_) for class_ in unique_classes]

    dictionary_of_means = dict(zip(unique_classes, 
                                   [[jnp.mean(X[collection_of_indices][:,j]).item() for j in range(X.shape[1])] for collection_of_indices in indices_for_each_class]))
    
    return dictionary_of_means

In [31]:
def compute_stds(X:jax.Array, y:jax.Array, random_state=12):
    """ 
    Compute stds.
    """
    np.random.seed(random_state)

    unique_classes = jnp.unique(y).tolist()
    indices_for_each_class = [jnp.where(y==class_) for class_ in unique_classes]

    dictionary_of_stds = dict(zip(unique_classes, 
                                   [[jnp.std(X[collection_of_indices][:,j]).item() for j in range(X.shape[1])] for collection_of_indices in indices_for_each_class]))
    
    return dictionary_of_stds

In [32]:
def compute_posterior(X:jax.Array, y:jax.Array):
    """ 
    Computes posteriors to compute predictions.
    """
    posteriors = []
    
    dictionary_of_means = compute_means(X, y)
    dictionary_of_stds = compute_stds(X, y)
    
    prior_probabilites = compute_priors(y)

    for x in X:
        likelihoods = jnp.array([gaussian_pdf(x, jnp.array(means), jnp.array(stds)) for means, stds in zip(dictionary_of_means.values(), dictionary_of_stds.values())])
        vector_of_posteriors = jnp.log(jnp.dot(likelihoods, prior_probabilites))
        posteriors.append(vector_of_posteriors)

    return posteriors

In [33]:
posteriors = compute_posterior(X_test, y_test)
y_pred = jnp.argmax(jnp.array(posteriors), axis=1)

TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (2,).

In [None]:
jnp.mean(y_pred == y_test) * 100

Array(87.43718, dtype=float32)

In [None]:
from sklearn.naive_bayes import GaussianNB

model = GaussianNB()
model_fitted = model.fit(X_train, y_train)
y_pred_2 = model_fitted.predict(X_test)

jnp.mean(y_pred_2 == y_test) * 100

Array(92.46231, dtype=float32)