## Import and load data

In [3]:
from sklearn import tree
import pandas as pd
from collections import Counter
from matplotlib import pyplot as plt
import graphviz

In [4]:
ATTRIBUTES = [
    "AAGE",
    "ACLSWKR",
    "ADTIND",
    "ADTOCC",
    # "AGI",
    "AHGA",
    "AHRSPAY",
    "AHSCOL",
    "AMARITL",
    "AMJIND",
    "AMJOCC",
    "ARACE",
    "AREORGN",
    "ASEX",
    "AUNMEM",
    "AUNTYPE",
    "AWKSTAT",
    "CAPGAIN",
    "CAPLOSS",
    "DIVVAL",
    # "FEDTAX",
    "FILESTAT",
    "GRINREG",
    "GRINST",
    "HHDFMX",
    "HHDREL",
    "MARSUPWT", #This is sample weight, should be ignored
    "MIGMTR1",
    "MIGMTR3",
    "MIGMTR4",
    "MIGSAME",
    "MIGSUN",
    "NOEMP",
    "PARENT",
    # "PEARNVAL",
    "PEFNTVTY",
    "PEMNTVTY",
    "PENATVTY",
    "PRCITSHP",
    # "PTOTVAL",
    "SEOTR",
    # "TAXINC",
    "VETQVA",
    "VETYN",
    "WKSWORK",
    "YEAR",
    "RESULT"
]

In [5]:

DATA = pd.read_csv('census-income.data', header=None, names=ATTRIBUTES)
RESULTS = DATA[:]['RESULT']
RESULTS = RESULTS.apply(lambda x: 0 if '-' in x else 1)
DATA = DATA.drop(columns=["MARSUPWT","RESULT"])
ATTRIBUTE_COLUMNS = DATA.columns

## Change string data to integers

In [6]:
print(Counter(DATA.dtypes))
CLEANED_DATA = DATA.apply(lambda x: pd.factorize(x)[0] if x.dtype == 'O' else x)
print(Counter(CLEANED_DATA.dtypes))

Counter({dtype('O'): 28, dtype('int64'): 12})
Counter({dtype('int64'): 40})


In [7]:
print(f"Data shape: {CLEANED_DATA.shape}")
print(f"Results shape: {RESULTS.shape}, Results type {RESULTS.dtypes}")

Data shape: (199523, 40)
Results shape: (199523,), Results type int64


In [8]:
CLEANED_DATA[:5]

Unnamed: 0,AAGE,ACLSWKR,ADTIND,ADTOCC,AHGA,AHRSPAY,AHSCOL,AMARITL,AMJIND,AMJOCC,...,PARENT,PEFNTVTY,PEMNTVTY,PENATVTY,PRCITSHP,SEOTR,VETQVA,VETYN,WKSWORK,YEAR
0,73,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,2,0,95
1,58,1,4,34,1,0,0,1,1,1,...,0,0,0,0,0,0,0,2,52,94
2,18,0,0,0,2,0,1,2,0,0,...,0,1,1,1,1,0,0,2,0,95
3,9,0,0,0,3,0,0,2,0,0,...,1,0,0,0,0,0,0,0,0,94
4,10,0,0,0,3,0,0,2,0,0,...,1,0,0,0,0,0,0,0,0,94


In [9]:
# clf = tree.DecisionTreeClassifier()
clf = tree.DecisionTreeClassifier(criterion="entropy", max_depth=2)
clf = clf.fit(CLEANED_DATA, RESULTS)

In [10]:
clf.predict(CLEANED_DATA[:1][0:])

array([0], dtype=int64)

In [12]:
# draw tree
# tree.plot_tree(clf)
dot_data = tree.export_graphviz(clf, out_file=None) 
graph = graphviz.Source(dot_data) 
graph.render("tree2") 

'tree2.pdf'