In [7]:
import math

from lib.learn_problem import Learner, error_example
from lib.learn_no_inputs import point_prediction, target_counts, selections


class DT_learner(Learner):

    def __init__(self,
        dataset,
        to_optimize='sum-of-squares',
        leaf_selection='mean',   # what to use for point prediction at leaves
        train=None,              # used for cross validation
        min_number_examples=10
    ):
        self.dataset = dataset
        self.target = dataset.target
        self.to_optimize = to_optimize
        self.leaf_selection = leaf_selection
        self.min_number_examples = min_number_examples
        if train is None:
            self.train = self.dataset.train
        else:
            self.train = train

    def learn(self):
        return self.learn_tree(self.dataset.input_features, self.train)
        
    def learn_tree(self, input_features, data_subset):
        """returns a decision tree
        for input_features is a set of possible conditions
        data_subset is a subset of the data used to build this (sub)tree

        where a decision tree is a function that takes an example and
        makes a prediction on the target feature
        """
        if input_features and len(data_subset) >= self.min_number_examples:
            first_target_val = self.target(data_subset[0])
            allagree = all(self.target(inst) == first_target_val for inst in data_subset)
            if not allagree:
                split, partn = self.select_split(input_features, data_subset)
                if split: # the split succeeded in splitting the data
                    false_examples, true_examples = partn
                    rem_features = [fe for fe in input_features if fe != split]
                    self.display(
                        2, 'Splitting on', split.__doc__, 'with examples split',
                        len(true_examples), ':', len(false_examples)
                    )
                    true_tree = self.learn_tree(rem_features, true_examples)
                    false_tree =  self.learn_tree(rem_features, false_examples)
                    def fun(e):
                        if split(e):
                            return true_tree(e)
                        else:
                            return false_tree(e)
                    #fun = lambda e: true_tree(e) if split(e) else false_tree(e)
                    fun.__doc__ = (
                        'if ' + split.__doc__ + ' then (' + true_tree.__doc__ +
                        ') else (' + false_tree.__doc__ + ')'
                    )
                    return fun
        # don't expand the trees but return a point prediction
        return point_prediction(self.target, data_subset, selection=self.leaf_selection)

    def select_split(self, input_features, data_subset):
        """finds best feature to split on.

        input_features is a non-empty list of features.
        returns feature, partition
        where feature is an input feature with the smallest error as
              judged by to_optimize or
              feature==None if there are no splits that improve the error
        partition is a pair (false_examples, true_examples) if feature is not None
        """
        best_feat = None # best feature
        # best_error = float("inf")  # infinity - more than any error
        best_error = training_error(self.dataset, data_subset, self.to_optimize)
        best_partition = None
        for feat in input_features:
            false_examples, true_examples = partition(data_subset, feat)
            if false_examples and true_examples:  #both partitons are non-empty
                err = (
                    training_error(self.dataset, false_examples, self.to_optimize)
                    +
                    training_error(self.dataset, true_examples, self.to_optimize)
                )
                self.display(
                    3, '   split on', feat.__doc__, 'has err=', err,
                    'splits into', len(true_examples), ':', len(false_examples)
                )
                if err < best_error:
                    best_feat = feat
                    best_error = err
                    best_partition = false_examples, true_examples
        self.display(3, 'best split is on', best_feat.__doc__, 'with err=', best_error)
        return best_feat, best_partition


In [8]:
def partition(data_subset, feature):
    """partitions the data_subset by the feature"""
    true_examples = []
    false_examples = []
    for example in data_subset:
        if feature(example):
            true_examples.append(example)
        else:
            false_examples.append(example)
    return false_examples, true_examples


In [9]:
def training_error(dataset, data_subset, to_optimize):
    """returns training error for dataset on to_optimize.
    This assumes that we choose the best value for the optimization
    criteria for dataset according to point_prediction
    """
    select_dict = {
        'sum-of-squares': 'mean',
        'sum_absolute': 'median',
        'logloss': 'Laplace'
    }
    selection = select_dict[to_optimize]
    predictor = point_prediction(dataset.target, data_subset, selection=selection)
    error = sum(
        error_example(predictor(example), dataset.target(example), to_optimize) for example in data_subset
    )
    return error


In [10]:
from lib.learn_problem import Data_set, Data_from_file

