In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import warnings
import pickle
import os

warnings.filterwarnings('ignore')


In [2]:
DATA_PATH = "../../data/clean_data.500k.csv"
SAVE_DIR = "../../data/graph-save"
OBSERVATION_DAYS = 30
PREDICTION_DAYS = 31

# Data split

In [3]:

def create_temporal_dataset(
    df,
    observation_days=30,
    prediction_days=31
):
    df = df.copy()
    df["event_dt"] = pd.to_datetime(df["event_dt"])

    max_date = df["event_dt"].max()
    min_date = df["event_dt"].min()

    observation_end = max_date - pd.Timedelta(days=prediction_days)
    observation_start = observation_end - pd.Timedelta(days=observation_days)

    if observation_start < min_date:
        observation_start = min_date

    print(f"\nObservation: {observation_start.date()} ‚Üí {observation_end.date()}")
    print(f"Prediction: {observation_end.date()} ‚Üí {max_date.date()}")

    # --- –¥–∞–Ω–Ω—ã–µ –æ–∫–Ω–∞ –Ω–∞–±–ª—é–¥–µ–Ω–∏—è ---
    observation_data = df[
        (df["event_dt"] >= observation_start) &
        (df["event_dt"] < observation_end)
    ].copy()

    # --- device_id set (–ø–æ—è–≤–∏–ª–∏—Å—å —Ö–æ—Ç—è –±—ã —Ä–∞–∑ –¥–æ prediction window) ---
    device_id_set = set(
        df[df["event_dt"] < observation_end]["device_id"].unique()
    )

    print(f"Nodes in observation: {observation_data['node_id'].nunique():,}")
    print(f"Devices before prediction: {len(device_id_set):,}")

    # --- —Ç–∞—Ä–≥–µ—Ç—ã –ø–æ node_id (–±–µ—Ä—ë–º –∫–∞–∫ –µ—Å—Ç—å) ---
    node_labels = (
        observation_data
        .groupby("node_id")["is_churn"]
        .max()          # –µ—Å–ª–∏ —É node –±—ã–ª–∏ —Ä–∞–∑–Ω—ã–µ —Å–æ–±—ã—Ç–∏—è ‚Äî churn=True –ø–æ–±–µ–∂–¥–∞–µ—Ç
        .to_dict()
    )

    return observation_data, node_labels, device_id_set


def create_unified_split(node_labels, save_dir):
    node_ids = list(node_labels.keys())
    labels = [node_labels[nid] for nid in node_ids]

    train_val_ids, test_ids, train_val_labels, test_labels = train_test_split(
        node_ids,
        labels,
        test_size=0.2,
        random_state=42,
        stratify=labels
    )

    train_ids, val_ids, train_labels, val_labels = train_test_split(
        train_val_ids,
        train_val_labels,
        test_size=0.25,
        random_state=42,
        stratify=train_val_labels
    )

    split_data = {
        "train_ids": train_ids,
        "val_ids": val_ids,
        "test_ids": test_ids,
        "train_labels": train_labels,
        "val_labels": val_labels,
        "test_labels": test_labels,
    }

    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "data_split.pkl"), "wb") as f:
        pickle.dump(split_data, f)

    print(
        f"\nSplit:"
        f" Train {len(train_ids):,}"
        f" | Val {len(val_ids):,}"
        f" | Test {len(test_ids):,}"
    )

    return split_data

# Baseline

In [4]:
import pandas as pd
import numpy as np


