# Churn prediction for an energy company 🔌💡

## Part 1 - Importing the data and performing data quality checks

Let's start by quickly inspecting the two datasets stored in `client_data_raw.csv` and `price_data_raw.csv`.

In [1]:
import pandas as pd

repo_path = "/workspaces/myfolder/energy-churn-prediction"

raw_client_df = pd.read_csv(f"{repo_path}/data/client_data_raw.csv")
raw_client_df.head(5)

Unnamed: 0,id,channel_sales,cons_12m,cons_gas_12m,cons_last_month,date_activ,date_end,date_modif_prod,date_renewal,forecast_cons_12m,...,has_gas,imp_cons,margin_gross_pow_ele,margin_net_pow_ele,nb_prod_act,net_margin,num_years_antig,origin_up,pow_max,churn
0,24011ae4ebbe3035111d65fa7c15bc57,foosdfpfkusacimwkcsosbicdxkicaua,0,54946,0,2013-06-15,2016-06-15,2015-11-01,2015-06-23,0.0,...,t,0.0,25.44,25.44,2,678.99,3,lxidpiddsbxsbosboudacockeimpuepw,43.648,1
1,d29c2c54acc38ff3c0614d0a653813dd,MISSING,4660,0,0,2009-08-21,2016-08-30,2009-08-21,2015-08-31,189.95,...,f,0.0,16.38,16.38,1,18.89,6,kamkkxfxxuwbdslkwifmmcsiusiuosws,13.8,0
2,764c75f661154dac3a6c254cd082ea7d,foosdfpfkusacimwkcsosbicdxkicaua,544,0,0,2010-04-16,2016-04-16,2010-04-16,2015-04-17,47.96,...,f,0.0,28.6,28.6,1,6.6,6,kamkkxfxxuwbdslkwifmmcsiusiuosws,13.856,0
3,bba03439a292a1e166f80264c16191cb,lmkebamcaaclubfxadlmueccxoimlema,1584,0,0,2010-03-30,2016-03-30,2010-03-30,2015-03-31,240.04,...,f,0.0,30.22,30.22,1,25.46,6,kamkkxfxxuwbdslkwifmmcsiusiuosws,13.2,0
4,149d57cf92fc41cf94415803a877cb4b,MISSING,4425,0,526,2010-01-13,2016-03-07,2010-01-13,2015-03-09,445.75,...,f,52.32,44.91,44.91,1,47.98,6,kamkkxfxxuwbdslkwifmmcsiusiuosws,19.8,0


In [2]:
raw_price_df = pd.read_csv(f"{repo_path}/data/price_data_raw.csv")
raw_price_df

Unnamed: 0,id,price_date,price_off_peak_var,price_peak_var,price_mid_peak_var,price_off_peak_fix,price_peak_fix,price_mid_peak_fix
0,038af19179925da21a25619c5a24b745,2015-01-01,0.151367,0.000000,0.000000,44.266931,0.00000,0.000000
1,038af19179925da21a25619c5a24b745,2015-02-01,0.151367,0.000000,0.000000,44.266931,0.00000,0.000000
2,038af19179925da21a25619c5a24b745,2015-03-01,0.151367,0.000000,0.000000,44.266931,0.00000,0.000000
3,038af19179925da21a25619c5a24b745,2015-04-01,0.149626,0.000000,0.000000,44.266931,0.00000,0.000000
4,038af19179925da21a25619c5a24b745,2015-05-01,0.149626,0.000000,0.000000,44.266931,0.00000,0.000000
...,...,...,...,...,...,...,...,...
192997,16f51cdc2baa19af0b940ee1b3dd17d5,2015-08-01,0.119916,0.102232,0.076257,40.728885,24.43733,16.291555
192998,16f51cdc2baa19af0b940ee1b3dd17d5,2015-09-01,0.119916,0.102232,0.076257,40.728885,24.43733,16.291555
192999,16f51cdc2baa19af0b940ee1b3dd17d5,2015-10-01,0.119916,0.102232,0.076257,40.728885,24.43733,16.291555
193000,16f51cdc2baa19af0b940ee1b3dd17d5,2015-11-01,0.119916,0.102232,0.076257,40.728885,24.43733,16.291555


In [3]:
print(raw_client_df.dtypes)

