In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import copy

In [None]:
dataset = pd.read_csv('/content/Tennis.csv')
X = dataset.iloc[:,:].values
X
dataset

Unnamed: 0,Outlook,Temp,Humidity,Wind,Play
0,Sunny,Hot,High,Weak,No
1,Sunny,Hot,High,Strong,No
2,Overcast,Hot,High,Weak,Yes
3,Rain,Mild,High,Weak,Yes
4,Rain,Cool,Normal,Weak,Yes
5,Rain,Cool,Normal,Strong,No
6,Overcast,Cool,Normal,Strong,Yes
7,Sunny,Mild,High,Weak,No
8,Sunny,Cool,Normal,Weak,Yes
9,Rain,Mild,Normal,Weak,Yes


In [None]:
attribute = ['Outlook', 'Temp', 'Humidity', 'Wind']

In [None]:
class Node(object):
    def __init__(self):
        self.value = None
        self.decision = None
        self.child = None

In [None]:
def findEntropy(data, rows):
    yes=0
    no=0
    ans=-1
    idx=len(data[0])-1
    entropy=0

    for i in rows:
        if data[i][idx]=='Yes':
            yes=yes+1
        else:
            no=no+1

    x=yes/(yes+no)
    y=no/(yes+no)
    if x!=0 and y!=0:
        entropy= -1*(x*math.log2(x)+y*math.log2(y))
    if x==1:
        ans = 1
    if y==1:
        ans = 0
    return entropy, ans

In [None]:
def findMaxGain(data, rows, columns):
    maxGain = 0
    retidx = -1
    entropy, ans = findEntropy(data, rows)
    if entropy == 0:
        """if ans == 1:
            print("Yes")
        else:
            print("No")"""
        return maxGain, retidx, ans
    for j in columns:
        mydict = {}
        idx = j
        for i in rows:
            key = data[i][idx]
            if key not in mydict:
                mydict[key] = 1
            else:
                mydict[key] = mydict[key] + 1
        gain = entropy

        # print(mydict)
        for key in mydict:
            yes = 0
            no = 0
            for k in rows:
                if data[k][j] == key:
                    if data[k][-1] == 'Yes':
                        yes = yes + 1
                    else:
                        no = no + 1
            # print(yes, no)
            x = yes/(yes+no)
            y = no/(yes+no)
            # print(x, y)
            if x != 0 and y != 0:
                gain += (mydict[key] * (x*math.log2(x) + y*math.log2(y)))/14
        # print(gain)
        if gain > maxGain:
            # print("hello")
            maxGain = gain
            retidx = j

    return maxGain, retidx, ans

In [None]:
def buildTree(data, rows, columns):

    maxGain, idx, ans = findMaxGain(X, rows, columns)
    root = Node()
    root.childs = []
    # print(maxGain)

    if maxGain == 0:
        if ans == 1:
            root.value = 'Yes'
        else:
            root.value = 'No'
        return root

    root.value = attribute[idx]
    mydict = {}
    for i in rows:
        key = data[i][idx]
        if key not in mydict:
            mydict[key] = 1
        else:
            mydict[key] += 1

    newcolumns = copy.deepcopy(columns)
    newcolumns.remove(idx)
    for key in mydict:
        newrows = []
        for i in rows:
            if data[i][idx] == key:
                newrows.append(i)
        # print(newrows)
        temp = buildTree(data, newrows, newcolumns)
        temp.decision = key
        root.childs.append(temp)
    return root

In [None]:
def traverse(root, level=0, prefix=""):
    print(f"{prefix}Parent: {root.decision}, Value: {root.value}")

    n = len(root.childs)
    if n > 0:
        for i in range(n):
            new_prefix = "│   " * level + "├── "
            traverse(root.childs[i], level + 1, new_prefix)


In [None]:
def calculate():
    rows = [i for i in range(0, 14)]
    columns = [i for i in range(0, 4)]
    root = buildTree(X, rows, columns)
    root.decision = 'Start'
    traverse(root)

In [None]:
calculate()

Parent: Start, Value: Outlook
├── Parent: Sunny, Value: Humidity
│   ├── Parent: High, Value: No
│   ├── Parent: Normal, Value: Yes
├── Parent: Overcast, Value: Yes
├── Parent: Rain, Value: Wind
│   ├── Parent: Weak, Value: Yes
│   ├── Parent: Strong, Value: No


In [None]:
from graphviz import Source

dot_code = """
digraph G {
    edge [dir=forward]
    node [shape=box, style=bold]

    A [label="OUTLOOK"]
    B [label="HUMIDITY"]
    C [label="WIND"]

    D [label="NO", shape=plaintext]
    E [label="YES", shape=plaintext]
    F [label="YES", shape=plaintext]
    G [label="NO", shape=plaintext]

    A -> B [label="SUNNY"]
    A -> E [label="OVERCAST"]
    A -> C [label="RAIN"]

    B -> D [label="HIGH"]
    B -> F [label="NORMAL"]

    C -> F [label="WEAK"]
    C -> G [label="STRONG"]
}
"""

s = Source(dot_code, filename="decision_tree", format="png")
s.view()


'decision_tree.png'