In [13]:
import csv
import math

In [14]:
def load_csv(filename):
    lines=csv.reader(open(filename,"r"));
    dataset = list(lines)
    headers = dataset.pop(0)
    return dataset,headers

In [15]:
class Node:
    def __init__(self,attribute):
        self.attribute=attribute
        self.children=[]
        self.answer=""

In [21]:
def subtables(data, col, delete):
    dic = {}
    coldata = [row[col] for row in data]
    attr = list(set(coldata))

    counts = [0] * len(attr)
    r = len(data)
    c = len(data[0])
    for x in range(len(attr)):
        for y in range(r):
            if data[y][col] == attr[x]:
                counts[x] += 1

    for x in range(len(attr)):
        dic[attr[x]] = [[0 for i in range(c)] for j in range(counts[x])]
        pos = 0
        for y in range(r):
            if data[y][col] == attr[x]:
                if delete:
                    del data[y][col]
                dic[attr[x]][pos] = data[y]
                pos += 1
    return attr, dic


def entropy(S):
    attr = list(set(S))
    if len(attr) == 1:
        return 0

    counts = [0, 0]
    for i in range(2):
        counts[i] = sum([1 for x in S if attr[i] == x]) / (len(S) * 1.0)

    sums = 0
    for cnt in counts:
        sums += -1 * cnt * math.log(cnt, 2)
    return sums


def compute_gain(data, col):
    attr, dic = subtables(data, col, delete=False)

    total_size = len(data)
    entropies = [0] * len(attr)
    ratio = [0] * len(attr)

    total_entropy = entropy([row[-1] for row in data])

    for x in range(len(attr)):
        ratio[x] = len(dic[attr[x]]) / (total_size * 1.0)
        entropies[x] = entropy([row[-1] for row in dic[attr[x]]])
        total_entropy -= ratio[x] * entropies[x]
    return total_entropy


def build_tree(data, features):
    lastcol = [row[-1] for row in data]
    if (len(set(lastcol))) == 1:
        node = Node("")
        node.answer = lastcol[0]
        return node

    n = len(data[0]) - 1
    gains = [0] * n

    for col in range(n):
        gains[col] = compute_gain(data, col)

    split = gains.index(max(gains))
    node = Node(features[split])
    fea = features[:split] + features[split + 1:]

    attr, dic = subtables(data, split, delete=True)

    for x in range(len(attr)):
        child = build_tree(dic[attr[x]], fea)
        node.children.append((attr[x], child))
    return node


def print_tree(node, level):
    if node.answer != "":
        print(" " * level, node.answer)
        return

    print(" " * level, node.attribute)

    for value, n in node.children:
        print(" " * (level + 1), value)
        print_tree(n, level + 2)


def classify(node, x_test, features):
    if node.answer != "":
        print(node.answer)
        return

    pos = features.index(node.attribute)

    for value, n in node.children:
        if x_test[pos] == value:
            classify(n, x_test, features)

In [22]:
'''Main program'''
dataset,features=load_csv("Social_Network_Ads.csv")
node1=build_tree(dataset,features)

In [23]:
print("The decision tree for the dataset using ID3 algorithm is")
print_tree(node1,0)
testdata,features=load_csv("Social_Network_Ads.csv")

The decision tree for the dataset using ID3 algorithm is
 User ID
  15776733
   0
  15584320
   1
  15694879
   1
  15792008
   0
  15727696
   1
  15807837
   1
  15807909
   0
  15577514
   1
  15697020
   0
  15638963
   0
  15711218
   0
  15755018
   0
  15789863
   0
  15764604
   1
  15782530
   0
  15688172
   1
  15699247
   0
  15708196
   0
  15766289
   0
  15595917
   0
  15809347
   0
  15715541
   0
  15582066
   0
  15668521
   1
  15694395
   0
  15800061
   0
  15693264
   0
  15674206
   0
  15663161
   1
  15672330
   1
  15750056
   1
  15617482
   1
  15791373
   1
  15678168
   0
  15741049
   0
  15625395
   1
  15767871
   0
  15631912
   0
  15723373
   0
  15793813
   0
  15720745
   1
  15753102
   1
  15697574
   0
  15574305
   0
  15578006
   0
  15624510
   0
  15646091
   1
  15753861
   1
  15628972
   0
  15778830
   1
  15636428
   0
  15813113
   1
  15603246
   0
  15675949
   1
  15801247
   0
  15744279
   1
  15779529
   1
  15748589
   1
  1569

In [24]:
for xtest in testdata:
    print("The test instance:",xtest)
    print("The label for test instance:",end=" ")
    classify(node1,xtest,features)

The test instance: ['15624510', 'Male', '19', '19000', '0']
The label for test instance: 0
The test instance: ['15810944', 'Male', '35', '20000', '0']
The label for test instance: 0
The test instance: ['15668575', 'Female', '26', '43000', '0']
The label for test instance: 0
The test instance: ['15603246', 'Female', '27', '57000', '0']
The label for test instance: 0
The test instance: ['15804002', 'Male', '19', '76000', '0']
The label for test instance: 0
The test instance: ['15728773', 'Male', '27', '58000', '0']
The label for test instance: 0
The test instance: ['15598044', 'Female', '27', '84000', '0']
The label for test instance: 0
The test instance: ['15694829', 'Female', '32', '150000', '1']
The label for test instance: 1
The test instance: ['15600575', 'Male', '25', '33000', '0']
The label for test instance: 0
The test instance: ['15727311', 'Female', '35', '65000', '0']
The label for test instance: 0
The test instance: ['15570769', 'Female', '26', '80000', '0']
The label for tes

The label for test instance: 0
The test instance: ['15646227', 'Male', '46', '79000', '1']
The label for test instance: 1
The test instance: ['15660541', 'Male', '40', '57000', '0']
The label for test instance: 0
The test instance: ['15753874', 'Female', '37', '80000', '0']
The label for test instance: 0
The test instance: ['15617877', 'Female', '46', '82000', '0']
The label for test instance: 0
The test instance: ['15772073', 'Female', '53', '143000', '1']
The label for test instance: 1
The test instance: ['15701537', 'Male', '42', '149000', '1']
The label for test instance: 1
The test instance: ['15736228', 'Male', '38', '59000', '0']
The label for test instance: 0
The test instance: ['15780572', 'Female', '50', '88000', '1']
The label for test instance: 1
The test instance: ['15769596', 'Female', '56', '104000', '1']
The label for test instance: 1
The test instance: ['15586996', 'Female', '41', '72000', '0']
The label for test instance: 0
The test instance: ['15722061', 'Female', '5