def test(data):
    """Prints errors and the trees for various evaluation criteria and ways to select leaves.
    """
    selections = ["median", "mean", "Laplace"]
    evaluation_criteria = ["sum-of-squares","sum_absolute","logloss"]
    
    for crit in evaluation_criteria:
        for leaf in selections:
            tree = DT_learner(data, to_optimize=crit, leaf_selection=leaf).learn()
            print('For', crit, 'using', leaf, 'at leaves, tree built is:', tree.__doc__)
            if data.test:
                for ecrit in evaluation_criteria:
                    test_error = data.evaluate_dataset(data.test, tree, ecrit)
                    print('    Average error for', ecrit, 'using', leaf, 'at leaves is', test_error)


In [12]:
if __name__ == "__main__":
    #print("carbool.csv"); test(data = Data_from_file('data/carbool.csv', target_index=-1))
    # print("SPECT.csv"); test(data = Data_from_file('data/SPECT.csv', target_index=0))
    print("mail_reading.csv"); test(data=Data_from_file('data/mail_reading.csv', target_index=-1))
    # print("holiday.csv"); test(data = Data_from_file('data/holiday.csv', num_train=19, target_index=-1))


carbool.csv
Tuples read. 
Training set 1206 examples. Number of columns: {7} 
Test set 522 examples. Number of columns: {7}
There are 20 input features
   split on e[0]==vhigh has err= 244.15732872296726 splits into 302 : 904
   split on e[0]==high has err= 249.48442485444755 splits into 303 : 903
   split on e[0]==low has err= 247.15864055299477 splits into 310 : 896
   split on e[0]==med has err= 247.59312714776723 splits into 291 : 915
   split on e[1]==vhigh has err= 241.4802731213603 splits into 295 : 911
   split on e[1]==high has err= 249.42385785870127 splits into 301 : 905
   split on e[1]==low has err= 246.50722453222406 splits into 296 : 910
   split on e[1]==med has err= 246.47766416269042 splits into 314 : 892
   split on e[2]<3 has err= 249.19361057953324 splits into 314 : 892
   split on e[2]<4 has err= 249.83147517703 splits into 619 : 587
   split on e[2]<5 has err= 250.40411835566698 splits into 916 : 290
   split on e[3]==2 has err= 198.83519206939107 splits into 399

   split on e[1]==high has err= 32.13370733987835 splits into 93 : 207
   split on e[1]==low has err= 32.87592341323694 splits into 99 : 201
   split on e[1]==med has err= 32.99074074074066 splits into 108 : 192
   split on e[2]<3 has err= 30.835555555555576 splits into 75 : 225
   split on e[2]<4 has err= 32.136806080270055 splits into 151 : 149
   split on e[2]<5 has err= 32.64888888888895 splits into 225 : 75
   split on e[3]==more has err= 32.59933259176867 splits into 145 : 155
   split on e[3]==4 has err= 32.59933259176867 splits into 155 : 145
   split on e[4]==big has err= 30.95609756097559 splits into 95 : 205
   split on e[4]==med has err= 32.985000000000184 splits into 100 : 200
   split on e[4]==small has err= 29.57509157509168 splits into 105 : 195
   split on e[5]==med has err= 30.56088429310024 splits into 142 : 158
   split on e[5]==high has err= 30.56088429310024 splits into 158 : 142
