<a href="https://colab.research.google.com/github/ashwin642/poisonous_mushroom_decision_tree/blob/main/decisiontree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import pandas as pd
import numpy as np
import math

col_names = ['Edibility', 'Cap Shape', 'Cap Surface', 'Cap Color', 'Bruises', 'Odor', 'Gill Attachment', 'Gill Spacing', 'Gill Size', 'Gill Color', 'Stalk Shape', 'Stalk root', 'Stalk Surface above ring', 'Stalk Surface below ring', 'Stalk color above ring', 'Stalk color below ring', 'Veil Type', 'Veil Color', 'Ring Number', 'Ring Type', 'Spore Print Color', 'Population', 'Habitat']
df = pd.read_csv('agaricus-lepiota.data', names=col_names)

def entropy(attribute):
    entropy_val = 0
    value_counts = attribute.value_counts()
    total_count = len(attribute)
    for count in value_counts:
        probability = count / total_count
        entropy_val += -probability * math.log2(probability)
    return entropy_val

def conditional_entropy(df, attribute, target_attribute):
    attribute_values = df[attribute].unique()
    target_entropy = entropy(df[target_attribute])
    conditional_entropy_val = 0
    for value in attribute_values:
        subset = df[df[attribute] == value]
        subset_entropy = entropy(subset[target_attribute])
        probability = len(subset) / len(df)
        conditional_entropy_val += probability * subset_entropy
    return conditional_entropy_val

def information_gain(df, attribute, target_attribute):
    target_entropy = entropy(df[target_attribute])
    conditional_entropy_val = conditional_entropy(df, attribute, target_attribute)
    return target_entropy - conditional_entropy_val

def split_dataset(df, attribute):
    split_data = {}
    attribute_values = df[attribute].unique()
    for value in attribute_values:
        subset = df[df[attribute] == value]
        subset = subset.drop(columns=[attribute])
        split_data[value] = subset
    return split_data

def create_decision_tree(df, target_attribute, attributes):
    if len(attributes) == 0:
        return df[target_attribute].mode()[0]
    elif len(df[target_attribute].unique()) == 1:
        return df[target_attribute].iloc[0]
    else:
        information_gains = []
        for attribute in attributes:
            info_gain = information_gain(df, attribute, target_attribute)
            information_gains.append(info_gain)
        max_info_gain_attribute = attributes[np.argmax(information_gains)]
        tree = {max_info_gain_attribute: {}}
        attributes.remove(max_info_gain_attribute)
        split_data = split_dataset(df, max_info_gain_attribute)
        for value, subset in split_data.items():
            tree[max_info_gain_attribute][value] = create_decision_tree(subset, target_attribute, attributes.copy())
        return tree

# Example usage
target_attribute = 'Edibility'
attributes = col_names[1:]  # Exclude the target attribute
decision_tree = create_decision_tree(df, target_attribute, attributes)
print(decision_tree)


{'Odor': {'p': 'p', 'a': 'e', 'l': 'e', 'n': {'Spore Print Color': {'n': 'e', 'k': 'e', 'w': {'Habitat': {'w': 'e', 'l': {'Cap Color': {'c': 'e', 'n': 'e', 'w': 'p', 'y': 'p'}}, 'd': {'Gill Size': {'n': 'p', 'b': 'e'}}, 'g': 'e', 'p': 'e'}}, 'h': 'e', 'r': 'p', 'o': 'e', 'y': 'e', 'b': 'e'}}, 'f': 'p', 'c': 'p', 'y': 'p', 's': 'p', 'm': 'p'}}


In [7]:
# Pretty-print the decision tree
def print_decision_tree(tree, indent=''):
    for attribute, subtree in tree.items():
        if isinstance(subtree, dict):
            print(indent + attribute + ':')
            print_decision_tree(subtree, indent + '  ')
        else:
            print(indent + attribute + ':', subtree)

print_decision_tree(decision_tree)

Odor:
  p: p
  a: e
  l: e
  n:
    Spore Print Color:
      n: e
      k: e
      w:
        Habitat:
          w: e
          l:
            Cap Color:
              c: e
              n: e
              w: p
              y: p
          d:
            Gill Size:
              n: p
              b: e
          g: e
          p: e
      h: e
      r: p
      o: e
      y: e
      b: e
  f: p
  c: p
  y: p
  s: p
  m: p
