# %% [markdown]
# # Healthcare Spending Analysis
# **Notebook Purpose**: Analyze survey data from MongoDB and generate visualizations

# %% [markdown]
# ## 1. Setup Environment

# %%


In [3]:

# Cell 1: Install required packages
import pandas as pd
from pymongo import MongoClient
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import os


In [4]:
# Cell 2: Connect to MongoDB
# Load environment variables from .env file
load_dotenv('../.env')  # Load from parent directory

client = MongoClient(os.getenv('MONGO_URI'))
db = client.healthcare_survey
collection = db.users

print("Connected to MongoDB. Collection size:", collection.count_documents({}))

Connected to MongoDB. Collection size: 6


# %% [markdown]
# ## 3. Data Loading & Processing

# %%

In [None]:
# Cell 3: Define User Class and Load Data
class User:
    def __init__(self, data):
        self.age = data.get('age')
        self.gender = data.get('gender')
        self.income = data.get('income')
        self.expenses = data.get('expenses', {})
    
    def to_dict(self):
        return {
            'age': self.age,
            'gender': self.gender,
            'income': self.income,
            **self.expenses
        }

# Load data from MongoDB
users = [User(user).to_dict() for user in collection.find()]
df = pd.DataFrame(users)

# Clean data
df.rename(columns={
    'utilities': 'Utilities',
    'entertainment': 'Entertainment',
    'school_fees': 'School Fees',
    'shopping': 'Shopping',
    'healthcare': 'Healthcare'
}, inplace=True)

print("Data loaded successfully. First 5 rows:")
df.head()

# %% [markdown]
# ## 4. Data Visualization

# %%

In [None]:
# Cell 4: Top Ages by Income
plt.figure(figsize=(12, 6))
top_ages = df.groupby('age')['income'].mean().nlargest(10)
top_ages.plot(kind='bar', color='#1f77b4')
plt.title('Top 10 Ages by Average Income', fontsize=14)
plt.xlabel('Age', fontsize=12)
plt.ylabel('Average Income ($)', fontsize=12)
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('../app/static/charts/top_ages.png', dpi=300)
plt.show()

In [None]:
# Cell 5: Gender Spending Distribution
gender_spending = df.groupby('gender')[['Utilities', 'Entertainment', 
                                      'School Fees', 'Shopping', 'Healthcare']].mean()

plt.figure(figsize=(14, 8))
gender_spending.plot(kind='bar', stacked=True, colormap='Pastel2')
plt.title('Average Spending by Gender', fontsize=14)
plt.xlabel('Gender', fontsize=12)
plt.ylabel('Total Spending ($)', fontsize=12)
plt.xticks(rotation=0)
plt.legend(title='Category', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('../app/static/charts/gender_spending.png', dpi=300)
plt.show()

In [None]:
# Cell 6: Save Processed Data to CSV
output_dir = '../data'
os.makedirs(output_dir, exist_ok=True)

csv_path = os.path.join(output_dir, 'processed_healthcare_data.csv')
df.to_csv(csv_path, index=False)

print(f"Data exported to {csv_path}")
print("Final DataFrame Summary:")
df.describe()