id                                 object
channel_sales                      object
cons_12m                            int64
cons_gas_12m                        int64
cons_last_month                     int64
date_activ                         object
date_end                           object
date_modif_prod                    object
date_renewal                       object
forecast_cons_12m                 float64
forecast_cons_year                  int64
forecast_discount_energy          float64
forecast_meter_rent_12m           float64
forecast_price_energy_off_peak    float64
forecast_price_energy_peak        float64
forecast_price_pow_off_peak       float64
has_gas                            object
imp_cons                          float64
margin_gross_pow_ele              float64
margin_net_pow_ele                float64
nb_prod_act                         int64
net_margin                        float64
num_years_antig                     int64
origin_up                         

In [4]:
print(raw_price_df.dtypes)

id                     object
price_date             object
price_off_peak_var    float64
price_peak_var        float64
price_mid_peak_var    float64
price_off_peak_fix    float64
price_peak_fix        float64
price_mid_peak_fix    float64
dtype: object


As seen in the previous two cells, Pandas did not correctly infer the date columns in both datasets. We'll use `pd.to_datetime()` to convert them to the proper `datetime` format.

In [5]:
client_df = raw_client_df.copy()
price_df = raw_price_df.copy()

client_date_cols = [col for col in client_df.columns if col.startswith('date')]
price_date_col = 'price_date'

print("'client_df':")
print()
for col in client_date_cols:
    client_df[col] = pd.to_datetime(client_df[col], format='%Y-%m-%d', errors='coerce')
    print(f"- Type of {col} variable after conversion: {client_df[col].dtypes}")

print("--------------------------------")
print()

print("'price_df':")
price_df[price_date_col] = pd.to_datetime(price_df[price_date_col], format='%Y-%m-%d', errors='coerce')
print(f"Type of {price_date_col} variable after conversion: {price_df[price_date_col].dtypes}")

'client_df':

- Type of date_activ variable after conversion: datetime64[ns]
- Type of date_end variable after conversion: datetime64[ns]
- Type of date_modif_prod variable after conversion: datetime64[ns]
- Type of date_renewal variable after conversion: datetime64[ns]
--------------------------------

'price_df':
Type of price_date variable after conversion: datetime64[ns]


The `channel_sales` and `origin_up` variables in `raw_client_df` have been encoded into hashed text strings for privacy reasons. To improve readability, we will map these values to more intuitive category labels. Additionally, we'll convert `has_gas` and `churn` values to `'Yes'` and `'No'` for the same reason.

In [6]:
hashed_cols = ['channel_sales', 'origin_up']
prefixes = ['Channel', 'Campaign']
mapping_dicts = {}

for col, prefix in zip(hashed_cols, prefixes):
    # Get unique values from the column:
    unique_values = [value for value in client_df[col].unique() if value!='MISSING']
    unique_values.append('MISSING')

    # Create new values
    new_values = [f"{prefix} {i}" for i in range(1,len(unique_values))]
    new_values.append(f'{prefix} Missing')

    # Create a mapping dictionary
    mapping_dicts[col] = dict(zip(unique_values, new_values))
    print(f"Mapping dictionary for '{col}':\n", mapping_dicts[col])
    print('---------------------------------------------')

mapping_dicts['churn'] = {0: 'No', 1: 'Yes'}
print(f"Mapping dictionary for 'churn':\n", mapping_dicts['churn'])
print('---------------------------------------------')
mapping_dicts['has_gas'] = {'f': 'No', 't': 'Yes'}
print(f"Mapping dictionary for 'has_gas':\n", mapping_dicts['has_gas'])
print('---------------------------------------------')

Mapping dictionary for 'channel_sales':
 {'foosdfpfkusacimwkcsosbicdxkicaua': 'Channel 1', 'lmkebamcaaclubfxadlmueccxoimlema': 'Channel 2', 'usilxuppasemubllopkaafesmlibmsdf': 'Channel 3', 'ewpakwlliwisiwduibdlfmalxowmwpci': 'Channel 4', 'epumfxlbckeskwekxbiuasklxalciiuu': 'Channel 5', 'sddiedcslfslkckwlfkdpoeeailfpeds': 'Channel 6', 'fixdbufsefwooaasfcxdxadsiekoceaa': 'Channel 7', 'MISSING': 'Channel Missing'}
