In [119]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.tree import export_text

data=pd.read_csv('./depression.csv')

In [97]:
OUTCOME='D'

def attr_probs(data,attr):
    probs={}
    values=np.unique(data[attr])
    for i in values:
        rows=data.loc[data[attr] == i]
        negative=np.count_nonzero(rows['D']==0)
        positive=np.count_nonzero(rows['D']==1)

        probs[i]=positive/(positive+negative)

    return probs



In [98]:
print(attr_probs(data,'A'))

{0: 0.6711111111111111, 1: 0.4052631578947368}


In [99]:
def entropy(probs):
    entropy=0
    for i,value in enumerate(probs):
        if value!=0:
            entropy-=value*np.log2(value)
    return entropy

In [100]:
print(entropy([2/8, 0/8, 4/8, 2/8]))

1.5


In [101]:
def gain(data,attr):
    total=data.shape[0]
    sum=0

    for key,value in attr_probs(data,attr).items():
        nRows=len(data.loc[data[attr]==key])
        portion=nRows/total
        entropyS=entropy([value,1-value])
        sum+=entropyS*portion

    portion=np.count_nonzero(data['D']==1)/total
    gain=entropy([portion,1-portion])-sum
    return gain
        


In [132]:
TTRS = ['A', 'B', 'C']
for attr in TTRS:
    print('Gain {attr}: {gain:.2f}'.format(attr=attr, gain=gain(data, attr)))

Gain A: 0.05
Gain B: 0.02
Gain C: 0.07


In [133]:
labels=data.iloc[:,-1]
samples=data.iloc[:,:-1]

x_train,x_test,y_train,y_test=train_test_split(samples,labels,test_size=0.3,random_state=7)

tree=DecisionTreeClassifier(criterion="entropy")
tree.fit(x_train,y_train)

y_pred = tree.predict(x_test)
print(classification_report(y_test,y_pred))

print(export_text(decision_tree=tree,feature_names=['A','B','C']))

              precision    recall  f1-score   support

           0       0.72      0.62      0.66       154
           1       0.66      0.76      0.71       152

    accuracy                           0.69       306
   macro avg       0.69      0.69      0.68       306
weighted avg       0.69      0.69      0.68       306

|--- C <= 1.50
|   |--- A <= 0.50
|   |   |--- B <= 0.50
|   |   |   |--- C <= 0.50
|   |   |   |   |--- class: 1
|   |   |   |--- C >  0.50
|   |   |   |   |--- class: 1
|   |   |--- B >  0.50
|   |   |   |--- C <= 0.50
|   |   |   |   |--- class: 1
|   |   |   |--- C >  0.50
|   |   |   |   |--- class: 1
|   |--- A >  0.50
|   |   |--- C <= 0.50
|   |   |   |--- B <= 0.50
|   |   |   |   |--- class: 0
|   |   |   |--- B >  0.50
|   |   |   |   |--- class: 0
|   |   |--- C >  0.50
|   |   |   |--- B <= 0.50
|   |   |   |   |--- class: 0
|   |   |   |--- B >  0.50
|   |   |   |   |--- class: 1
|--- C >  1.50
|   |--- B <= 0.50
|   |   |--- A <= 0.50
|   |   |   |--