In [7]:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [8]:
data = pd.read_csv('churn_prediction.csv')
data.head()

Unnamed: 0,customer_id,vintage,age,gender,dependents,occupation,city,customer_nw_category,branch_code,current_balance,...,average_monthly_balance_prevQ,average_monthly_balance_prevQ2,current_month_credit,previous_month_credit,current_month_debit,previous_month_debit,current_month_balance,previous_month_balance,churn,last_transaction
0,1,2101,66,Male,0.0,self_employed,187.0,2,755,1458.71,...,1458.71,1449.07,0.2,0.2,0.2,0.2,1458.71,1458.71,0,2019-05-21
1,2,2348,35,Male,0.0,self_employed,,2,3214,5390.37,...,7799.26,12419.41,0.56,0.56,5486.27,100.56,6496.78,8787.61,0,2019-11-01
2,4,2194,31,Male,0.0,salaried,146.0,2,41,3913.16,...,4910.17,2815.94,0.61,0.61,6046.73,259.23,5006.28,5070.14,0,NaT
3,5,2329,90,,,self_employed,1020.0,2,582,2291.91,...,2084.54,1006.54,0.47,0.47,0.47,2143.33,2291.91,1669.79,1,2019-08-06
4,6,1579,42,Male,2.0,self_employed,1494.0,3,388,927.72,...,1643.31,1871.12,0.33,714.61,588.62,1538.06,1157.15,1677.16,1,2019-11-03


In [9]:
#extract the month , year and date from 'last_transaction' attribute

In [10]:
data['last_transaction'] = pd.to_datetime(data['last_transaction'])

# Extract relevant features from the datetime column
data['year'] = data['last_transaction'].dt.year
data['month'] = data['last_transaction'].dt.month
data['day'] = data['last_transaction'].dt.day


# Drop the original string date column if no longer needed
data.drop(columns=['last_transaction'], inplace=True)

In [11]:
#delete the rows which having the null values

In [12]:
data.dropna(inplace=True)
data.head()

Unnamed: 0,customer_id,vintage,age,gender,dependents,occupation,city,customer_nw_category,branch_code,current_balance,...,current_month_credit,previous_month_credit,current_month_debit,previous_month_debit,current_month_balance,previous_month_balance,churn,year,month,day
0,1,2101,66,Male,0.0,self_employed,187.0,2,755,1458.71,...,0.2,0.2,0.2,0.2,1458.71,1458.71,0,2019.0,5.0,21.0
4,6,1579,42,Male,2.0,self_employed,1494.0,3,388,927.72,...,0.33,714.61,588.62,1538.06,1157.15,1677.16,1,2019.0,11.0,3.0
5,7,1923,42,Female,0.0,self_employed,1096.0,2,1666,15202.2,...,0.36,0.36,857.5,286.07,15719.44,15349.75,0,2019.0,11.0,1.0
6,8,2048,72,Male,0.0,retired,1020.0,1,1,7006.93,...,0.64,0.64,1299.64,439.26,7076.06,7755.98,0,2019.0,9.0,24.0
7,9,2009,46,Male,0.0,self_employed,623.0,2,317,10096.58,...,0.27,0.27,443.13,5688.44,8563.84,5317.04,0,2019.0,7.0,12.0


In [13]:
#shape of the preprocessed dataset
data.shape

(22067, 23)