---------------------------------------------
Mapping dictionary for 'origin_up':
 {'lxidpiddsbxsbosboudacockeimpuepw': 'Campaign 1', 'kamkkxfxxuwbdslkwifmmcsiusiuosws': 'Campaign 2', 'ldkssxwpmemidmecebumciepifcamkci': 'Campaign 3', 'usapbepcfoloekilkwsdiboslwaxobdp': 'Campaign 4', 'ewxeelcelemmiwuafmddpobolfuxioce': 'Campaign 5', 'MISSING': 'Campaign Missing'}
---------------------------------------------
Mapping dictionary for 'churn':
 {0: 'No', 1: 'Yes'}
---------------------------------------------
Mapping dictionary for 'has_gas':
 {'f': 'No', 't': 'Yes'}
----------------

In [7]:
mapping_cols = hashed_cols
mapping_cols.extend(['churn', 'has_gas'])

# Apply the mapping to the column
for col in mapping_cols:
    client_df[col] = client_df[col].map(mapping_dicts[col])
    print(f"New {col} values:\n", client_df[col].unique())
    print('--------------------------------------------------')

New channel_sales values:
 ['Channel 1' 'Channel Missing' 'Channel 2' 'Channel 3' 'Channel 4'
 'Channel 5' 'Channel 6' 'Channel 7']
--------------------------------------------------
New origin_up values:
 ['Campaign 1' 'Campaign 2' 'Campaign 3' 'Campaign Missing' 'Campaign 4'
 'Campaign 5']
--------------------------------------------------
New churn values:
 ['Yes' 'No']
--------------------------------------------------
New has_gas values:
 ['Yes' 'No']
--------------------------------------------------


The following code performs essential quality checks to ensure data consistency and accuracy across both tables:

1. Removing duplicate rows.
2. Ensuring certain columns contain only positive values.

Additional checks on `client_df` columns are:

1. Ensuring ID values are unique
2. Removing rows where `has_gas` is `'No'` but `cons_gas_12m` is greater than `0`.
3. Confirming that `date_renewal` and `date_end` are always greater than or equal to `date_activ`. We assumed that `date_modif_prod` might be less than `date_activ`, in cases where the energy contract is modified only when you sign it, and this might happen before the activation date.
4. Verifying that `margin_gross_pow_ele` is greater than or equal to `margin_net_pow_ele`.

In [8]:
print(f"Initial number of rows in client_df: {client_df.shape[0]}")
print(f"Initial number of rows in price_df: {price_df.shape[0]}")
print("--------------------------------------------------------")

# Remove duplicate rows
client_df = client_df.drop_duplicates()

# Define function for checking non-positive values
def remove_negative_values(df, cols, df_name, repo_path):
    for col in cols:
        negative_values_rows = df[df[col] < 0]
        if not negative_values_rows.empty:
            print(f"Warning: '{col}' in {df_name} contains {negative_values_rows.shape[0]} non-positive values.")
            negative_values_rows.to_csv(f"{repo_path}/data/{col}_negative_values_{df_name}.csv", index=False)
            df.drop(index=negative_values_rows.index, inplace=True)

# Check if certain columns contain only positive values
client_positive_cols = [
    'cons_gas_12m', 'cons_12m', 'cons_gas_12m', 'cons_last_month', 
    'forecast_cons_12m', 'forecast_cons_year', 'forecast_discount_energy', 
    'forecast_meter_rent_12m', 'forecast_price_energy_off_peak',
    'forecast_price_energy_peak', 'forecast_price_pow_off_peak',
    'imp_cons', 'nb_prod_act', 'num_years_antig', 'pow_max'
]
price_positive_cols = [
    'price_off_peak_var', 'price_peak_var', 
    'price_mid_peak_var', 'price_off_peak_fix', 
    'price_peak_fix', 'price_mid_peak_fix'
]
remove_negative_values(client_df, client_positive_cols, 'client_df', repo_path)
remove_negative_values(price_df, price_positive_cols, 'price_df', repo_path)

# Check if 'id' column is unique and report if any duplicate ids are found
if client_df['id'].duplicated().any():
    print("Warning: Duplicate IDs found.")
else:
    print("All IDs are unique.")

