# Customer Segmentation & LTV Prediction
Segment customers and predict lifetime value.

In [None]:
# 1. Import libraries
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

plt.rcParams['figure.figsize'] = (10, 5)

In [None]:
# 2. Load dataset
df = pd.read_csv('customer_data.csv', parse_dates=['invoice_date'])
df.head()

In [None]:
# 3. Compute RFM features
import datetime as dt
snapshot = df['invoice_date'].max() + dt.timedelta(days=1)
rfm = df.groupby('customer_id').agg({
    'invoice_date': lambda x: (snapshot - x.max()).days,
    'invoice_no': 'nunique',
    'amount': 'sum'
}).rename(columns={'invoice_date':'Recency','invoice_no':'Frequency','amount':'Monetary'})
rfm.head()

In [None]:
# 4. RFM Clustering
scaler = StandardScaler()
rfm_scaled = scaler.fit_transform(rfm)
kmeans = KMeans(n_clusters=4, random_state=42).fit(rfm_scaled)
rfm['Cluster'] = kmeans.labels_
rfm['Cluster'].value_counts()

In [None]:
# 5. Visualise Clusters
plt.scatter(rfm['Recency'], rfm['Monetary'], c=rfm['Cluster'], cmap='viridis')
plt.xlabel('Recency'); plt.ylabel('Monetary'); plt.title('RFM Clusters'); plt.show()

In [None]:
# 6. Prepare data for LTV prediction
data = df.groupby('customer_id').agg({
    'amount': ['sum', 'mean'], 
    'invoice_no': 'nunique'
})
data.columns = ['total_amount','avg_amount','frequency']
data = data.merge(rfm[['Recency','Cluster']], left_index=True, right_index=True)
data['LTV'] = data['total_amount'] * 1.2  # placeholder for actual LTV
X = data.drop('LTV', axis=1)
y = data['LTV']

In [None]:
# 7. Train XGBoost Model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = XGBRegressor(objective='reg:squarederror', n_estimators=100, random_state=42)
model.fit(X_train, y_train)
preds = model.predict(X_test)
print('LTV MAE:', mean_absolute_error(y_test, preds))

In [None]:
# 8. Save Model
import pickle
with open('ltv_model.pkl', 'wb') as f:
    pickle.dump(model, f)
print('Model saved to ltv_model.pkl')