In [14]:
#getting information about each attribute
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 22067 entries, 0 to 28381
Data columns (total 23 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   customer_id                     22067 non-null  int64  
 1   vintage                         22067 non-null  int64  
 2   age                             22067 non-null  int64  
 3   gender                          22067 non-null  object 
 4   dependents                      22067 non-null  float64
 5   occupation                      22067 non-null  object 
 6   city                            22067 non-null  float64
 7   customer_nw_category            22067 non-null  int64  
 8   branch_code                     22067 non-null  int64  
 9   current_balance                 22067 non-null  float64
 10  previous_month_end_balance      22067 non-null  float64
 11  average_monthly_balance_prevQ   22067 non-null  float64
 12  average_monthly_balance_prevQ2  

In [16]:
#no. of values of each kind in each categorical attributes
print(data.occupation.value_counts())
print('\n',data.gender.value_counts())

self_employed    13434
salaried          5602
retired           1638
student           1369
company             24
Name: occupation, dtype: int64

 Male      13421
Female     8646
Name: gender, dtype: int64


In [17]:
#preprocess the categorical attributes
data.replace({'occupation':{'self_employed':0,'salaried':1,'retired':2,'student':3,'company':4}},inplace=True)
data.replace({'gender':{'Male':0,'Female':1}},inplace=True)

In [18]:
#dataset after preprocess
data

Unnamed: 0,customer_id,vintage,age,gender,dependents,occupation,city,customer_nw_category,branch_code,current_balance,...,current_month_credit,previous_month_credit,current_month_debit,previous_month_debit,current_month_balance,previous_month_balance,churn,year,month,day
0,1,2101,66,0,0.0,0,187.0,2,755,1458.71,...,0.20,0.20,0.20,0.20,1458.71,1458.71,0,2019.0,5.0,21.0
4,6,1579,42,0,2.0,0,1494.0,3,388,927.72,...,0.33,714.61,588.62,1538.06,1157.15,1677.16,1,2019.0,11.0,3.0
5,7,1923,42,1,0.0,0,1096.0,2,1666,15202.20,...,0.36,0.36,857.50,286.07,15719.44,15349.75,0,2019.0,11.0,1.0
6,8,2048,72,0,0.0,2,1020.0,1,1,7006.93,...,0.64,0.64,1299.64,439.26,7076.06,7755.98,0,2019.0,9.0,24.0
7,9,2009,46,0,0.0,0,623.0,2,317,10096.58,...,0.27,0.27,443.13,5688.44,8563.84,5317.04,0,2019.0,7.0,12.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28375,30295,2398,42,0,0.0,0,146.0,2,286,7493.69,...,0.51,90.10,1103.20,1183.04,7956.03,7431.36,0,2019.0,11.0,4.0
28377,30297,2325,10,1,0.0,3,1020.0,2,1207,1076.43,...,0.30,0.30,0.30,0.30,1076.43,1076.43,0,2019.0,10.0,22.0
28378,30298,1537,34,1,0.0,0,1046.0,2,223,3844.10,...,1.71,2.29,901.00,1014.07,3738.54,3690.32,0,2019.0,12.0,17.0
28379,30299,2376,47,0,0.0,1,1096.0,2,588,65511.97,...,4666.84,3883.06,168.23,71.80,61078.50,57564.24,1,2019.0,12.0,31.0


In [19]:
#to check whether any attribute contains null values
data.isna().sum()

customer_id                       0
vintage                           0
age                               0
gender                            0
dependents                        0
occupation                        0
city                              0
customer_nw_category              0
branch_code                       0
current_balance                   0
previous_month_end_balance        0
average_monthly_balance_prevQ     0
average_monthly_balance_prevQ2    0
current_month_credit              0
previous_month_credit             0
current_month_debit               0
previous_month_debit              0
current_month_balance             0
previous_month_balance            0
churn                             0
year                              0
month                             0
day                               0
dtype: int64

In [20]:
#assigning the input features and output features
x=data.drop(['customer_id','churn','year','month','day'],axis=1)
y=data['churn']


In [21]:
#splitting dataset into training and testing data
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=42)

In [33]:
#assign the 'DecisionTreeClassifier' instance to the 'model' variable
model=DecisionTreeClassifier(
    criterion='entropy',  # Change criterion to entropy
    max_depth=5
)


In [34]:
#getting the pasrameters of the model
model.get_params()

{'ccp_alpha': 0.0,
 'class_weight': None,
 'criterion': 'entropy',
 'max_depth': 5,
 'max_features': None,
 'max_leaf_nodes': None,
 'min_impurity_decrease': 0.0,
 'min_samples_leaf': 1,
 'min_samples_split': 2,
 'min_weight_fraction_leaf': 0.0,
 'random_state': None,
 'splitter': 'best'}

In [35]:
#fit or train the model on training dataset
model.fit(x_train,y_train)


In [36]:
#take the predictions  of model on testing data
pred=model.predict(x_test)

In [37]:
#check the accuracy of the model on testing data
print(accuracy_score(y_test,pred))

0.84700196344963


In [38]:
'''check the model performance on sample values of a customer
   if output as '0' --> non churn
      output as '1' --> churn'''
samp=[[2101,66,0,0.0,1,187.0,2,755,1458.71,1458.71,1458.71,1449.07,0.2,0.2,0.2,0.2,1458.71,1458.71]]
pred1 = model.predict(np.array(samp))
print(pred1)

[0]




In [39]:
#checking the model accuracy on training data
pred2=model.predict(x_train)
print(accuracy_score(y_train,pred2))

0.8552376019681471
