# 2.0 - Model Training

Train a simple uplift model (Causal Forest) on the simulated data and save a pickled model to `src/uplift_model.pkl`.
Note: For quick prototyping we will use a simple two-model approach (T-learner) with RandomForest as a proxy if `causalml` is not available in the environment. If `causalml` is installed, prefer its CausalForest implementation.

In [None]:
import os
import pandas as pd
import numpy as np
import joblib
from sklearn.ensemble import RandomForestClassifier

# Paths (notebook runs from src/notebooks/)
data_path = os.path.join('..', '..', 'data', 'sample_data.csv')
df = pd.read_csv(data_path)

# Features and labels
feature_cols = ['age', 'income', 'number_of_transactions']
X = df[feature_cols]
T = df['treatment']
y = df['conversion']

# Output model path (save to src/uplift_model.pkl)
out_path = os.path.join('..', 'uplift_model.pkl')

# Try using causalml's CausalForest; fall back to a T-learner with RandomForest
use_causalml = False
try:
    from causalml.inference.tree import CausalForest
    from causalml.metrics import plot_qini
    use_causalml = True
    print('causalml detected: will use CausalForest')
except Exception as e:
    print('causalml not available, falling back to T-learner RandomForest:', e)

if use_causalml:
    # Fit CausalForest (API may vary slightly between versions)
    cf = CausalForest(random_state=42)
    try:
        cf.fit(X.values, T.values, y.values)
        # Attempt to get uplift scores
        try:
            uplift_scores = cf.predict(X.values)
            # If predict returns a dict/tuple, try to extract the TE estimate
            if isinstance(uplift_scores, dict) and 'ate' in uplift_scores:
                uplift_scores = np.array(uplift_scores['ate']).ravel()
            elif hasattr(uplift_scores, 'reshape'):
                uplift_scores = np.array(uplift_scores).ravel()
        except Exception:
            # Some versions return a tuple (te, lb, ub)
            pred = cf.predict(X.values)
            if isinstance(pred, (list, tuple)) and len(pred) >= 1:
                uplift_scores = np.array(pred[0]).ravel()
            else:
                uplift_scores = np.zeros(len(X))
                print('Warning: unable to extract uplift scores from CausalForest predict; using zeros')
    except Exception as e:
        print('Error fitting CausalForest, falling back to T-learner:', e)
        use_causalml = False

if not use_causalml:
    # T-learner simple proxy
    X_treated = X[T == 1]
    y_treated = y[T == 1]
    X_control = X[T == 0]
    y_control = y[T == 0]

    model_t = RandomForestClassifier(n_estimators=100, random_state=42)
    model_c = RandomForestClassifier(n_estimators=100, random_state=42)

    model_t.fit(X_treated, y_treated)
    model_c.fit(X_control, y_control)

    class SimpleUpliftModel:
        def __init__(self, model_t, model_c, feature_cols):
            self.model_t = model_t
            self.model_c = model_c
            self.feature_cols = feature_cols
        def predict(self, X):
            p_t = self.model_t.predict_proba(X[self.feature_cols])[:, 1]
            p_c = self.model_c.predict_proba(X[self.feature_cols])[:, 1]
            return p_t - p_c

    uplift_model = SimpleUpliftModel(model_t, model_c, feature_cols)
    uplift_scores = uplift_model.predict(X)

    # Save the wrapper as the model
    joblib.dump(uplift_model, out_path)
    print('Saved T-learner uplift model to', out_path)
else:
    # Save causalml model
    joblib.dump(cf, out_path)
    print('Saved CausalForest model to', out_path)

# Plot and save Qini curve (try using causalml.metrics.plot_qini if available)
qini_path = os.path.join('..', '..', 'docs', 'qini_curve.png')
try:
    from causalml.metrics import plot_qini
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6,4))
    plot_qini(uplift_scores, T.values, y.values)
    plt.savefig(qini_path, bbox_inches='tight')
    plt.close()
    print('Saved Qini curve to', qini_path)
except Exception as e:
    print('Could not plot Qini with causalml:', e)
    # Fallback: simple sorted uplift plot as a placeholder
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6,4))
    plt.plot(np.sort(uplift_scores)[::-1])
    plt.title('Approximate Qini (sorted uplift scores)')
    plt.xlabel('Ranked customers')
    plt.ylabel('Uplift score (proxy)')
    plt.savefig(qini_path, bbox_inches='tight')
    plt.close()
    print('Saved placeholder Qini curve to', qini_path)

# Sanity print for a few values
print('Example uplift scores (first 5):', uplift_scores[:5])
