## Part 1

In [27]:
import numpy as np
from math import log2
import copy

with open('breast-cancer-wisconsin.data', 'r') as f:
    a = [l.strip('\n').split(',') for l in f if '?' not in l]
a = np.array(a).astype(int)   # training data

In [2]:
sum(b[-1] == 2 for b in a), sum(b[-1] == 4 for b in a)

(444, 239)

In [3]:
def entropy(data):
    count = len(data)
    p0 = sum(b[-1] == 2 for b in data) / count
    if p0 == 0 or p0 == 1: return 0
    p1 = 1 - p0
    return -p0 * log2(p0) - p1 * log2(p1)

In [4]:
entropy(a)

0.9340026588217948

In [12]:
def infogain(data, fea, threshold):  # x_fea <= threshold;  fea = 2,3,4,..., 10; threshold = 1,..., 9
    count = len(data)
    d1 = data[data[:, fea - 1] <= threshold]
    d2 = data[data[:, fea - 1] > threshold]
    if len(d1) == 0 or len(d2) == 0: return 0
    return entropy(data) - (len(d1) / count * entropy(d1) + len(d2) / count * entropy(d2))

In [24]:
threshold = np.argmax([infogain(a, 2, t) for t in range(1,10)]) + 1
d1 = a[a[:, 1] <= threshold]
d2 = a[a[:, 1] > threshold]
sum(b[-1] == 2 for b in d1), sum(b[-1] == 2 for b in d2), sum(b[-1] == 4 for b in d1), sum(b[-1] == 4 for b in d2)

(439, 5, 94, 145)

In [25]:
infogain(a, 2, threshold)

0.36324318392472243

## Part 2

In [32]:
F =  (3, 4, 5, 8, 9, 10)
def entropy(data):
    count = len(data)
    p0 = sum(b[-1] == 2 for b in data) / count
    if p0 == 0 or p0 == 1: return 0
    p1 = 1 - p0
    return -p0 * log2(p0) - p1 * log2(p1)


def infogain(data, fea, threshold):  # x_fea <= threshold;  fea = 2,3,4,..., 10; threshold = 1,..., 9
    count = len(data)
    d1 = data[data[:, fea - 1] <= threshold]
    d2 = data[data[:, fea - 1] > threshold]
    if len(d1) == 0 or len(d2) == 0: return 0
    return entropy(data) - (len(d1) / count * entropy(d1) + len(d2) / count * entropy(d2))


def find_best_split(data):
    c = len(data)
    c0 = sum(b[-1] == 2 for b in data)
    if c0 == c: return (2, None)
    if c0 == 0: return (4, None)
    ig = [[infogain(data, f, t) for t in range(1, 10)] for f in F]
    ig = np.array(ig)
    max_ig = max(max(i) for i in ig)
    if max_ig == 0:
        if c0 >= c - c0:
            return (2, None)
        else:
            return (4, None)
    ind = np.unravel_index(np.argmax(ig, axis=None), ig.shape)
    fea, threshold = F[ind[0]], ind[1] + 1
    return (fea, threshold)


def split(data, node):
    fea, threshold = node.fea, node.threshold
    d1 = data[data[:,fea-1] <= threshold]
    d2 = data[data[:,fea-1] > threshold]
    return (d1,d2)


class Node:
    def __init__(self, fea, threshold):
        self.fea = fea
        self.threshold = threshold
        self.left = None
        self.right = None


ig = [[infogain(a, fea, t) for t in range(0,11)] for fea in F]
ig = np.array(ig)
ind = np.unravel_index(np.argmax(ig, axis=None), ig.shape)
root = Node(F[ind[0]], ind[1] + 1)


def create_tree(data, node):
    d1,d2 = split(data, node)
    f1, t1 = find_best_split(d1)
    f2, t2 = find_best_split(d2)
    if t1 == None: node.left = f1
    else:
        node.left = Node(f1,t1)
        create_tree(d1, node.left)
    if t2 == None: node.right = f2
    else:
        node.right = Node(f2,t2)
        create_tree(d2, node.right)

create_tree(a, root)

In [47]:
def print_tree(node, f, prefix=''):
    fea = node.fea
    t = node.threshold
    l = node.left
    r = node.right
    if l == 2 or l == 4:
        f.write(prefix+'if (x'+str(fea)+' <= '+str(t)+') return '+str(l)+'\n')
    else:
        f.write(prefix+'if (x'+str(fea)+' <= '+str(t)+')\n')
        print_tree(l, f, prefix+' ')
    if r == 2 or r == 4:
        f.write(prefix+'else return '+str(r)+'\n')
    else:
        f.write(prefix+'else\n')
        print_tree(r, f, prefix+' ')

f = open("decision_tree.txt", "w")
print_tree(root, f, prefix='')
f.close()