def extract_node_baseline_features(df, device_id_set):
    df = df.copy()
    df["event_dt"] = pd.to_datetime(df["event_dt"])

    # --- —Ñ–∏–ª—å—Ç—Ä–∞—Ü–∏—è –ø–æ device_id_set ---
    df = df[df["device_id"].isin(device_id_set)].copy()

    # ============================================================
    # steps_count
    # ============================================================
    steps_count = (
        df.groupby("node_id")
        .size()
        .rename("steps_count")
    )

    # ============================================================
    # sessions_count (—É–Ω–∏–∫–∞–ª—å–Ω—ã–µ device_id + session_id)
    # ============================================================
    sessions_count = (
        df.groupby("node_id")[["device_id", "session_id"]]
        .apply(lambda x: x.drop_duplicates().shape[0])
        .rename("sessions_count")
    )

    # ============================================================
    # steptime_ms
    # ============================================================
    df = df.sort_values(["device_id", "session_id", "event_dt"])

    df["next_event_dt"] = (
        df.groupby(["device_id", "session_id", "node_id"])["event_dt"]
        .shift(-1)
    )

    df["steptime_ms"] = (
        (df["next_event_dt"] - df["event_dt"])
        .dt.total_seconds() * 1000
    )

    steptime = (
        df.groupby("node_id")["steptime_ms"]
        .mean()
        .rename("steptime_ms")
    )

    # ============================================================
    # feature_diversity_avg
    # ============================================================
    feature_diversity_avg = (
        df.groupby(["node_id", "device_id"])["feature"]
        .nunique()
        .groupby("node_id")
        .mean()
        .rename("feature_diversity_avg")
    )

    # ============================================================
    # age_avg
    # ============================================================
    age_avg = (
        df.groupby("node_id")["age"]
        .mean()
        .rename("age_avg")
    )

    # ============================================================
    # male_rate
    # ============================================================
    gender_df = df[df["gender"].notna()]

    male_rate = (
        gender_df.groupby("node_id")["gender"]
        .apply(lambda x: (x == "–ú").mean())
        .rename("male_rate")
    )

    # ============================================================
    # churn_rate
    # ============================================================
    churn_rate = (
        df.groupby("node_id")["is_churn"]
        .mean()
        .rename("churn_rate")
    )

    # ============================================================
    # –æ–±—ä–µ–¥–∏–Ω–µ–Ω–∏–µ
    # ============================================================
    features = pd.concat(
        [
            steps_count,
            sessions_count,
            steptime,
            feature_diversity_avg,
            age_avg,
            male_rate,
            churn_rate,
        ],
        axis=1,
    )

    return features.reset_index()


# GNN (node-level regression)

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


class GraphSAGEChurn(nn.Module):
    """
    Node-level regression:
    y = churn_rate
    """
    def __init__(self, in_channels, hidden_channels=128, num_layers=2, dropout=0.3):
        super().__init__()

        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))

        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))

        self.dropout = dropout

        self.regressor = nn.Sequential(
            nn.Linear(hidden_channels, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 1)
        )

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.regressor(x)
        return out.squeeze(-1)   # [num_nodes]


### –û–±—É—á–µ–Ω–∏–µ

In [21]:
import os
import numpy as np
import torch
from sklearn.metrics import mean_absolute_error, root_mean_squared_error


