In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [12]:
PATH = 'titanic.csv'

In [13]:
df = pd.read_csv(PATH)

In [14]:
df

Unnamed: 0,Passengerid,Age,Fare,Sex,Pclass,Embarked,survived
0,1,22.0,7.2500,0,3,2.0,0
1,2,38.0,71.2833,1,1,0.0,1
2,3,26.0,7.9250,1,3,2.0,1
3,4,35.0,53.1000,1,1,2.0,1
4,5,35.0,8.0500,0,3,2.0,0
...,...,...,...,...,...,...,...
1304,1305,28.0,8.0500,0,3,2.0,0
1305,1306,39.0,108.9000,1,1,0.0,0
1306,1307,38.5,7.2500,0,3,2.0,0
1307,1308,28.0,8.0500,0,3,2.0,0


In [31]:
class GiniImpurity:
    
    def __init__(self, leaves: list): # list of list 
        self.leaves = leaves
        self.total_samples = sum([sum(leaf) for leaf in self.leaves])
    
    def probability(self,value,total):
        return value/total
    
    def get_value(self,target_list : list):
        self.target_list = target_list
        self.number_of_samples = sum(self.target_list)
        # formula is 1 - square of probability of class A - square of probability of class B - ....
        list_of_probabilites = []
        for target in self.target_list:
            list_of_probabilites.append(self.probability(target,self.number_of_samples)**2)
        sum_of_probabilities = sum(list_of_probabilites)
        return 1-sum_of_probabilities

    # total gini impurity is weighted average of all leaves gini impurity
    def get_total_gini_impurity(self):
        weights_x_impurity = []
        for i in range(len(self.leaves)):
            leaf = self.leaves[i]
            leaf_total = sum(leaf)
            impurity = self.get_value(leaf)
            weight = leaf_total / self.total_samples
            weights_x_impurity.append(weight*impurity)
        return sum(weights_x_impurity)

In [120]:
class GiniImpurityForCategoricalColumn:
    
    def __init__(self,categorical_column : str, target_column : str, dataframe):
        self.categorical_column = categorical_column
        self.target_column = target_column
        self.dataframe = dataframe
        self.branches = list(self.dataframe[self.categorical_column].unique())
        self.targets = list(self.dataframe[self.target_column].unique())
        
    def get(self):
        gini_impurity = []
        for branch in self.branches:
            filtered_df = self.dataframe[self.dataframe[self.categorical_column]==branch]
            counts = dict(filtered_df[self.target_column].value_counts())
            follow_target = [counts.get(tar,0) for tar in self.targets]
            gini_impurity.append(follow_target)
        gini = GiniImpurity(gini_impurity)
        impurity = gini.get_total_gini_impurity()
        return {impurity, 'classes':}

In [121]:
class GiniImpurityForNumericalColumn:
    
    def __init__(self,numerical_column : str, target_column : str, dataframe):
        self.numerical_column = numerical_column
        self.target_column = target_column
        self.dataframe = dataframe
        self.dataframe = self.dataframe.sort_values(by=self.numerical_column)
        self.dataframe.reset_index(inplace = True,drop = True)
        self.max_index = max(self.dataframe.index)
        self.targets = list(self.dataframe[self.target_column].unique())
        
    def get(self) -> dict:
        gini_impurity = []
        for index,row in self.dataframe.iterrows():
            if index!=self.max_index:
                next_ = self.dataframe.loc[index+1,self.numerical_column]
                current = row[self.numerical_column]
                average = (next_+current)/2
                if average not in [i['average'] for i in gini_impurity]:
                    filtered_df = self.dataframe[self.dataframe[self.numerical_column]<=average]
                    filtered_df_2 = self.dataframe[self.dataframe[self.numerical_column]>average]
                    counts = dict(filtered_df[self.target_column].value_counts())
                    counts_2 = dict(filtered_df_2[self.target_column].value_counts())
                    follow_target = [counts.get(tar,0) for tar in self.targets]
                    follow_target_2 =[counts_2.get(tar,0) for tar in self.targets]
                    gini = GiniImpurity([follow_target,follow_target_2])
                    impurity = gini.get_total_gini_impurity()
                    gini_impurity.append({'impurity':impurity, 'average':average})
        impurity = min(gini_impurity,key = lambda x : x['impurity'])
        return impurity

In [None]:
class Node:
    
    def __init__(self,gini):
        self.gini = gini
        
        
            
        

In [113]:
x = GiniImpurityForNumericalColumn('Age','love',temp_df)

In [114]:
x.get()

{'impurity': 0.34285714285714286, 'average': 15.0}

Unnamed: 0,Passengerid,Age,Fare,Sex,Pclass,Embarked,survived
0,1246,0.17,20.5750,1,3,2.0,0
1,1093,0.33,14.4000,0,3,2.0,0
2,804,0.42,8.5167,0,3,0.0,1
3,756,0.67,14.5000,0,2,2.0,1
4,645,0.75,19.2583,1,3,0.0,1
...,...,...,...,...,...,...,...
1304,97,71.00,34.6542,0,1,0.0,0
1305,494,71.00,49.5042,0,1,0.0,0
1306,852,74.00,7.7750,0,3,2.0,0
1307,988,76.00,78.8500,1,1,2.0,0


In [52]:
df['Pclass'].unique()

array([3, 1, 2], dtype=int64)

In [122]:
categorical_columns = ['pop','soda']
numerical_columns = ['age']
target_column = 'love'
gini = []
for category in categorical_columns:
    x = GiniImpurityForCategoricalColumn(category,target_column,temp_df)
    y = x.get()
    gini.append({'category':category, 'gini_value':y})
    
for category in numerical_columns:
    x = GiniImpurityForNumericalColumn(category,target_column,temp_df)
    y = x.get()
    gini.append({'category':category, 'gini_value':y})

In [123]:
gini

[{'category': 'pop', 'gini_value': 0.40476190476190477},
 {'category': 'soda', 'gini_value': 0.21428571428571427},
 {'category': 'age',
  'gini_value': {'impurity': 0.34285714285714286, 'average': 15.0}}]

In [117]:
temp_df = pd.DataFrame({'pop':['yes','yes','no','no','yes','yes','no'],
                        'soda': ['yes','no','yes','yes','yes','no','no'],
                        'age':[7,12,18,35,38,50,83],
                       'love':['no','no','yes','yes','yes','no','no']})

In [118]:
temp_df

Unnamed: 0,pop,soda,age,love
0,yes,yes,7,no
1,yes,no,12,no
2,no,yes,18,yes
3,no,yes,35,yes
4,yes,yes,38,yes
5,yes,no,50,no
6,no,no,83,no
