In [6]:
import sys
import os
from pathlib import Path

# Add parent directory to Python path
notebook_dir = Path(os.getcwd())
project_dir = notebook_dir.parent
sys.path.append(str(project_dir))


# notebooks/02_model_training.ipynb

import pandas as pd
import numpy as np
from src.data_processor import DataProcessor
from src.models.expenditure_predictor import ExpenditurePredictor
import plotly.graph_objects as go

# Prepare data
processor = DataProcessor()
df = processor.load_household_data('../data/sample_household_data.csv')

# Train separate models for each household
household_models = {}
household_metrics = {}

for household in df['HouseholdID'].unique():
    print(f"\nTraining model for {household}")
    
    # Filter data for current household
    household_df = df[df['HouseholdID'] == household].copy()
    X, y = processor.prepare_training_data(household_df)
    
    # Split data
    train_size = int(len(X) * 0.8)
    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]
    
    # Train model
    predictor = ExpenditurePredictor()
    history = predictor.train(X_train, y_train)
    
    # Evaluate model
    y_pred = predictor.predict(X_test)
    mse = np.mean((y_test - y_pred.flatten()) ** 2)
    
    # Store model and metrics
    household_models[household] = predictor
    household_metrics[household] = {
        'mse': mse,
        'history': history
    }
    
    print(f"MSE for {household}: {mse:.4f}")

# Visualize training history for each household
fig = go.Figure()
for household, metrics in household_metrics.items():
    history = metrics['history']
    fig.add_trace(go.Scatter(
        y=history.history['loss'],
        name=f'{household} Training Loss',
        mode='lines'
    ))
    fig.add_trace(go.Scatter(
        y=history.history['val_loss'],
        name=f'{household} Validation Loss',
        mode='lines',
        line=dict(dash='dash')
    ))

fig.update_layout(
    title='Training History by Household',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    hovermode='x unified'
)
fig.show()

# Compare predictions vs actual for each household
for household in household_models:
    household_df = df[df['HouseholdID'] == household].copy()
    X, y = processor.prepare_training_data(household_df)
    
    # Get predictions
    y_pred = household_models[household].predict(X)
    
    # Plot actual vs predicted
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=y,
        name='Actual',
        mode='lines'
    ))
    fig.add_trace(go.Scatter(
        y=y_pred.flatten(),
        name='Predicted',
        mode='lines'
    ))
    
    fig.update_layout(
        title=f'Actual vs Predicted Expenditure - {household}',
        xaxis_title='Time',
        yaxis_title='Expenditure ($)',
        hovermode='x unified'
    )
    fig.show()



Training model for H1
Epoch 1/100


  super().__init__(**kwargs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 326ms/step - loss: 0.3319 - val_loss: 0.1039
Epoch 2/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - loss: 0.2594 - val_loss: 0.0724
Epoch 3/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - loss: 0.1977 - val_loss: 0.0444
Epoch 4/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step - loss: 0.1207 - val_loss: 0.0301
Epoch 5/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step - loss: 0.0679 - val_loss: 0.0469
Epoch 6/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - loss: 0.0443 - val_loss: 0.1018
Epoch 7/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step - loss: 0.0496 - val_loss: 0.1237
Epoch 8/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 95ms/step - loss: 0.0590 - val_loss: 0.0927
Epoch 9/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m

  super().__init__(**kwargs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 287ms/step - loss: 0.4006 - val_loss: 0.1529
Epoch 2/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step - loss: 0.3192 - val_loss: 0.1150
Epoch 3/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step - loss: 0.2625 - val_loss: 0.0859
Epoch 4/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - loss: 0.2111 - val_loss: 0.0617
Epoch 5/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step - loss: 0.1570 - val_loss: 0.0447
Epoch 6/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - loss: 0.1126 - val_loss: 0.0420
Epoch 7/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step - loss: 0.0662 - val_loss: 0.0644
Epoch 8/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step - loss: 0.0455 - val_loss: 0.1053
Epoch 9/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m

  super().__init__(**kwargs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 327ms/step - loss: 0.3706 - val_loss: 0.1181
Epoch 2/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step - loss: 0.3010 - val_loss: 0.0914
Epoch 3/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - loss: 0.2626 - val_loss: 0.0735
Epoch 4/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step - loss: 0.2174 - val_loss: 0.0584
Epoch 5/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - loss: 0.1762 - val_loss: 0.0473
Epoch 6/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - loss: 0.1402 - val_loss: 0.0449
Epoch 7/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - loss: 0.0961 - val_loss: 0.0582
Epoch 8/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - loss: 0.0775 - val_loss: 0.0925
Epoch 9/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m

  super().__init__(**kwargs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 355ms/step - loss: 0.3270 - val_loss: 0.0795
Epoch 2/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step - loss: 0.2297 - val_loss: 0.0506
Epoch 3/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - loss: 0.1620 - val_loss: 0.0386
Epoch 4/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - loss: 0.1031 - val_loss: 0.0520
Epoch 5/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - loss: 0.0624 - val_loss: 0.0919
Epoch 6/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step - loss: 0.0553 - val_loss: 0.1233
Epoch 7/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 57ms/step - loss: 0.0601 - val_loss: 0.1352
Epoch 8/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 61ms/step - loss: 0.0644 - val_loss: 0.1266
Epoch 9/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m

  super().__init__(**kwargs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 328ms/step - loss: 0.2656 - val_loss: 0.0501
Epoch 2/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - loss: 0.2112 - val_loss: 0.0322
Epoch 3/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - loss: 0.1603 - val_loss: 0.0264
Epoch 4/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step - loss: 0.0942 - val_loss: 0.0387
Epoch 5/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step - loss: 0.0590 - val_loss: 0.0754
Epoch 6/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step - loss: 0.0416 - val_loss: 0.1197
Epoch 7/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - loss: 0.0506 - val_loss: 0.1255
Epoch 8/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - loss: 0.0555 - val_loss: 0.1006
Epoch 9/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m

[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step 


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 19ms/step 


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step 


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 19ms/step 


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 14ms/step 