def train_gnn(
    data,
    train_mask,
    val_mask,
    test_mask,
    save_dir,
    device="cuda",
    config=None,
):
    hidden_channels = getattr(config, "hidden_channels", 128)
    num_layers = getattr(config, "num_layers", 2)
    dropout = getattr(config, "dropout", 0.3)
    learning_rate = getattr(config, "learning_rate", 0.01)
    weight_decay = getattr(config, "weight_decay", 5e-4)
    max_epochs = getattr(config, "epochs", 100)

    model = GraphSAGEChurn(
        in_channels=data.x.size(1),
        hidden_channels=hidden_channels,
        num_layers=num_layers,
        dropout=dropout,
    ).to(device)

    data = data.to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )

    best_val_rmse = float("inf")
    patience_counter = 0

    for epoch in range(max_epochs):
        # -------------------- train --------------------
        model.train()
        optimizer.zero_grad()

        out = model(data.x, data.edge_index)
        loss = criterion(out[train_mask], data.y[train_mask])

        loss.backward()
        optimizer.step()

        # -------------------- validation --------------------
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)

            train_pred = out[train_mask].cpu().numpy()
            train_true = data.y[train_mask].cpu().numpy()
            val_pred = out[val_mask].cpu().numpy()
            val_true = data.y[val_mask].cpu().numpy()

            train_rmse = root_mean_squared_error(train_true, train_pred)
            val_rmse = root_mean_squared_error(val_true, val_pred)

        if (epoch + 1) % 10 == 0:
            print(
                f"Epoch {epoch+1:03d} | "
                f"Train RMSE {train_rmse:.4f} | "
                f"Val RMSE {val_rmse:.4f}"
            )

        # -------------------- early stopping --------------------
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            patience_counter = 0
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "config": {
                        "in_channels": data.x.size(1),
                        "hidden_channels": hidden_channels,
                        "num_layers": num_layers,
                        "dropout": dropout,
                    },
                },
                os.path.join(save_dir, "gnn_model.pth"),
            )
        else:
            patience_counter += 1
            if patience_counter >= 15:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # -------------------- test --------------------
    checkpoint = torch.load(os.path.join(save_dir, "gnn_model.pth"))
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    with torch.no_grad():
        out = model(data.x, data.edge_index)
        test_pred = out[test_mask].cpu().numpy()
        test_true = data.y[test_mask].cpu().numpy()

    test_rmse = root_mean_squared_error(test_true, test_pred)
    test_mae = mean_absolute_error(test_true, test_pred)

    print(
        f"\nGNN node-regression:"
        f" Best Val RMSE {best_val_rmse:.4f}"
        f" | Test RMSE {test_rmse:.4f}"
        f" | Test MAE {test_mae:.4f}"
    )

    return {
        "model": "GNN",
        "val_rmse": best_val_rmse,
        "test_rmse": test_rmse,
        "test_mae": test_mae,
    }


### Load edges

In [22]:

def build_graph_from_edges(edges_df, all_ids, undirected=False):
    user_id_to_idx = {uid: i for i, uid in enumerate(all_ids)}

    edges_df = edges_df[
        edges_df['source_id'].isin(user_id_to_idx) &
        edges_df['target_id'].isin(user_id_to_idx)
    ]

    edge_index = torch.tensor([
        edges_df['source_id'].map(user_id_to_idx).values,
        edges_df['target_id'].map(user_id_to_idx).values
    ], dtype=torch.long)

    if undirected:
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

    edge_index = torch.unique(edge_index, dim=1)
    return edge_index

In [30]:
from sklearn.preprocessing import StandardScaler
import torch
import joblib
from torch_geometric.data import Data


