In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import warnings
import pickle
import os

warnings.filterwarnings('ignore')


In [2]:
DATA_PATH = "../../data/clean_data.100k.csv"
SAVE_DIR = "../../data/graph-save"
OBSERVATION_DAYS = 30
PREDICTION_DAYS = 31

# Data split

In [3]:
def create_temporal_dataset(df, observation_days=30, prediction_days=31):
    df = df.copy()
    df['event_dt'] = pd.to_datetime(df['event_dt'])

    max_date = df['event_dt'].max()
    min_date = df['event_dt'].min()

    observation_end = max_date - pd.Timedelta(days=prediction_days)
    observation_start = observation_end - pd.Timedelta(days=observation_days)

    if observation_start < min_date:
        observation_start = min_date

    print(f"\nObservation: {observation_start.date()} to {observation_end.date()}")
    print(f"Prediction: {observation_end.date()} to {max_date.date()}")

    observation_data = df[
        (df['event_dt'] >= observation_start) &
        (df['event_dt'] < observation_end)
    ].copy()

    prediction_data = df[df['event_dt'] >= observation_end]

    active_in_observation = set(observation_data['device_id'].unique())
    active_in_prediction = set(prediction_data['device_id'].unique())
    churned_users = active_in_observation - active_in_prediction

    print(f"Users: {len(active_in_observation):,} | Churned: {len(churned_users):,} ({len(churned_users)/len(active_in_observation):.1%})")

    churn_labels = {uid: 1 if uid in churned_users else 0 for uid in active_in_observation}

    return observation_data, churn_labels

def create_unified_split(observation_data, churn_labels, save_dir):
    user_ids = list(churn_labels.keys())
    labels = [churn_labels[uid] for uid in user_ids]

    train_val_ids, test_ids, train_val_labels, test_labels = train_test_split(
        user_ids, labels, test_size=0.2, random_state=42, stratify=labels
    )

    train_ids, val_ids, train_labels, val_labels = train_test_split(
        train_val_ids, train_val_labels, test_size=0.25, random_state=42, stratify=train_val_labels
    )

    split_data = {
        'train_ids': train_ids,
        'val_ids': val_ids,
        'test_ids': test_ids,
        'train_labels': train_labels,
        'val_labels': val_labels,
        'test_labels': test_labels
    }

    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, 'data_split.pkl'), 'wb') as f:
        pickle.dump(split_data, f)

    print(f"\nSplit: Train {len(train_ids):,} | Val {len(val_ids):,} | Test {len(test_ids):,}")

    return split_data

# Baseline

In [4]:
def extract_baseline_features(df, user_ids):
    event_features = df.groupby('device_id').agg({
        'event_dt': ['min', 'max', 'count'],
        'session_id': 'nunique',
        'screen': 'nunique',
        'feature': 'nunique',
        'action': 'count'
    })
    event_features.columns = ['_'.join(col).strip() for col in event_features.columns]

    event_features['days_in_window'] = (
        event_features['event_dt_max'] - event_features['event_dt_min']
    ).dt.total_seconds() / 86400

    event_features['events_per_day'] = event_features['event_dt_count'] / event_features['days_in_window'].clip(lower=1)
    event_features['sessions_per_day'] = event_features['session_id_nunique'] / event_features['days_in_window'].clip(lower=1)
    event_features['events_per_session'] = event_features['event_dt_count'] / event_features['session_id_nunique'].clip(lower=1)
    event_features['screen_diversity'] = event_features['screen_nunique'] / event_features['event_dt_count']
    event_features['feature_diversity'] = event_features['feature_nunique'] / event_features['event_dt_count']

    first_day = event_features['event_dt_min'].min()
    last_day = event_features['event_dt_max'].max()
    event_features['days_since_first_seen'] = (event_features['event_dt_min'] - first_day).dt.total_seconds() / 86400
    event_features['days_until_window_end'] = (last_day - event_features['event_dt_max']).dt.total_seconds() / 86400
    event_features['recency_in_window'] = event_features['days_until_window_end'] / event_features['days_in_window'].clip(lower=1)

    demographic = df.groupby('device_id')[['age', 'gender']].first()
    features = event_features.join(demographic, how='left')
    features = features.drop(columns=['event_dt_min', 'event_dt_max'])

    if 'gender' in features.columns:
        features = pd.get_dummies(features, columns=['gender'], prefix='gender', drop_first=False)

    features = features.fillna(features.median())
    features = features.loc[features.index.intersection(user_ids)]

    return features


# GNN

