In [1]:
# Importing necessary libraries.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

# Loading the survey data.
df = pd.read_csv('../data/raw/survey.csv')

# Making a copy of the original dataframe to preserve raw data.
df_fe = df.copy()

# Dropping unnecessary columns that are not useful for modeling.
# 'Timestamp' is being dropped because it's not relevant to the target.
df_fe.drop(columns=['Timestamp', 'comments'], inplace=True)

# Standardizing gender entries to reduce inconsistency.
# Creating a function to clean gender column.
def clean_gender(gender):
    gender = str(gender).lower()
    if gender in ['male', 'm', 'male-ish', 'maile', 'mal', 'man', 'cis male', 'male (cis)', 'make', 'msle', 'mail', 'cis man']:
        return 'Male'
    elif gender in ['female', 'f', 'woman', 'cis female', 'female (cis)', 'femake', 'cis-female/femme', 'femail']:
        return 'Female'
    else:
        return 'Other'

# Applying the cleaning function to the 'Gender' column.
df_fe['Gender'] = df_fe['Gender'].apply(clean_gender)

# Handling missing values in key categorical columns by filling with mode.
categorical_columns = ['self_employed', 'work_interfere', 'state']
for col in categorical_columns:
    df_fe[col].fillna(df_fe[col].mode()[0], inplace=True)

# Filling missing values in numerical column 'Age' with median.
df_fe['Age'] = pd.to_numeric(df_fe['Age'], errors='coerce')  # Converting non-numeric ages to NaN.
df_fe = df_fe[(df_fe['Age'] >= 16) & (df_fe['Age'] <= 100)]  # Removing outliers.
df_fe['Age'].fillna(df_fe['Age'].median(), inplace=True)

# Creating new feature: age group.
df_fe['age_group'] = pd.cut(df_fe['Age'], bins=[15, 25, 35, 50, 100], labels=['18-25', '26-35', '36-50', '51+'])

# Converting all categorical variables to category dtype.
for col in df_fe.select_dtypes(include='object').columns:
    df_fe[col] = df_fe[col].astype('category')

# Encoding categorical variables using one-hot encoding.
df_encoded = pd.get_dummies(df_fe, drop_first=True)

# Printing the shape of the dataset after feature engineering.
print(f"Shape of dataset before encoding: {df_fe.shape}")
print(f"Shape of dataset after encoding: {df_encoded.shape}")

# Saving the processed data to a new file.
df_encoded.to_csv('../data/processed/mental_health_cleaned.csv', index=False)

# Displaying first few rows of the processed dataset
df_encoded.head()

Shape of dataset before encoding: (1251, 26)
Shape of dataset after encoding: (1251, 137)


Unnamed: 0,Age,Gender_Male,Gender_Other,Country_Austria,Country_Belgium,Country_Bosnia and Herzegovina,Country_Brazil,Country_Bulgaria,Country_Canada,Country_China,...,mental_health_interview_No,mental_health_interview_Yes,phys_health_interview_No,phys_health_interview_Yes,mental_vs_physical_No,mental_vs_physical_Yes,obs_consequence_Yes,age_group_26-35,age_group_36-50,age_group_51+
0,37,False,False,False,False,False,False,False,False,False,...,True,False,False,False,False,True,False,False,True,False
1,44,True,False,False,False,False,False,False,False,False,...,True,False,True,False,False,False,False,False,True,False
2,32,True,False,False,False,False,False,False,True,False,...,False,True,False,True,True,False,False,True,False,False
3,31,True,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,True,True,False,False
4,31,True,False,False,False,False,False,False,False,False,...,False,True,False,True,False,False,False,True,False,False


In [2]:
# saving correlation matrix figure.
plt.figure(figsize=(12, 10))
sns.heatmap(df_encoded.corr(), annot=False, cmap='coolwarm', fmt='.2f')
plt.title("Correlation Matrix")
plt.tight_layout()
plt.savefig('../outputs/figures/correlation_matrix.png')
plt.close()

In [3]:
# saving class distribution figure.
# Plotting class distribution of the target variable
plt.figure(figsize=(6, 4))
sns.countplot(x='treatment_Yes', data=df_encoded, palette='Set2')
plt.title("Class Distribution of Target Variable")
plt.xlabel("Treatment (0 = No, 1 = Yes)")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig('../outputs/figures/class_distribution.png')
plt.close()