def run_gnn_single(
    train_ids,
    val_ids,
    test_ids,
    observation_data,
    device_id_set,
    edges_df,
):
    print("\n" + "=" * 80)
    print("TRAINING GNN (NODE-LEVEL)")
    print("=" * 80)

    # --------------------------------------------------
    # 1. all node_ids
    # --------------------------------------------------
    all_ids = train_ids + val_ids + test_ids

    # --------------------------------------------------
    # 2. node-level features
    # --------------------------------------------------
    features = (
        extract_node_baseline_features(
            observation_data,
            device_id_set,
        )
        .set_index("node_id")
    )

    # -------------------------
    # NaN handling (baseline-style)
    # -------------------------
    features = features.fillna(features.median())

    # –ø–æ—Ä—è–¥–æ–∫ —Å—Ç—Ä–æ–≥–æ –∫–∞–∫ –≤ all_ids
    X = features.loc[all_ids].values
    # X_scaled = StandardScaler().fit_transform(X)
    scaler_x = StandardScaler()
    X_scaled = scaler_x.fit_transform(X)

    # --------------------------------------------------
    # 3. graph
    # --------------------------------------------------
    edge_index = build_graph_from_edges(
        edges_df,
        all_ids,
        undirected=True,
    )

    # --------------------------------------------------
    # 4. masks
    # --------------------------------------------------
    node_id_to_idx = {nid: idx for idx, nid in enumerate(all_ids)}

    train_mask = torch.zeros(len(all_ids), dtype=torch.bool)
    val_mask = torch.zeros(len(all_ids), dtype=torch.bool)
    test_mask = torch.zeros(len(all_ids), dtype=torch.bool)

    for nid in train_ids:
        train_mask[node_id_to_idx[nid]] = True
    for nid in val_ids:
        val_mask[node_id_to_idx[nid]] = True
    for nid in test_ids:
        test_mask[node_id_to_idx[nid]] = True

    # --------------------------------------------------
    # 5. tensors
    # --------------------------------------------------
    x = torch.tensor(X_scaled, dtype=torch.float)

    # y = churn_rate (node-level regression)
    y = torch.tensor(
        features.loc[all_ids]["churn_rate"].values,
        dtype=torch.float,
    )

    data = Data(x=x, edge_index=edge_index, y=y)

    # --------------------------------------------------
    # 6. train
    # --------------------------------------------------
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    result = train_gnn(
        data,
        train_mask,
        val_mask,
        test_mask,
        SAVE_DIR,
        device,
    )

    joblib.dump(scaler_x, os.path.join(SAVE_DIR, "scaler_x.pkl"))

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º mapping –¥–ª—è ID -> index
    import json
    with open(os.path.join(SAVE_DIR, "node_mapping.json"), "w") as f:
        json.dump({
            "node_id_to_idx": node_id_to_idx,
            "all_ids": all_ids,
        }, f)

    # –í–æ–∑–≤—Ä–∞—â–∞–µ–º –í–°–Å —á—Ç–æ –Ω—É–∂–Ω–æ –¥–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
    return {
        "result": result,
        "features_raw": features,  # –Ω–µ–æ–±—Ä–∞–±–æ—Ç–∞–Ω–Ω—ã–µ —Ñ–∏—á–∏ (—Å churn_rate)
        "features_scaled": X_scaled,  # —Å–∫–∞–ª–∏—Ä–æ–≤–∞–Ω–Ω—ã–µ —Ñ–∏—á–∏
        "edge_index": edge_index,  # —Å—Ç—Ä—É–∫—Ç—É—Ä–∞ –≥—Ä–∞—Ñ–∞
        "node_id_to_idx": node_id_to_idx,  # mapping ID -> –∏–Ω–¥–µ–∫—Å
        "all_ids": all_ids,  # –ø–æ—Ä—è–¥–æ–∫ —É–∑–ª–æ–≤
        "scaler_x": scaler_x,  # —Å–∫–∞–ª–µ—Ä
    }

# Split and save

In [24]:
df = pd.read_csv(DATA_PATH)
print(f"Loaded {len(df):,} events, {df['node_id'].nunique():,} nodes")
edges_df = pd.read_csv("../../data/links_graph.csv")


Loaded 499,999 events, 157 nodes


In [25]:
# ---------------------------------------------
# temporal window + device_id_set
# ---------------------------------------------
observation_data, node_labels, device_id_set = create_temporal_dataset(
    df,
    OBSERVATION_DAYS,
    PREDICTION_DAYS,
)

# ---------------------------------------------
# split –ø–æ node_id
# ---------------------------------------------
split_data = create_unified_split(
    node_labels,
    SAVE_DIR,
)

train_ids = split_data["train_ids"]
val_ids = split_data["val_ids"]
test_ids = split_data["test_ids"]

print(
    f"Train: {len(train_ids)} nodes, "
    f"Val: {len(val_ids)} nodes, "
    f"Test: {len(test_ids)} nodes"
)



Observation: 2025-09-01 ‚Üí 2025-09-25
Prediction: 2025-09-25 ‚Üí 2025-10-26
Nodes in observation: 99
Devices before prediction: 37,899

