In [1]:
import numpy as np

def create_train_data():
    data = [
        ['Sunny', 'Hot', 'High', 'Weak', 'no'],
        ['Sunny', 'Hot', 'High', 'Strong', 'no'],
        ['Overcast', 'Hot', 'High', 'Weak', 'yes'],
        ['Rain', 'Mild', 'High', 'Weak', 'yes'],
        ['Rain', 'Cool', 'Normal', 'Weak', 'yes'],
        ['Rain', 'Cool', 'Normal', 'Strong', 'no'],
        ['Overcast', 'Cool', 'Normal', 'Strong', 'yes'],
        ['Overcast', 'Mild', 'High', 'Weak', 'no'],
        ['Sunny', 'Cool', 'Normal', 'Weak', 'yes'],
        ['Rain', 'Mild', 'Normal', 'Weak', 'yes']
    ]
    return np.array(data)

train_data = create_train_data()
print(train_data)

[['Sunny' 'Hot' 'High' 'Weak' 'no']
 ['Sunny' 'Hot' 'High' 'Strong' 'no']
 ['Overcast' 'Hot' 'High' 'Weak' 'yes']
 ['Rain' 'Mild' 'High' 'Weak' 'yes']
 ['Rain' 'Cool' 'Normal' 'Weak' 'yes']
 ['Rain' 'Cool' 'Normal' 'Strong' 'no']
 ['Overcast' 'Cool' 'Normal' 'Strong' 'yes']
 ['Overcast' 'Mild' 'High' 'Weak' 'no']
 ['Sunny' 'Cool' 'Normal' 'Weak' 'yes']
 ['Rain' 'Mild' 'Normal' 'Weak' 'yes']]


In [2]:
def compute_prior_probability(train_data):
    y_unique = ['no', 'yes']
    prior_probability = np.zeros(len(y_unique))
    
    total_samples = len(train_data)
    yes_count = np.sum(train_data[:, -1] == 'yes')
    no_count = total_samples - yes_count
    
    prior_probability[0] = no_count / total_samples
    prior_probability[1] = yes_count / total_samples
    
    return prior_probability

prior_probability = compute_prior_probability(train_data)
print("P(play tennis = No)", prior_probability[0])
print("P(play tennis = Yes)", prior_probability[1])

P(play tennis = No) 0.4
P(play tennis = Yes) 0.6


In [3]:
def compute_conditional_probability(train_data):
    y_unique = ['no', 'yes']
    conditional_probability = []
    list_x_name = []
    
    for i in range(0, train_data.shape[1] - 1):
        x_unique = np.unique(train_data[:, i])
        list_x_name.append(x_unique)
        
        x_conditional_probability = np.zeros((2, len(x_unique)))
        
        for j, y in enumerate(y_unique):
            y_samples = train_data[train_data[:, -1] == y]
            total_y_samples = len(y_samples)
            
            for k, x in enumerate(x_unique):
                x_count = np.sum(y_samples[:, i] == x)
                x_conditional_probability[j, k] = x_count / total_y_samples
        
        conditional_probability.append(x_conditional_probability)
    
    return conditional_probability, list_x_name

conditional_probability, list_x_name = compute_conditional_probability(train_data)

In [4]:
def get_index_from_value(feature_name, list_features):
    return np.where(list_features == feature_name)[0][0]

In [5]:
def train_naive_bayes(train_data):
    prior_probability = compute_prior_probability(train_data)
    conditional_probability, list_x_name = compute_conditional_probability(train_data)
    return prior_probability, conditional_probability, list_x_name

In [6]:
def prediction_play_tennis(X, list_x_name, prior_probability, conditional_probability):
    x1 = get_index_from_value(X[0], list_x_name[0])
    x2 = get_index_from_value(X[1], list_x_name[1])
    x3 = get_index_from_value(X[2], list_x_name[2])
    x4 = get_index_from_value(X[3], list_x_name[3])
    
    p0 = prior_probability[0]
    p1 = prior_probability[1]
    
    p0 *= conditional_probability[0][0, x1]
    p0 *= conditional_probability[1][0, x2]
    p0 *= conditional_probability[2][0, x3]
    p0 *= conditional_probability[3][0, x4]
    
    p1 *= conditional_probability[0][1, x1]
    p1 *= conditional_probability[1][1, x2]
    p1 *= conditional_probability[2][1, x3]
    p1 *= conditional_probability[3][1, x4]
    
    if p0 > p1:
        y_pred = 0
    else:
        y_pred = 1
    
    return y_pred

# Dự đoán cho ngày D11
X = ['Sunny', 'Cool', 'High', 'Strong']
data = create_train_data()
prior_probability, conditional_probability, list_x_name = train_naive_bayes(data)
pred = prediction_play_tennis(X, list_x_name, prior_probability, conditional_probability)

if pred:
    print("Ad should go!")
else:
    print("Ad should not go!")

Ad should not go!