best split is on e[4]==small with err= 29.57509157509168
Splitting on e[4]==small wit

   split on e[1]==vhigh has err= 185.34892812105937 splits into 195 : 610
   split on e[1]==high has err= 197.07432646065214 splits into 198 : 607
   split on e[1]==low has err= 192.92855371900953 splits into 200 : 605
   split on e[1]==med has err= 192.87119380190293 splits into 212 : 593
   split on e[2]<3 has err= 196.72240802675543 splits into 207 : 598
   split on e[2]<4 has err= 197.57466004668595 splits into 411 : 394
   split on e[2]<5 has err= 198.14123376623394 splits into 616 : 189
   split on e[3]==2 has err= 120.75278810409094 splits into 267 : 538
   split on e[3]==more has err= 185.6744946782556 splits into 266 : 539
   split on e[3]==4 has err= 171.20061251517473 splits into 272 : 533
   split on e[4]==big has err= 197.04841713221708 splits into 268 : 537
   split on e[4]==med has err= 198.00263066805056 splits into 270 : 535
   split on e[4]==small has err= 195.01566350611898 splits into 267 : 538
   split on e[5]==med has err= 192.86493894534058 splits into 398 : 407


   split on e[0]==med has err= 5.3125 splits into 16 : 40
   split on e[1]==high has err= 5.299999999999999 splits into 20 : 36
   split on e[1]==low has err= 5.2865497076023376 splits into 18 : 38
   split on e[1]==med has err= 5.356725146198829 splits into 18 : 38
   split on e[2]<3 has err= 3.4285714285714284 splits into 14 : 42
   split on e[2]<4 has err= 4.714285714285717 splits into 28 : 28
   split on e[2]<5 has err= 5.1219512195121935 splits into 41 : 15
   split on e[3]==more has err= 4.5 splits into 24 : 32
   split on e[3]==4 has err= 4.5 splits into 32 : 24
best split is on e[2]<3 with err= 3.4285714285714284
Splitting on e[2]<3 with examples split 14 : 42
   split on e[0]==high has err= 3.4222222222222225 splits into 5 : 9
   split on e[0]==low has err= 3.375 splits into 6 : 8
   split on e[0]==med has err= 3.3939393939393936 splits into 3 : 11
   split on e[1]==high has err= 3.2 splits into 5 : 9
   split on e[1]==low has err= 3.25 splits into 4 : 10
   split on e[1]==med

   split on e[0]==high has err= 13.548387096774196 splits into 36 : 62
   split on e[0]==med has err= 6.206896551724142 splits into 29 : 69
   split on e[2]<3 has err= 15.90610328638498 splits into 27 : 71
   split on e[2]<4 has err= 15.918367346938776 splits into 49 : 49
   split on e[2]<5 has err= 15.891014492753625 splits into 75 : 23
   split on e[3]==more has err= 15.83729662077597 splits into 51 : 47
   split on e[3]==4 has err= 15.83729662077597 splits into 47 : 51
   split on e[4]==big has err= 15.818181818181815 splits into 32 : 66
   split on e[4]==med has err= 15.885714285714288 splits into 35 : 63
   split on e[4]==small has err= 15.66297544535387 splits into 31 : 67
   split on e[5]==med has err= 15.501886792452835 splits into 45 : 53
   split on e[5]==high has err= 15.501886792452835 splits into 53 : 45
best split is on e[0]==med with err= 6.206896551724142
Splitting on e[0]==med with examples split 29 : 69
   split on e[2]<3 has err= 6.08421052631579 splits into 10 : 19


   split on e[0]==low has err= 4.235294117647058 splits into 6 : 17
   split on e[0]==med has err= 4.984126984126984 splits into 9 : 14
   split on e[1]==high has err= 4.7142857142857135 splits into 7 : 16
   split on e[1]==low has err= 5.178571428571429 splits into 7 : 16
   split on e[1]==med has err= 4.984126984126984 splits into 9 : 14
   split on e[3]==more has err= 5.212121212121213 splits into 12 : 11
   split on e[3]==4 has err= 5.212121212121213 splits into 11 : 12
   split on e[4]==big has err= 3.733333333333334 splits into 8 : 15
   split on e[4]==med has err= 3.733333333333334 splits into 15 : 8
best split is on e[0]==high with err= 3.233333333333332
Splitting on e[0]==high with examples split 8 : 15
   split on e[0]==low has err= 1.5555555555555556 splits into 6 : 9
   split on e[0]==med has err= 1.5555555555555556 splits into 9 : 6
   split on e[1]==high has err= 1.2000000000000002 splits into 5 : 10
   split on e[1]==low has err= 1.6000000000000005 splits into 5 : 10
   

   split on e[3]==more has err= 142 splits into 267 : 532
   split on e[3]==4 has err= 142 splits into 261 : 538
   split on e[4]==big has err= 142 splits into 264 : 535
   split on e[4]==med has err= 142 splits into 273 : 526
   split on e[4]==small has err= 142 splits into 262 : 537
   split on e[5]==med has err= 142 splits into 398 : 401
   split on e[5]==low has err= 142 splits into 401 : 398