In [54]:
def classify(i):
    x3, x4, x5, x8, x9, x10 = i[2], i[3], i[4], i[7], i[8], i[9]
    if x3 <= 3:
        if x9 <= 2:
            if x8 <= 3:
                if x5 <= 2:
                    if x10 <= 1:
                        return 2
                    elif x3 <= 1: 
                        return 2
                    elif x5 <= 1: 
                        return 4
                    else: 
                        return 2
                elif x3 <= 2:
                    if x8 <= 1:
                        return 2
                    elif x8 <= 2:
                        if x4 <= 1:
                            if x5 <= 3:
                                if x3 <= 1:
                                    return 2
                                else: 
                                    return 2
                            else: 
                                return 2
                        else: 
                            return 2
                    else: 
                        return 2
                elif x4 <= 2: 
                    return 2
                else: 
                    return 4
            else:
                if x4 <= 3:
                    if x3 <= 1:
                        return 2
                    elif x9 <= 1:
                        if x8 <= 5: 
                            return 4
                        elif x5 <= 3:
                            if x4 <= 2: 
                                return 2
                            elif x3 <= 2: 
                                return 4
                            else: 
                                return 2
                        else: 
                            return 4
                    else: 
                        return 2
                else: 
                    return 4
        elif x4 <= 2:
            if x8 <= 4:
                return 2
            elif x4 <= 1:
                if x5 <= 1:
                    return 4
                else: 
                    return 2
            else: 
                return 4
        elif x4 <= 4:
            if x8 <= 4:
                if x3 <= 1: 
                    return 2
                elif x3 <= 2: 
                    return 4
                elif x9 <= 7:
                    if x9 <= 5:
                        if x4 <= 3:
                            if x5 <= 4:
                                if x5 <= 1:
                                    if x10 <= 1:
                                        return 2
                                    else: 
                                        return 4
                                else: 
                                    return 4
                            else: 
                                return 2
                        else: 
                            return 4
                    else: 
                        return 2
                else: 
                    return 4
            else: 
                return 4
        else: 
            return 4
    else:
        if x3 <= 4:
            if x5 <= 5:
                if x4 <= 6:
                    if x9 <= 5:
                        if x8 <= 3:
                            if x9 <= 2: 
                                return 2
                            else: 
                                return 4
                        elif x5 <= 2:
                            return 4
                        elif x4 <= 4: 
                            return 2
                        elif x5 <= 3: 
                            return 2
                        else: 
                            return 4
                    elif x9 <= 8: 
                        return 2
                    else: 
                        return 4
                else: 
                    return 4
            else: 
                return 4
        elif x8 <= 4:
            if x10 <= 1:
                if x4 <= 6: 
                    return 4
                elif x3 <= 9:
                    if x5 <= 1: 
                        return 2
                    elif x3 <= 8: 
                        return 4
                    else: 
                        return 2
                else: 
                    return 4
            else: 
                return 4
        else: 
            return 4

In [56]:
test = np.loadtxt('test.txt', dtype=np.int32, delimiter=',')
result = []
for i in test:
    result.append(classify(i))
np.savetxt("result.txt", np.reshape(result, (1,-1)), fmt='%d', delimiter=',')

In [60]:
def print_tree(node, f, d=0, prefix=''):
    fea = node.fea
    t = node.threshold
    l = node.left
    r = node.right
    d += 1
    if d <= 6 :
        if l == 2 or l == 4:
            f.write(prefix+'if (x'+str(fea)+' <= '+str(t)+') return '+str(l)+'\n')
        else:
            if d < 6 :
                f.write(prefix+'if (x'+str(fea)+' <= '+str(t)+')\n')
                print_tree(l, f, d, prefix+' ')
            else:
                d1 = a[a[:, fea-1] <= t]
                p, n = sum(b[-1] == 2 for b in d1), sum(b[-1] == 4 for b in d1)
                if p >= n:
                    var = 2
                else:
                    var = 4
                f.write(prefix+'if (x'+str(fea)+' <= '+str(t)+') return '+str(var)+'\n')
        if r == 2 or r == 4:
            f.write(prefix+'else return '+str(r)+'\n')
        else:
            if d < 6 :
                f.write(prefix+'else\n')
                print_tree(r, f, d, prefix+' ')
            else:
                d2 = a[a[:, fea-1] > t]
                p, n = sum(b[-1] == 2 for b in d2), sum(b[-1] == 4 for b in d2)
                if p >= n:
                    var = 2
                else:
                    var = 4
                f.write(prefix+'else return '+str(var)+'\n')
f = open("decision_tree1.txt", "w")
print_tree(root, f, d=0, prefix='')
f.close()

In [61]:
def classify(i):
    x3, x4, x5, x8, x9, x10 = i[2], i[3], i[4], i[7], i[8], i[9]
    if x3 <= 3:
        if x9 <= 2:
            if x8 <= 3:
                if x5 <= 2:
                    if x10 <= 1: 
                        return 2
                    elif x3 <= 1:
                        return 2
                    else: 
                        return 4
                elif x3 <= 2:
                    if x8 <= 1: 
                        return 2
                    else: 
                        return 2
                elif x4 <= 2: 
                    return 2
                else: 
                    return 4
            elif x4 <= 3:
                if x3 <= 1: 
                    return 2
                elif x9 <= 1: 
                    return 2
                else:
                    return 2
            else: 
                return 4
        elif x4 <= 2:
            if x8 <= 4: 
                return 2
            elif x4 <= 1:
                if x5 <= 1: 
                    return 4
                else:
                    return 2
            else:
                return 4
        elif x4 <= 4:
            if x8 <= 4:
                if x3 <= 1: 
                    return 2
                else:
                    return 4
            else:
                return 4
        else: 
            return 4
    elif x3 <= 4:
        if x5 <= 5:
            if x4 <= 6:
                if x9 <= 5:
                    if x8 <= 3: 
                        return 2
                    else:
                        return 4
                elif x9 <= 8: 
                    return 2
                else: 
                    return 4
            else: 
                return 4
        else: 
            return 4
    elif x8 <= 4:
        if x10 <= 1:
            if x4 <= 6: 
                return 4
            elif x3 <= 9: 
                return 2
            else: 
                return 4
        else: 
            return 4
    else: 
        return 4

In [None]:
test = np.loadtxt('test.txt', dtype=np.int32, delimiter=',')
result = []
for i in test:
    result.append(classify(i))
np.savetxt("result.txt", np.reshape(result, (1,-1)), fmt='%d', delimiter=',')