In [3]:
import numpy as np
x1 = ['high', 'high', 'low', 'low', 'low', 'high']  
# x2 is weather type
x2 = ['partly cloudy', 'sunny', 'sunny', 'cloudy', 'cloudy', 'cloudy']  
X = np.array([x1, x2]).T
y = np.array([False, False, True, False, False, True]) # rain (True) and no-rain (False)


# Splitting a set
# Input is an array of feature observations and output is a dictionary with "unique feature value" as key and indices as value
def partition(a):
    return {c: (a==c).nonzero()[0] for c in np.unique(a)}

# Picking which attribute to split
# Calculate entropy of a list
def entropy(s):
    res = 0
    val, counts = np.unique(s, return_counts=True)
    freqs = counts.astype('float')/len(s)
    for p in freqs:
        if p != 0.0:
            res -= p * np.log2(p)
    return res

# Calculate decrease in impurity (information gains)
# 
def mutual_information(y, x):
    
    # Calculate entropy of observation classes
    res = entropy(y)

    # We partition x, according to attribute values x_i
    val, counts = np.unique(x, return_counts=True)
    freqs = counts.astype('float')/len(x)

    # We calculate a weighted average of the entropy and subtract it from parent entropy
    for p, v in zip(freqs, val):
        res -= p * entropy(y[x == v])

    return res

# Checks for pureness of elements in a list
def is_pure(s):
    return len(set(s)) == 1

# Helper function to print decision tree
def print_tree(d, depth = 0):
    for key, value in d.items():
        for i in range(depth):
                print(' ', end='')
        if type(value) is dict:
            print(key, end=':\n')
            print_tree(value, depth + 1)
        else:
            print(key, end=': ')
            print(value)
            
    
# Get the most common element of an array
def most_common(a):
    (values,counts) = np.unique(a,return_counts=True)
    ind=np.argmax(counts)
    return values[ind]

# Recursive split of observations
def recursive_split(x, y):
    
    # If set to be split is pure or empty, return it as leaf
    if is_pure(y) or len(y) == 0:
        return most_common(y)

    # Calculate decrease in impurity (information gain) and split attribute with maximum gain
    gain = np.array([mutual_information(y, x_attr) for x_attr in x.T])
    selected_attr = np.argmax(gain)

    # Sufficiently pure, return it as leaf
    if np.all(gain < 1e-6):
        return most_common(y)

    # Split using the selected attribute
    sets = partition(x[:, selected_attr])

    # Perform recursive splits and collect results
    res = {}
    for key, value in sets.items():
        y_subset = y.take(value, axis=0)
        x_subset = x.take(value, axis=0)
        res["x_%d = %s" % (selected_attr, key)] = recursive_split(x_subset, y_subset)

    return res

# Perform algorithm on the example dataset to create a decision tree
print(X)

[['high' 'partly cloudy']
 ['high' 'sunny']
 ['low' 'sunny']
 ['low' 'cloudy']
 ['low' 'cloudy']
 ['high' 'cloudy']]


In [6]:
d = recursive_split(X, y)
d

{'x_1 = cloudy': {'x_0 = high': True, 'x_0 = low': False},
 'x_1 = partly cloudy': False,
 'x_1 = sunny': {'x_0 = high': False, 'x_0 = low': True}}

In [7]:
print_tree(d)

x_1 = cloudy:
 x_0 = high: True
 x_0 = low: False
x_1 = partly cloudy: False
x_1 = sunny:
 x_0 = high: False
 x_0 = low: True


In [13]:
#Recursive labelling of the samples
def predict_rains(d,sample):
    for key, value in d.items():
        nodeKey,condValue = key.replace(" ","").split("=")
        selected_attr = 0
        
        if("1" in nodeKey):
            selected_attr= 1
        
        #Choose atmospheric pressure (0) or weather type(1) based on the condition at the node
        sampleValue = sample[selected_attr]
    
        if type(value) is dict:
            if( sampleValue.replace(" ","") == condValue.replace(" ","")):
                predict_rains(value,sample)
        else:
            if(sampleValue.replace(" ","") == condValue.replace(" ","")):
                #print(" Label --> ",value)
                if(value):
                  print('\nPrediction: Rains')
                else:
                  print('\nPrediction: No Rains')

In [14]:
for i in X:
  a = predict_rains(d,i)


Prediction: No Rains

Prediction: No Rains

Prediction: Rains

Prediction: No Rains

Prediction: No Rains

Prediction: Rains