In [5]:
import os
class GraphSAGEChurn(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, num_layers=2, dropout=0.3):
        super(GraphSAGEChurn, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))

        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))

        self.dropout = dropout

        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 1)
        )

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.classifier(x)
        return torch.sigmoid(out)

def build_random_graph(n_users, edges_per_node=5):
    edge_list = []
    for i in range(n_users):
        neighbors = np.random.choice(n_users, size=min(edges_per_node, n_users-1), replace=False)
        neighbors = neighbors[neighbors != i]
        for j in neighbors:
            edge_list.append([i, j])

    return torch.tensor(edge_list, dtype=torch.long).t().contiguous()

def train_gnn(data, train_mask, val_mask, test_mask, save_dir, device='cuda', config=None):

    hidden_channels = getattr(config, 'hidden_channels', 128)
    num_layers = getattr(config, 'num_layers', 2)
    dropout = getattr(config, 'dropout', 0.3)
    learning_rate = getattr(config, 'learning_rate', 0.01)
    weight_decay = getattr(config, 'weight_decay', 5e-4)
    max_epochs = getattr(config, 'epochs', 100)

    model = GraphSAGEChurn(in_channels=data.x.size(1), hidden_channels=hidden_channels, num_layers=num_layers, dropout=dropout)
    model = model.to(device)
    data = data.to(device)

    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    best_val_auc = 0
    patience_counter = 0

    for epoch in range(max_epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index).squeeze()
        loss = criterion(out[train_mask], data.y[train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index).squeeze()

            train_pred = out[train_mask].cpu().numpy()
            train_true = data.y[train_mask].cpu().numpy()
            val_pred = out[val_mask].cpu().numpy()
            val_true = data.y[val_mask].cpu().numpy()

            train_auc = roc_auc_score(train_true, train_pred)
            val_auc = roc_auc_score(val_true, val_pred)


        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:03d} | Train {train_auc:.3f} | Val {val_auc:.3f}")

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            patience_counter = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': {'in_channels': data.x.size(1), 'hidden_channels': hidden_channels, 'num_layers': num_layers, 'dropout': dropout}
            }, os.path.join(save_dir, 'gnn_model.pth'))
        else:
            patience_counter += 1
            if patience_counter >= 15:
                print(f"Early stopping at epoch {epoch+1}")
                break

    model.load_state_dict(torch.load(os.path.join(save_dir, 'gnn_model.pth'))['model_state_dict'])
    model.eval()

    with torch.no_grad():
        out = model(data.x, data.edge_index).squeeze()
        test_pred = out[test_mask].cpu().numpy()
        test_true = data.y[test_mask].cpu().numpy()

    test_pred_binary = (test_pred > 0.5).astype(int)
    test_auc = roc_auc_score(test_true, test_pred)
    test_precision = precision_score(test_true, test_pred_binary)
    test_recall = recall_score(test_true, test_pred_binary)
    test_f1 = f1_score(test_true, test_pred_binary)
    print(f"\nGNN: Best Val {best_val_auc:.3f} | Test {test_auc:.3f}")


    return {
        'model': 'GNN',
        'val_auc': best_val_auc,
        'test_auc': test_auc,
        'test_precision': test_precision,
        'test_recall': test_recall,
        'test_f1': test_f1
    }

GNN_SWEEP_CONFIG = {
    'method': 'bayes',
    'metric': {'name': 'val_auc', 'goal': 'maximize'},
    'parameters': {
        'hidden_channels': {'values': [64, 128, 256]},
        'num_layers': {'values': [1, 2, 3]},
        'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5},
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 0.001, 'max': 0.1},
        'weight_decay': {'distribution': 'log_uniform_values', 'min': 1e-5, 'max': 1e-3},
        'epochs': {'value': 100}
    }
}

### Load edges

In [6]:

# all_ids = train_ids + val_ids + test_ids

def build_graph_from_edges(edges_df, all_ids, undirected=False):
    user_id_to_idx = {uid: i for i, uid in enumerate(all_ids)}

    edges_df = edges_df[
        edges_df['source_id'].isin(user_id_to_idx) &
        edges_df['target_id'].isin(user_id_to_idx)
    ]

    edge_index = torch.tensor([
        edges_df['source_id'].map(user_id_to_idx).values,
        edges_df['target_id'].map(user_id_to_idx).values
    ], dtype=torch.long)

    if undirected:
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

    edge_index = torch.unique(edge_index, dim=1)
    return edge_index

# edge_index = build_graph_from_edges(
#     edges_df,
#     all_ids
# )

