# Intuit Refund ETA Prediction

This notebook demonstrates how to:
- Generate mock tax refund data
- Train an XGBoost regression model
- Predict refund arrival times
- Explain predictions with SHAP

In [1]:
import numpy as np
import pandas as pd
import xgboost as xgb
from xgboost import XGBRegressor
import seaborn as sns
import shap
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
from sklearn.inspection import PartialDependenceDisplay
from train_model import train_model, categorical_cols

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

SyntaxError: invalid syntax (3982017575.py, line 8)

In [None]:
# Read generated fake data
df = pd.read_csv("../generated/training_data.csv")

# Quick preview
df.head()

In [None]:
# Train model
model, X_encoded, X_test, X_train, dtest, y_test, y_train = train_model("../generated/training_data.csv")

print("X_train types:")
print(X_train.dtypes.value_counts())

y_pred = model.predict(dtest)
rmse = root_mean_squared_error(y_test, y_pred)

print(f"RMSE: {rmse:.2f} days")

In [None]:
# Use the model to predict a tax refund time

# Define a test user record (raw, not encoded)
test_record = {
    "filing_method": "efile_direct_deposit",
    "filing_time_category": "early",
    "bank_deposit_type": "traditional_bank",
    "geo_region": "west",
    "prior_credits_claimed": "few",

    "has_return_errors": 0,
    "requires_id_verification": 0,
    "is_selected_for_manual_review": 0,
    "claimed_eitc": 1,
    "claimed_actc": 0,
    "is_amended_return": 0,
    "has_injured_spouse_claim": 0,
    "has_offset_debts": 0,
    "prior_refund_delayed": 0,
    "prior_id_verification_flagged": 0,
    "has_bank_info_on_file": 1,

    "num_days_since_filed": 12,
    "return_completeness_score": 0.95,
    "prior_refund_processing_time": 10
}

# Convert to DataFrame
sample_user = pd.DataFrame([test_record])

# One-hot encode the same categorical columns
sample_user_encoded = pd.get_dummies(sample_user, columns=categorical_cols)

model_features = list(X_encoded.columns)

# Ensure all columns match the model
for col in model_features:
    if col not in sample_user_encoded.columns:
        sample_user_encoded[col] = 0

# Reorder columns to match training data
sample_user_encoded = sample_user_encoded[model_features]

# Check dtypes
print(sample_user_encoded.dtypes.value_counts())

# Create DMatrix
d_sample = xgb.DMatrix(sample_user_encoded)

# Predict
pred_delay = model.predict(d_sample)

print(f"Predicted refund delay: {pred_delay[0]:.2f} days")

In [None]:
# Use SHAP to explain prediction
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(sample_user_encoded)

shap.initjs()

shap.force_plot(
    explainer.expected_value,
    shap_values,
    sample_user_encoded
)

In [None]:
# Output reasons for delay
reasons = {
    key: float(value)
    for key, value in zip(sample_user.columns, shap_values[0])
}

# Convert to DataFrame
reasons_df = pd.DataFrame(
    list(reasons.items()),
    columns=["Feature", "SHAP Value"]
)

# Sort by absolute SHAP value
reasons_df = reasons_df.reindex(
    reasons_df["SHAP Value"].abs().sort_values(ascending=False).index
)

reasons_df

In [None]:
# Compute SHAP values for entire test set
shap_values_all = explainer.shap_values(X_test)

# Plot summary
shap.summary_plot(shap_values_all, X_test)

In [None]:
xgb.plot_importance(model)
plt.show()

In [None]:
shap.summary_plot(shap_values, X_test, plot_type="bar")

In [None]:
shap.force_plot(
    explainer.expected_value,
    shap_values[0],
    X_test.iloc[0],
    matplotlib=True
)

In [None]:
# Train with scikit-learn API
trained_model_reg = XGBRegressor(
    objective="reg:squarederror",
    max_depth=4,
    eta=0.1,
    n_estimators=100,
)

trained_model_reg.fit(X_train, y_train)

PartialDependenceDisplay.from_estimator(
    trained_model_reg,
    X_test,
    ["num_days_since_filed"]
)

plt.show()

In [None]:
corr = X_test.corr()
plt.figure(figsize=(10,8))
sns.heatmap(corr, annot=False, cmap="coolwarm")
plt.title("Feature Correlations")
plt.show()

In [None]:
errors = y_test - y_pred
plt.hist(errors, bins=30)
plt.xlabel("Prediction Error (days)")
plt.ylabel("Frequency")
plt.title("Distribution of Prediction Errors")
plt.show()

In [None]:
shap_interaction_values = explainer.shap_interaction_values(X_test)
shap.summary_plot(shap_interaction_values, X_test)

In [None]:
plt.scatter(y_test, y_pred, alpha=0.5)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], color='red')
plt.xlabel("Actual refund delay")
plt.ylabel("Predicted refund delay")
plt.title("Predicted vs Actual Refund Delay")
plt.show()