In [1]:
import sys
import os
from pathlib import Path

# Add parent directory to Python path
notebook_dir = Path(os.getcwd())
project_dir = notebook_dir.parent
sys.path.append(str(project_dir))


# notebooks/02_model_training.ipynb

import pandas as pd
import numpy as np
from src.data_processor import DataProcessor
from src.models.expenditure_predictor import ExpenditurePredictor
import plotly.graph_objects as go

# Prepare data
processor = DataProcessor()
df = processor.load_household_data('../data/sample_household_data.csv')

# Train separate models for each household
household_models = {}
household_metrics = {}

for household in df['HouseholdID'].unique():
    print(f"\nTraining model for {household}")
    
    # Filter data for current household
    household_df = df[df['HouseholdID'] == household].copy()
    X, y = processor.prepare_training_data(household_df)
    
    # Split data
    train_size = int(len(X) * 0.8)
    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]
    
    # Train model
    predictor = ExpenditurePredictor()
    history = predictor.train(X_train, y_train)
    
    # Evaluate model
    y_pred = predictor.predict(X_test)
    mse = np.mean((y_test - y_pred.flatten()) ** 2)
    
    # Store model and metrics
    household_models[household] = predictor
    household_metrics[household] = {
        'mse': mse,
        'history': history
    }
    
    print(f"MSE for {household}: {mse:.4f}")


# Save models for each household
for household in household_models:
    model_path = f'../src/models/saved/model_{household}.keras'
    household_models[household].save(model_path)
    print(f"Saved model for {household} at {model_path}")

# Visualize training history for each household
fig = go.Figure()
for household, metrics in household_metrics.items():
    history = metrics['history']
    fig.add_trace(go.Scatter(
        y=history.history['loss'],
        name=f'{household} Training Loss',
        mode='lines'
    ))
    fig.add_trace(go.Scatter(
        y=history.history['val_loss'],
        name=f'{household} Validation Loss',
        mode='lines',
        line=dict(dash='dash')
    ))

fig.update_layout(
    title='Training History by Household',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    hovermode='x unified'
)
fig.show()

# Compare predictions vs actual for each household
for household in household_models:
    household_df = df[df['HouseholdID'] == household].copy()
    X, y = processor.prepare_training_data(household_df)
    
    # Get predictions
    y_pred = household_models[household].predict(X)
    
    # Plot actual vs predicted
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=y,
        name='Actual',
        mode='lines'
    ))
    fig.add_trace(go.Scatter(
        y=y_pred.flatten(),
        name='Predicted',
        mode='lines'
    ))
    
    fig.update_layout(
        title=f'Actual vs Predicted Expenditure - {household}',
        xaxis_title='Time',
        yaxis_title='Expenditure ($)',
        hovermode='x unified'
    )
    fig.show()



Training model for H1
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/10