In [7]:
def run_gnn_single(train_ids, val_ids, test_ids, observation_data, churn_labels, edges_df):
    print("\n" + "="*80)
    print("TRAINING GNN (SINGLE RUN)")
    print("="*80)

    all_ids = train_ids + val_ids + test_ids
    features = extract_baseline_features(observation_data, all_ids)
    X_scaled = StandardScaler().fit_transform(features.loc[all_ids])

    # edge_index = build_random_graph(len(all_ids), edges_per_node=5)
    # print('edge_index', edge_index)
    edge_index = build_graph_from_edges(
        edges_df,
        all_ids
    )

    user_id_to_idx = {uid: idx for idx, uid in enumerate(all_ids)}
    train_mask = torch.zeros(len(all_ids), dtype=torch.bool)
    val_mask = torch.zeros(len(all_ids), dtype=torch.bool)
    test_mask = torch.zeros(len(all_ids), dtype=torch.bool)

    for uid in train_ids:
        train_mask[user_id_to_idx[uid]] = True
    for uid in val_ids:
        val_mask[user_id_to_idx[uid]] = True
    for uid in test_ids:
        test_mask[user_id_to_idx[uid]] = True

    x = torch.FloatTensor(X_scaled)
    y = torch.FloatTensor([churn_labels[uid] for uid in all_ids])
    data = Data(x=x, edge_index=edge_index, y=y)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    result = train_gnn(data, train_mask, val_mask, test_mask, SAVE_DIR, device)

    return result


# Split and save

In [8]:
df = pd.read_csv(DATA_PATH)
print(f"Loaded {len(df):,} events, {df['device_id'].nunique():,} users")
edges_df = pd.read_csv("../../data/links_graph.csv")


Loaded 99,999 events, 16,129 users


In [9]:
observation_data, churn_labels = create_temporal_dataset(df, OBSERVATION_DAYS, PREDICTION_DAYS)
split_data = create_unified_split(observation_data, churn_labels, SAVE_DIR)

train_ids = split_data['train_ids']
val_ids = split_data['val_ids']
test_ids = split_data['test_ids']
train_labels = split_data['train_labels']
val_labels = split_data['val_labels']
test_labels = split_data['test_labels']

print(f"Train: {len(train_ids)} users, Val: {len(val_ids)} users, Test: {len(test_ids)} users")


Observation: 2025-09-01 to 2025-09-24
Prediction: 2025-09-24 to 2025-10-25
Users: 8,501 | Churned: 6,230 (73.3%)

Split: Train 5,100 | Val 1,700 | Test 1,701
Train: 5100 users, Val: 1700 users, Test: 1701 users


# Test 3 models and compare

In [10]:
def run_all_single():
    print("\n" + "="*80)
    print("TRAINING ALL MODELS (SINGLE RUNS)")
    print("="*80)

    results = []

    results.append(run_gnn_single(train_ids, val_ids, test_ids, observation_data, churn_labels, edges_df))

    results_df = pd.DataFrame(results)

    print("\n" + "="*80)
    print("FINAL RESULTS")
    print("="*80 + "\n")
    print(results_df.to_string(index=False))

    results_df.to_csv(os.path.join(SAVE_DIR, 'final_results.csv'), index=False)

    best_model = results_df.loc[results_df['test_auc'].idxmax()]
    print(f"\nüèÜ Best: {best_model['model']} (Test AUC: {best_model['test_auc']:.4f})")

    return results_df

In [11]:
run_all_single()


TRAINING ALL MODELS (SINGLE RUNS)

TRAINING GNN (SINGLE RUN)
Using device: cpu
Epoch 010 | Train 0.682 | Val 0.694
Epoch 020 | Train 0.726 | Val 0.721
Epoch 030 | Train 0.752 | Val 0.727
Epoch 040 | Train 0.797 | Val 0.746
Epoch 050 | Train 0.828 | Val 0.761
Epoch 060 | Train 0.841 | Val 0.762
Epoch 070 | Train 0.859 | Val 0.759
Epoch 080 | Train 0.870 | Val 0.767
Epoch 090 | Train 0.885 | Val 0.761
Epoch 100 | Train 0.891 | Val 0.763

GNN: Best Val 0.772 | Test 0.754

FINAL RESULTS

model  val_auc  test_auc  test_precision  test_recall  test_f1
  GNN 0.772189  0.754134          0.8121     0.914996 0.860483

üèÜ Best: GNN (Test AUC: 0.7541)


Unnamed: 0,model,val_auc,test_auc,test_precision,test_recall,test_f1
0,GNN,0.772189,0.754134,0.8121,0.914996,0.860483