best split is on None with err= 142
For sum_absolute using median at leaves, tree built is: if e[5]==high then (if e[3]==2 then (0) else (if e[1]==vhigh then (if e[0]==low then (1) else (if e[0]==med then (1) else (0))) else (1))) else (0)
    Average error for sum-of-squares using median at leaves is 0.16475095785440613
    Average error for sum_absolute using median at leaves is 0.16475095785440613
    Average error for logloss using median at leaves is inf
   split on e[0]==vhigh has err= 355 splits into 302 : 904
   split on e[0]==high has err= 355 splits into 303 : 903
   split on e[0]==l

   split on e[0]==vhigh has err= 142 splits into 200 : 599
   split on e[0]==high has err= 142 splits into 194 : 605
   split on e[0]==low has err= 142 splits into 209 : 590
   split on e[0]==med has err= 142 splits into 196 : 603
   split on e[1]==vhigh has err= 142 splits into 196 : 603
   split on e[1]==high has err= 142 splits into 201 : 598
   split on e[1]==low has err= 142 splits into 198 : 601
   split on e[1]==med has err= 142 splits into 204 : 595
   split on e[2]<3 has err= 142 splits into 211 : 588
   split on e[2]<4 has err= 142 splits into 412 : 387
   split on e[2]<5 has err= 142 splits into 606 : 193
   split on e[3]==2 has err= 142 splits into 271 : 528
   split on e[3]==more has err= 142 splits into 267 : 532
   split on e[3]==4 has err= 142 splits into 261 : 538
   split on e[4]==big has err= 142 splits into 264 : 535
   split on e[4]==med has err= 142 splits into 273 : 526
   split on e[4]==small has err= 142 splits into 262 : 537
   split on e[5]==med has err= 142 

   split on e[4]==big has err= 143.22281480581137 splits into 95 : 205
   split on e[4]==med has err= 163.15223956562772 splits into 100 : 200
   split on e[4]==small has err= 141.95881160756778 splits into 105 : 195
   split on e[5]==med has err= 146.22055045740035 splits into 142 : 158
   split on e[5]==high has err= 146.22055045740035 splits into 158 : 142
best split is on e[4]==small with err= 141.95881160756778
Splitting on e[4]==small with examples split 105 : 195
   split on e[0]==high has err= 81.36391016567731 splits into 32 : 73
   split on e[0]==low has err= 83.57874384829638 splits into 40 : 65
   split on e[0]==med has err= 89.10609930282341 splits into 33 : 72
   split on e[1]==high has err= 84.89759919951692 splits into 38 : 67
   split on e[1]==low has err= 87.70117864101911 splits into 33 : 72
   split on e[1]==med has err= 88.38909450174884 splits into 34 : 71
   split on e[2]<3 has err= 82.74451182526477 splits into 25 : 80
   split on e[2]<4 has err= 86.656522666457

   split on e[0]==high has err= 791.1366449804995 splits into 207 : 598
   split on e[0]==low has err= 780.215018479713 splits into 203 : 602
   split on e[0]==med has err= 782.5702870593851 splits into 191 : 614
   split on e[1]==vhigh has err= 756.2758834323604 splits into 195 : 610
   split on e[1]==high has err= 792.8343045124396 splits into 198 : 607
   split on e[1]==low has err= 780.8165627967292 splits into 200 : 605
   split on e[1]==med has err= 780.6446724791058 splits into 212 : 593
   split on e[2]<3 has err= 791.78849142376 splits into 207 : 598
   split on e[2]<4 has err= 794.3396193226471 splits into 411 : 394
   split on e[2]<5 has err= 796.0016497327633 splits into 616 : 189
   split on e[3]==2 has err= 499.06422221223886 splits into 267 : 538
   split on e[3]==more has err= 759.4414020767845 splits into 266 : 539
   split on e[3]==4 has err= 715.9965497923251 splits into 272 : 533
   split on e[4]==big has err= 792.8129255106672 splits into 268 : 537
   split on e[4]

   split on e[3]==more has err= 18.34114239687853 splits into 11 : 12
   split on e[3]==4 has err= 18.34114239687853 splits into 12 : 11
   split on e[4]==med has err= 13.270249391619295 splits into 12 : 11
   split on e[4]==small has err= 13.270249391619295 splits into 11 : 12