Split: Train 59 | Val 20 | Test 20
Train: 59 nodes, Val: 20 nodes, Test: 20 nodes


# Test 3 models and compare

In [40]:
def run_all_single():
    print("\n" + "=" * 80)
    print("TRAINING MODEL (SINGLE RUN)")
    print("=" * 80)

    gnn_output = run_gnn_single(
        train_ids=train_ids,
        val_ids=val_ids,
        test_ids=test_ids,
        observation_data=observation_data,
        device_id_set=device_id_set,
        edges_df=edges_df,
    )

    result = gnn_output["result"]  # ‚Üê —Ä–µ–∑—É–ª—å—Ç–∞—Ç –æ–±—É—á–µ–Ω–∏—è
    features_raw = gnn_output["features_raw"]  # ‚Üê —Ñ–∏—á–∏
    edge_index = gnn_output["edge_index"]  # ‚Üê —Å—Ç—Ä—É–∫—Ç—É—Ä–∞ –≥—Ä–∞—Ñ–∞
    scaler_x = gnn_output["scaler_x"]  # ‚Üê —Å–∫–∞–ª–µ—Ä

    results_df = pd.DataFrame([result])

    print("\n" + "=" * 80)
    print("FINAL RESULTS")
    print("=" * 80 + "\n")
    print(results_df.to_string(index=False))

    results_df.to_csv(
        os.path.join(SAVE_DIR, "final_results.csv"),
        index=False,
    )

    # best –ø–æ RMSE (–º–µ–Ω—å—à–µ = –ª—É—á—à–µ)
    best_model = results_df.loc[results_df["test_rmse"].idxmin()]
    print(
        f"\nüèÜ Best: {best_model['model']} "
        f"(Test RMSE: {best_model['test_rmse']:.4f})"
    )

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º edge_index –æ—Ç–¥–µ–ª—å–Ω–æ –¥–ª—è —É–¥–æ–±—Å—Ç–≤–∞
    torch.save(edge_index, os.path.join(SAVE_DIR, "edge_index.pt"))
    
    # –°–æ—Ö—Ä–∞–Ω—è–µ–º features
    # features_without_target = features_raw.drop(columns=["churn_rate"])
    features_raw.to_csv(os.path.join(SAVE_DIR, "node_features.csv"))
    
    print(f"\n‚úÖ –î–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π —Å–æ—Ö—Ä–∞–Ω–µ–Ω–æ:")
    print(f"   - edge_index.pt (—Å—Ç—Ä—É–∫—Ç—É—Ä–∞ –≥—Ä–∞—Ñ–∞)")
    print(f"   - node_features.csv (–ø—Ä–∏–∑–Ω–∞–∫–∏ —É–∑–ª–æ–≤)")
    print(f"   - scaler_x.pkl (—Å–∫–∞–ª–µ—Ä –ø—Ä–∏–∑–Ω–∞–∫–æ–≤)")
    print(f"   - gnn_model.pth (–º–æ–¥–µ–ª—å)")

    return results_df, gnn_output


In [41]:
run_all_single()


TRAINING MODEL (SINGLE RUN)

TRAINING GNN (NODE-LEVEL)
Using device: cpu
Epoch 010 | Train RMSE 0.0811 | Val RMSE 0.0965
Epoch 020 | Train RMSE 0.0735 | Val RMSE 0.0932
Epoch 030 | Train RMSE 0.0778 | Val RMSE 0.0752
Epoch 040 | Train RMSE 0.0519 | Val RMSE 0.0598
Epoch 050 | Train RMSE 0.0343 | Val RMSE 0.0518
Epoch 060 | Train RMSE 0.0697 | Val RMSE 0.0541
Epoch 070 | Train RMSE 0.0619 | Val RMSE 0.0434
Epoch 080 | Train RMSE 0.1110 | Val RMSE 0.0591
Early stopping at epoch 89

