In [30]:
import pandas as pd
import sqlalchemy as sa
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 0. Save data to sql database

## 0.1 Create a connection to the postgres server

In [31]:
# create connection to database
engine = sa.create_engine(
    # postgres connection url
    "postgresql://postgres:pass@localhost/postgres",
    # select the schema
    connect_args={"options": "-c search_path=hr"},
)
conn = engine.connect()


## 0.2 Save data to sql tables
> Don't need to do this if data was already imported some other way.

In [32]:
# employee_df = pd.read_csv("Datasets/employee_survey_data.csv", index_col=0)
# employee_df.to_sql("employee_survey_data", conn, if_exists="replace")
# general_df = pd.read_csv("Datasets/general_data.csv", index_col=8)
# general_df.to_sql("general_data", conn, if_exists="replace")
# manager_df = pd.read_csv("Datasets/manager_survey_data.csv", index_col=0)
# manager_df.to_sql("manager_survey_data", conn, if_exists="replace")

# 1. Importing the data

In [33]:
employee_df = pd.read_sql_table("employee_survey_data", conn)
general_df = pd.read_sql_table("general_data", conn)
manager_df = pd.read_sql_table("manager_survey_data", conn)

#######################################
# Alternatively import from csv files #
#######################################
# employee_df = pd.read_csv("Datasets/employee_survey_data.csv", index_col=0)
# general_df = pd.read_csv("Datasets/general_data.csv", index_col=8)
# manager_df = pd.read_csv("Datasets/manager_survey_data.csv", index_col=0)

In [34]:
display(employee_df.head())
display(general_df.head())
display(manager_df.head())

Unnamed: 0,EmployeeID,EnvironmentSatisfaction,JobSatisfaction,WorkLifeBalance
0,1,3.0,4.0,2.0
1,2,3.0,2.0,4.0
2,3,2.0,2.0,1.0
3,4,4.0,4.0,3.0
4,5,4.0,1.0,3.0


Unnamed: 0,EmployeeID,Age,Attrition,BusinessTravel,Department,DistanceFromHome,Education,EducationField,EmployeeCount,Gender,...,NumCompaniesWorked,Over18,PercentSalaryHike,StandardHours,StockOptionLevel,TotalWorkingYears,TrainingTimesLastYear,YearsAtCompany,YearsSinceLastPromotion,YearsWithCurrManager
0,1,51,No,Travel_Rarely,Sales,6,2,Life Sciences,1,Female,...,1.0,Y,11,8,0,1.0,6,1,0,0
1,2,31,Yes,Travel_Frequently,Research & Development,10,1,Life Sciences,1,Female,...,0.0,Y,23,8,1,6.0,3,5,1,4
2,3,32,No,Travel_Frequently,Research & Development,17,4,Other,1,Male,...,1.0,Y,15,8,3,5.0,2,5,0,3
3,4,38,No,Non-Travel,Research & Development,2,5,Life Sciences,1,Male,...,3.0,Y,11,8,3,13.0,5,8,7,5
4,5,32,No,Travel_Rarely,Research & Development,10,1,Medical,1,Male,...,4.0,Y,12,8,2,9.0,2,6,0,4


Unnamed: 0,EmployeeID,JobInvolvement,PerformanceRating
0,1,3,3
1,2,2,4
2,3,3,3
3,4,2,3
4,5,3,3


# 2. Data Preprocessing

## 2.1 Checking for null values

We will check for any null values, and decide if we want to drop them or impute them.

In [48]:
general_df_na_ratio = general_df.isna().sum().sum() / general_df.shape[0]
employee_df_na_ratio = employee_df.isna().sum().sum() / employee_df.shape[0]
manager_df_na_ratio = manager_df.isna().sum().sum() / manager_df.shape[0]

print(f"General Dataframe NA Ratio: {general_df_na_ratio:.3f}")
print(f"Employee Dataframe NA Ratio: {employee_df_na_ratio:.3f}")
print(f"Manager Dataframe NA Ratio: {manager_df_na_ratio:.3f}")


General Dataframe NA Ratio: 0.006
Employee Dataframe NA Ratio: 0.019
Manager Dataframe NA Ratio: 0.000


There are not very many null values so we can safely drop them without affecting the final analysis.

## 2.2 Checking for duplicate values

In [52]:
print(f"General Dataframe Duplicate Rows: {general_df.duplicated().sum()}")
print(f"Employee Dataframe Duplicate Rows: {employee_df.duplicated().sum()}")
print(f"Manager Dataframe Duplicate Rows: {manager_df.duplicated().sum()}")

General Dataframe Duplicate Rows: 0
Employee Dataframe Duplicate Rows: 0
Manager Dataframe Duplicate Rows: 0


There are no duplicate rows in the dataset.

# 3. Exploratory Data Analysis

Create functions to visualise numerical data and categorical data

In [64]:
def numerical_vis(data, variable):
    fig, ax = plt.

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4410 entries, 0 to 4409
Data columns (total 24 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   EmployeeID               4410 non-null   int64  
 1   Age                      4410 non-null   int64  
 2   Attrition                4410 non-null   object 
 3   BusinessTravel           4410 non-null   object 
 4   Department               4410 non-null   object 
 5   DistanceFromHome         4410 non-null   int64  
 6   Education                4410 non-null   int64  
 7   EducationField           4410 non-null   object 
 8   EmployeeCount            4410 non-null   int64  
 9   Gender                   4410 non-null   object 
 10  JobLevel                 4410 non-null   int64  
 11  JobRole                  4410 non-null   object 
 12  MaritalStatus            4410 non-null   object 
 13  MonthlyIncome            4410 non-null   int64  
 14  NumCompaniesWorked      