best split is on e[4]==med with err= 13.270249391619295
Splitting on e[4]==med with examples split 12 : 11
   split on e[1]==low has err= 11.768612239392446 splits into 5 : 7
   split on e[1]==med has err= 11.768612239392446 splits into 7 : 5
   split on e[2]<3 has err= 7.613507848646137 splits into 4 : 8
   split on e[2]<4 has err= 5.457307584432803 splits into 7 : 5
   split on e[2]<5 has err= 11.520857364321765 splits into 11 : 1
   split on e[3]==more has err= 11.084725238016475 splits into 6 : 6
   split on e[3]==4 has err= 11.084725238016475 splits into 6 : 6
best split is on e[2]<4 with err= 5.457307584432803
Splitting on e[2]<4 with examples split 7 : 5
   split on e[1]==low has err= 7.141404972557168 sp

   split on e[3]==4 has err= 7.648334012357209 splits into 23 : 24
   split on e[4]==big has err= 7.710970567031691 splits into 23 : 24
   split on e[4]==med has err= 7.710970567031691 splits into 24 : 23
   split on e[5]==med has err= 7.512605538869584 splits into 21 : 26
   split on e[5]==high has err= 7.512605538869584 splits into 26 : 21
best split is on e[2]<4 with err= 6.884451811173842
Splitting on e[2]<4 with examples split 14 : 33
   split on e[1]==high has err= 6.07630499784113 splits into 3 : 11
   split on e[1]==low has err= 6.07630499784113 splits into 3 : 11
   split on e[1]==med has err= 5.731295226753273 splits into 8 : 6
   split on e[3]==more has err= 5.5348204878467495 splits into 7 : 7
   split on e[3]==4 has err= 5.5348204878467495 splits into 7 : 7
   split on e[4]==big has err= 5.5348204878467495 splits into 7 : 7
   split on e[4]==med has err= 5.5348204878467495 splits into 7 : 7
   split on e[5]==med has err= 5.29121224395462 splits into 6 : 8
   split on e[5]=

   split on e[2]<4 has err= 51.22773323184623 splits into 31 : 29
   split on e[2]<5 has err= 51.677022446945074 splits into 47 : 13
   split on e[3]==more has err= 52.52342318397346 splits into 29 : 31
   split on e[3]==4 has err= 52.52342318397346 splits into 31 : 29
   split on e[4]==big has err= 44.26635033134544 splits into 16 : 44
   split on e[4]==med has err= 52.382877881614746 splits into 25 : 35
   split on e[4]==small has err= 42.734104775062775 splits into 19 : 41
   split on e[5]==med has err= 37.43924947472085 splits into 31 : 29
   split on e[5]==high has err= 37.43924947472085 splits into 29 : 31
best split is on e[5]==med with err= 37.43924947472085
Splitting on e[5]==med with examples split 31 : 29
   split on e[0]==low has err= 30.754485319132073 splits into 17 : 14
   split on e[0]==med has err= 30.754485319132073 splits into 14 : 17
   split on e[2]<3 has err= 30.29075503539005 splits into 9 : 22
   split on e[2]<4 has err= 29.732521144378488 splits into 17 : 14
  

   split on e[0]==high has err= 27.77530789656624 splits into 18 : 38
   split on e[0]==low has err= 27.52999126204417 splits into 22 : 34
   split on e[0]==med has err= 27.539309380140686 splits into 16 : 40
   split on e[1]==high has err= 27.339024070934514 splits into 20 : 36
   split on e[1]==low has err= 27.328162024305993 splits into 18 : 38
   split on e[1]==med has err= 27.77530789656624 splits into 18 : 38
   split on e[2]<3 has err= 15.189478741394248 splits into 14 : 42
   split on e[2]<4 has err= 22.399912924905106 splits into 28 : 28
   split on e[2]<5 has err= 25.997292978714533 splits into 41 : 15
   split on e[3]==more has err= 20.8819275615964 splits into 24 : 32
   split on e[3]==4 has err= 20.8819275615964 splits into 32 : 24
best split is on e[2]<3 with err= 15.189478741394248
Splitting on e[2]<3 with examples split 14 : 42
   split on e[0]==high has err= 13.789209293426154 splits into 5 : 9
   split on e[0]==low has err= 13.650612255493119 splits into 6 : 8
   spli