# Remove rows where 'has_gas' is 'No' but 'cons_gas_12m' is not 0
invalid_gas_rows = client_df[
    (client_df['has_gas'] == 'No') & (client_df['cons_gas_12m'] != 0)
]
if not invalid_gas_rows.empty:
    print(f"There are {invalid_gas_rows.shape[0]} rows with inconsistent 'has_gas' and 'cons_gas_12m' values.")
    invalid_gas_rows.to_csv(f"{repo_path}/data/gas_inconsistent_rows_client_df.csv", index=False)
    # client_df = client_df.drop(invalid_gas_rows.index)
    client_df.drop(index=invalid_gas_rows.index, inplace=True)

# Check that date_end and date_renewal are >= date_activ
invalid_dates = client_df[
    (client_df['date_end'] < client_df['date_activ']) |
    (client_df['date_renewal'] < client_df['date_activ'])
]
if not invalid_dates.empty:
    print(f"There are {invalid_dates.shape[0]} rows with inconsistent dates.")
    invalid_dates.to_csv(f"{repo_path}/data/invalid_dates_client_df.csv", index=False)
    client_df.drop(index=invalid_dates.index, inplace=True)
    # client_df = client_df.drop(invalid_dates.index)

# Verify that margin_gross_pow_ele is >= margin_net_pow_ele
invalid_margin_rows = client_df[
    client_df['margin_gross_pow_ele'] < client_df['margin_net_pow_ele']
]
if not invalid_margin_rows.empty:
    print(f"There are {invalid_margin_rows.shape[0]} rows with inconsistent margins.")
    invalid_margin_rows.to_csv(f"{repo_path}/data/invalid_margin_rows_client_df.csv", index=False)
    client_df.drop(index=invalid_margin_rows.index, inplace=True)
    # client_df = client_df.drop(invalid_margin_rows.index)

print("--------------------------------------------------------")
print(f"Final number of rows in client_df: {client_df.shape[0]}")
print(f"Final number of rows in price_df: {price_df.shape[0]}")

Initial number of rows in client_df: 14606
Initial number of rows in price_df: 193002
--------------------------------------------------------
All IDs are unique.
There are 53 rows with inconsistent 'has_gas' and 'cons_gas_12m' values.
--------------------------------------------------------
Final number of rows in client_df: 14553
Final number of rows in price_df: 193002


In [9]:
invalid_gas_rows[['id']+['has_gas', 'cons_gas_12m']].head(10)

Unnamed: 0,id,has_gas,cons_gas_12m
211,d03b894570bbe809ab2ce610d52e4256,No,458306
931,9f6aa89b4d15b6b60062a51a6b56698d,No,10542
969,8c754ed545769094ac652456aa7d7110,No,298897
993,201b65b25599462f94946cf16b386cb9,No,193
1653,1c65d82e5ac151a43656de3fc026fc8e,No,191
1672,645588ce9410be0d47f4c63783487493,No,3132
1691,aef79ff04e0c1e0af1f028428060a5c4,No,1306
1874,2f52ef4f444bb56552c75ad3cb4385f6,No,21515
2075,d3cd0c17501d1d4c39dca734da32f4d5,No,1199
2401,c8e44781cf503ca69157b5c8474d5565,No,1270


Next, we perform a missing value analysis on `client_df` and `price_df`, displaying only columns with missing values. This helps identify areas requiring data imputation or cleaning.

**Note**: A more detailed missing value assessment for `price_df` will be conducted in the next notebook.

In [10]:
# Check for missing values in client_df
missing_values_client = client_df.isnull().sum()

# Display the columns with missing values
print("Missing values in 'client_df': ")
print(missing_values_client[missing_values_client > 0])

# Check for missing values in price_df
missing_values_price = price_df.isnull().sum()

# Display the columns with missing values
print("Missing values in 'price_df': ")
print(missing_values_price[missing_values_price > 0])

Missing values in 'client_df': 
Series([], dtype: int64)
Missing values in 'price_df': 
Series([], dtype: int64)


Since no columns contain missing values, data imputation is not needed. 

We can now save the cleaned datasets as new CSV files for use in the `Churn Prediction - Part 2` notebook.

In [11]:
client_df.to_csv(f"{repo_path}/data/client_data_cleaned.csv", index=False)
price_df.to_csv(f"{repo_path}/data/price_data_cleaned.csv", index=False)