GNN node-regression: Best Val RMSE 0.0333 | Test RMSE 0.0210 | Test MAE 0.0190

FINAL RESULTS

model  val_rmse  test_rmse  test_mae
  GNN  0.033261    0.02097  0.018964

üèÜ Best: GNN (Test RMSE: 0.0210)

‚úÖ –î–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π —Å–æ—Ö—Ä–∞–Ω–µ–Ω–æ:
   - edge_index.pt (—Å—Ç—Ä—É–∫—Ç—É—Ä–∞ –≥—Ä–∞—Ñ–∞)
   - node_features.csv (–ø—Ä–∏–∑–Ω–∞–∫–∏ —É–∑–ª–æ–≤)
   - scaler_x.pkl (—Å–∫–∞–ª–µ—Ä –ø—Ä–∏–∑–Ω–∞–∫–æ–≤)
   - gnn_model.pth (–º–æ–¥–µ–ª—å)


(  model  val_rmse  test_rmse  test_mae
 0   GNN  0.033261    0.02097  0.018964,
 {'result': {'model': 'GNN',
   'val_rmse': 0.03326091915369034,
   'test_rmse': 0.020970314741134644,
   'test_mae': 0.01896384358406067},
  'features_raw':                                   steps_count  sessions_count    steptime_ms  \
  node_id                                                                        
  0072f89b60d46ef6f2094949d8831f13        17770           10979   26027.683699   
  02b207cc24a78c1942161bafc72fe532        11214            9790  207362.359551   
  05aa62cfe2beb31d4ecc652cddec5689            2               1  549000.000000   
  0ab7553a46130fe3b64fa66ae66e6ad1         3786            3577   78966.507177   
  0bcd42c9cba99c24662d526b8917a4b2            5               5   56011.904762   
  ...                                       ...             ...            ...   
  f2f9d242858a788cad0cd1e66264f25b          450             388   41709.677419   
  f38e3fd9c83a13ec4cc1da3

## –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ

In [42]:
def load_gnn_for_predictions(save_dir, device="cpu"):
    """
    –ó–∞–≥—Ä—É–∂–∞–µ—Ç –≤—Å—ë –Ω–µ–æ–±—Ö–æ–¥–∏–º–æ–µ –¥–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
    """
    # 1. –ú–æ–¥–µ–ª—å
    checkpoint = torch.load(os.path.join(save_dir, "gnn_model.pth"), map_location=device)
    model = GraphSAGEChurn(
        in_channels=checkpoint["config"]["in_channels"],
        hidden_channels=checkpoint["config"]["hidden_channels"],
        num_layers=checkpoint["config"]["num_layers"],
        dropout=checkpoint["config"]["dropout"],
    ).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    # 2. Scaler
    scaler_x = joblib.load(os.path.join(save_dir, "scaler_x.pkl"))
    
    # 3. Edge index
    edge_index = torch.load(os.path.join(save_dir, "edge_index.pt"), map_location=device)
    
    # 4. Mapping (–æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω–æ)
    import json
    with open(os.path.join(save_dir, "node_mapping.json"), "r") as f:
        mapping = json.load(f)
    
    return {
        "model": model,
        "scaler_x": scaler_x,
        "edge_index": edge_index,
        "node_id_to_idx": mapping["node_id_to_idx"],
        "all_ids": mapping["all_ids"],
    }

In [43]:
def predict_new_nodes(model, scaler_x, edge_index, new_features_df, device="cpu"):
    """
    –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ –¥–ª—è –Ω–æ–≤—ã—Ö —É–∑–ª–æ–≤
    
    new_features_df: DataFrame —Å –∫–æ–ª–æ–Ω–∫–∞–º–∏ –∫–∞–∫ –≤ features (–±–µ–∑ churn_rate)
                    –∏–Ω–¥–µ–∫—Å - node_id
    """
    # –ö–æ–Ω–≤–µ—Ä—Ç–∏—Ä—É–µ–º –≤ numpy
    X_new = new_features_df.values
    
    # –°–∫–∞–ª–∏—Ä—É–µ–º
    X_new_scaled = scaler_x.transform(X_new)
    
    # –í —Ç–µ–Ω–∑–æ—Ä
    x_tensor = torch.tensor(X_new_scaled, dtype=torch.float).to(device)
    
    # Edge_index –æ—Å—Ç–∞–µ—Ç—Å—è —Ç–µ–º –∂–µ (–µ—Å–ª–∏ —Å—Ç—Ä—É–∫—Ç—É—Ä–∞ –≥—Ä–∞—Ñ–∞ –Ω–µ –º–µ–Ω—è–µ—Ç—Å—è)
    
    # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ
    with torch.no_grad():
        predictions = model(x_tensor, edge_index)
        predictions = predictions.cpu().numpy().flatten()
    
    # –°–æ–∑–¥–∞–µ–º DataFrame —Å —Ä–µ–∑—É–ª—å—Ç–∞—Ç–∞–º–∏
    results_df = pd.DataFrame({
        "node_id": new_features_df.index,
        "predicted_churn_rate": predictions
    })
    
    return results_df

In [None]:
# –î–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π –Ω–∞ —Ç–µ—Ö –∂–µ —É–∑–ª–∞—Ö:
loaded = load_gnn_for_predictions(SAVE_DIR, device="cpu")

model = loaded["model"]
scaler_x = loaded["scaler_x"]
edge_index = loaded["edge_index"].clone()
model.eval()

# –ó–∞–≥—Ä—É–∂–∞–µ–º –ø—Ä–∏–∑–Ω–∞–∫–∏
features_df = pd.read_csv(os.path.join(SAVE_DIR, "node_features.csv"), index_col="node_id")
mean_features = features_df.mean()
new_node_df = pd.DataFrame([mean_features], columns=features_df.columns)

X_old = features_df.values
X_new = new_node_df.values

X_all = np.vstack([X_old, X_new])

# –í–ê–ñ–ù–û: —Ç–æ–ª—å–∫–æ transform
X_all_scaled = scaler_x.transform(X_all)

x_tensor = torch.tensor(X_all_scaled, dtype=torch.float)

# TODO: –¥–æ–±–∞–≤–∏—Ç—å —Å–≤—è–∑—å –Ω–æ–≤–æ–π –Ω–æ–¥—ã, —Å –∫–∞–∫–æ–π–Ω–∏–±—É–¥—å —Å–ª—É—á–∞–π–Ω–æ–π –Ω–æ–¥–æ–π —É–∂–µ —Å—É—â–µ—Å—Ç–≤—É—é—â–µ–π
edge_index = loaded["edge_index"]

# –î–æ–±–∞–≤–ª—è–µ–º —Ä–µ–±—Ä–æ –∫ —Å–ª—É—á–∞–π–Ω–æ–π —Å—É—â–µ—Å—Ç–≤—É—é—â–µ–π –Ω–æ–¥–µ
num_old_nodes = features_df.shape[0]
new_node_idx = num_old_nodes

existing_node_idx = np.random.randint(0, num_old_nodes)

# –¥–æ–±–∞–≤–ª—è–µ–º —Ä–µ–±—Ä–∞ (undirected)
new_edges = torch.tensor(
    [
        [existing_node_idx, new_node_idx],
        [new_node_idx, existing_node_idx],
    ],
    dtype=torch.long
).t()

edge_index_extended = torch.cat([edge_index, new_edges], dim=1)

# –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ
with torch.no_grad():
    out = model(x_tensor, edge_index_extended)

new_node_prediction = out[new_node_idx].item()

print(f"Predicted churn_rate for new node: {new_node_prediction:.6f}")

Predicted churn_rate for new node: 0.041186
