# Data Preparation for GNN Estimation

This notebook prepare data from `processed_data.parquet` and export it as `gnn_data.pt` for GNN training.

### Train / validation / test split by burst

We split the data by time within each burst, so that validation and test always come **after** training in calendar time.

There are three bursts:

| Burst | Calendar range        | # unique days | Train days (earliest) | Validation days | Test days (latest) | Rule                                      |
|-------|-----------------------|---------------|------------------------|-----------------|--------------------|-------------------------------------------|
| 1     | 2020-07-05 – 2020-07-12 | 8             | first 4 days           | next 2 days      | last 2 days        | 4 / 2 / 2 split over the 8 days           |
| 2     | 2020-08-01 – 2020-08-10 | 10            | first 6 days           | next 2 days      | last 2 days        | 6 / 2 / 2 split over the 10 days          |
| 3     | 2020-08-27 – 2020-09-05 | 10            | first 6 days           | next 2 days      | last 2 days        | 6 / 2 / 2 split over the 10 days          |

Implementation details:

- For each burst, I sort interactions by `time` and extract the unique calendar days in that burst.
- I assign the earliest days to the training set, the middle days to the validation set, and the latest days to the test set, according to the ratios above.
- The same time-based split is used for all models (gradient boosting, standard NN, and GNN), so their performance is directly comparable.
- The training set has 0.8 million observations, and the validation and test sets each have 0.2 million.
- A tiny subsample contains 10% of the edges is generated for a test run.


### Features used for prediction:

| Feature                     | Description                                                                                 |
|-----------------------------|---------------------------------------------------------------------------------------------|
| burst_id                    | Burst window identifier (1st, 2nd, or 3rd data burst).                                     |
| session                     | Session ID for the user’s viewing session.                                                 |
| u_user_active_degree        | Categorical user activity level (e.g., low / medium / high).                              |
| u_is_lowactive_period       | Indicator if the user is in a low-activity period.                                         |
| u_is_live_streamer          | Indicator if the user has ever done live streaming.                                        |
| u_is_video_author           | Indicator if the user is also a content creator (has uploaded videos).                    |
| u_follow_user_num           | Number of users this user follows.                                                         |
| u_follow_user_num_range     | Binned range of `u_follow_user_num`.                                                       |
| u_fans_user_num             | Number of followers (fans) this user has.                                                  |
| u_fans_user_num_range       | Binned range of `u_fans_user_num`.                                                         |
| u_friend_user_num           | Number of friends (mutual relationships).                                                  |
| u_friend_user_num_range     | Binned range of `u_friend_user_num`.                                                       |
| u_register_days             | Days since user registration at the time of interaction.                                   |
| u_register_days_range       | Binned range of `u_register_days`.                                                         |
| u_onehot_feat0-17           | Encrypted user categorical features.                                                       |
| u_follow_user_num_log1p     | Log1p-transformed number of users the user follows.                                        |
| u_fans_user_num_log1p       | Log1p-transformed number of followers (fans).                                              |
| u_friend_user_num_log1p     | Log1p-transformed number of friends.                                                       |
| u_register_days_log1p       | Log1p-transformed days since registration.                                                 |
| i_aspect_ratio              | Video aspect ratio (height / width).                                                       |
| i_author_id                 | Encoded ID of the video’s author.                                                          |
| i_video_type                | Encoded video type (e.g., normal / live / other).                                          |
| i_upload_type               | Encoded upload type (e.g., original / re-upload).                                          |
| i_visible_status            | Encoded visibility status (e.g., public / private / limited).                             |
| i_music_id                  | Encoded ID of background music used in the video.                                          |
| i_video_tag_id              | Encoded tag ID associated with the video.                                                  |
| i_video_tag_name            | Encoded tag-name category for the video.                                                   |
| i_video_duration            | Video duration (seconds).                                                                  |
| i_age_since_upload_days     | Days since the video was uploaded.                                                         |
| i_cat_level1_id             | Encoded top-level content category ID.                                                     |
| i_cat_level2_id             | Encoded mid-level content category ID.                                                     |
| i_cat_level3_id             | Encoded fine-grained content category ID.                                                  |
| ctx_hour_sin                | Sine transform of local watch hour (time-of-day feature).                                  |
| ctx_hour_cos                | Cosine transform of local watch hour (time-of-day feature).                                |
| ctx_is_weekend              | Indicator if the interaction happens on a weekend.                                         |
| hist_ema_y_complete         | Per-user EMA of past completion events before this interaction.                           |
| hist_ema_y_long             | Per-user EMA of past long-watch events before this interaction.                           |
| hist_ema_y_rewatch          | Per-user EMA of past rewatch events before this interaction.                              |
| hist_ema_y_neg              | Per-user EMA of past negative-feedback events before this interaction.                    |
| hist_ema_watchratio         | Per-user EMA of past watch ratios before this interaction.                                 |
| hist_cat_ema_complete       | Per-user & category EMA of past completion events.                                         |
| hist_cat_entropy_l2         | Entropy-based measure of user’s category diversity (L2-regularized).                      |
| hist_author_recency_days    | Days since the user last watched this author before this interaction.                      |
| hist_last_complete_author   | Indicator if the last video from this author was completed by the user.                   |
| hist_has_author_history     | Indicator if the user has any prior interaction history with this author.                 |
| hist_prev_sess_len          | Length of the previous session for this user.                                              |
| hist_intersess_gap_h        | Time gap (hours) between the previous session end and current interaction.                 |


In [1]:
import pandas as pd
import numpy as np
import torch
from pathlib import Path

BASE = Path("/Users/haozhangao/Desktop/RecSys Research/KuaiRec 2.0/data/processed/")

INPUT_PATH = BASE / "processed_data.parquet"
OUTPUT_PATH = BASE / "gnn_data.pt"

df = pd.read_parquet(INPUT_PATH)

In [2]:
# Basic counts
n_users = df["user_id"].max() + 1
n_videos = df["video_id"].max() + 1
print("number of users:", n_users, "number of videos", n_videos)

# Labels (4 heads)
label_cols = ["y_complete", "y_long", "y_rewatch", "y_neg"]
y = df[label_cols].to_numpy().astype("float32")

# Edge index
src = df["user_id"].to_numpy() # source
dst = n_users + df["video_id"].to_numpy() # destination (avoid index overlapping)
edge_index = np.vstack([src, dst]) # [2, num_edges]

# Encode categorical columns to integers
cat_cols = df.select_dtypes(include="category").columns.tolist()
for c in cat_cols:
    df[c] = df[c].cat.codes.astype("int32")

# Exclude irrelevant columns
cols_to_exclude = [
    "user_id", "video_id",
    "time",
    "date", "timestamp",
    "play_duration", "watch_ratio", "sess_rank", "sess_len",
] + label_cols

numeric_feature_cols = [c for c in df.columns if c not in cols_to_exclude and pd.api.types.is_numeric_dtype(df[c]) ]

print("Feature columns used:\n", numeric_feature_cols)

edge_features = df[numeric_feature_cols].to_numpy().astype("float32")
num_edges = edge_index.shape[1]



print("number of edges:", num_edges)
print("edge_index shape:", edge_index.shape)
print("edge_features shape:", edge_features.shape)
print("y shape:", y.shape)

number of users: 7176 number of videos 10728
Feature columns used:
 ['burst_id', 'session', 'u_user_active_degree', 'u_is_lowactive_period', 'u_is_live_streamer', 'u_is_video_author', 'u_follow_user_num', 'u_follow_user_num_range', 'u_fans_user_num', 'u_fans_user_num_range', 'u_friend_user_num', 'u_friend_user_num_range', 'u_register_days', 'u_register_days_range', 'u_onehot_feat0', 'u_onehot_feat1', 'u_onehot_feat2', 'u_onehot_feat3', 'u_onehot_feat4', 'u_onehot_feat5', 'u_onehot_feat6', 'u_onehot_feat7', 'u_onehot_feat8', 'u_onehot_feat9', 'u_onehot_feat10', 'u_onehot_feat11', 'u_onehot_feat12', 'u_onehot_feat13', 'u_onehot_feat14', 'u_onehot_feat15', 'u_onehot_feat16', 'u_onehot_feat17', 'u_follow_user_num_log1p', 'u_fans_user_num_log1p', 'u_friend_user_num_log1p', 'u_register_days_log1p', 'i_aspect_ratio', 'i_author_id', 'i_video_type', 'i_upload_type', 'i_visible_status', 'i_music_id', 'i_video_tag_id', 'i_video_tag_name', 'i_video_duration', 'i_age_since_upload_days', 'i_cat_leve

In [3]:
# Define training/validation/test sets

time_dates = df["time"].dt.normalize()

train_mask = np.zeros(num_edges, dtype=bool)
valid_mask = np.zeros(num_edges, dtype=bool)
test_mask = np.zeros(num_edges, dtype=bool)

burst_ranges = [
    ("2020-07-05", "2020-07-12"),  # burst 1 (8 days)
    ("2020-08-01", "2020-08-10"),  # burst 2 (10 days)
    ("2020-08-27", "2020-09-05"),  # burst 3 (10 days)
]

for start_str, end_str in burst_ranges:
    start = pd.to_datetime(start_str)
    end = pd.to_datetime(end_str)

    in_burst = (time_dates >= start) & (time_dates <= end)
    burst_indx = np.where(in_burst)[0]

    # Count unique days inside each burst
    burst_days = np.sort(time_dates[in_burst].unique())
    n_days = len(burst_days)

    # Data split: 4/2/2 or 6/2/2
    if n_days == 8:
        n_train_days, n_valid_days, n_test_days = 4, 2, 2
    elif n_days == 10:
        n_train_days, n_valid_days, n_test_days = 6, 2, 2
    else:
        raise ValueError(f"Burst {start_str}-{end_str} has {n_days}, not 8 or 10.")

    train_days = set(burst_days[:n_train_days])
    valid_days = set(burst_days[n_train_days:n_train_days+n_valid_days])
    test_days = set(burst_days[n_train_days+n_valid_days:n_train_days+n_valid_days+n_test_days])

    # Update masks
    train_mask |= in_burst & time_dates.isin(train_days)
    valid_mask |= in_burst & time_dates.isin(valid_days)
    test_mask |= in_burst & time_dates.isin(test_days)

# Convert masks to index

train_idx = np.where(train_mask)[0]
valid_idx = np.where(valid_mask)[0]
test_idx = np.where(test_mask)[0]

print("train/valid/test set sizes:", len(train_idx), len(valid_idx), len(test_idx))

# Sanity check by testing overlapping
print("Overlap between training and valid sets", np.sum(train_mask & valid_mask))
print("Overlap between training and test sets:", np.sum(train_mask & test_mask))
print("Overlap between valid and test sets:", np.sum(valid_mask & test_mask))


train/valid/test set sizes: 8016340 2382507 2129065
Overlap between training and valid sets 0
Overlap between training and test sets: 0
Overlap between valid and test sets: 0


In [4]:
# Convert everything to torch tensors
edge_index_t = torch.from_numpy(edge_index).long() # [2, E]
edge_attr_t = torch.from_numpy(edge_features) # [E, D]
y_t = torch.from_numpy(y) # [E, 4]

train_idx_t = torch.from_numpy(train_idx).long()
valid_idx_t = torch.from_numpy(valid_idx).long()
test_idx_t = torch.from_numpy(test_idx).long()

print("edge_idx_t:", edge_index_t.shape)
print("edge_attr_t:", edge_attr_t.shape)
print("y_t:", y_t.shape)

edge_idx_t: torch.Size([2, 12527912])
edge_attr_t: torch.Size([12527912, 64])
y_t: torch.Size([12527912, 4])


In [5]:
# Number of nodes = users + videos
num_nodes = int(n_users + n_videos)
num_edge_features = edge_attr_t.shape[1]

gnn_data = {
    "edge_index": edge_index_t,       # [2, E]
    "edge_attr": edge_attr_t,         # [E, D]
    "y": y_t,                         # [E, 4]
    "train_idx": train_idx_t,         # [E_train]
    "val_idx": valid_idx_t,           # [E_val]
    "test_idx": test_idx_t,           # [E_test]
    "n_users": int(n_users),
    "n_items": int(n_videos),
    "num_nodes": num_nodes,
    "num_edge_features": num_edge_features,
    "feature_names": numeric_feature_cols
}

torch.save(gnn_data, OUTPUT_PATH)
print("num_nodes:", num_nodes, "num_edge_features:", num_edge_features)
print("Saved GNN data to", OUTPUT_PATH)

num_nodes: 17904 num_edge_features: 64
Saved GNN data to /Users/haozhangao/Desktop/RecSys Research/KuaiRec 2.0/data/processed/gnn_data.pt


In [6]:
# Create a tiny sample for sanity check

D_edge = edge_attr_t.shape[1]
E = edge_index_t.shape[1]
print("Original num edges:", E)

# 1) Choose a fraction of edges to keep
frac = 0.05   # 10% of edges
k = max(1, int(frac * E))

# Randomly pick k edges
perm_edges = torch.randperm(E)[:k]

edge_index_small = edge_index_t[:, perm_edges]   # [2, k]
edge_attr_small  = edge_attr_t[perm_edges]       # [k, D_edge]
y_small          = y_t[perm_edges]               # [k, 4]

print("Subsampled num edges:", k)

# 2) Make NEW random train/val/test split for this tiny graph
num_edges_small = k
perm = torch.randperm(num_edges_small)

n_train = int(0.8 * num_edges_small)
n_val   = int(0.1 * num_edges_small)
n_test  = num_edges_small - n_train - n_val

train_idx_small = perm[:n_train]
val_idx_small   = perm[n_train:n_train + n_val]
test_idx_small  = perm[n_train + n_val:]

print("tiny train/val/test sizes:",
      train_idx_small.shape[0],
      val_idx_small.shape[0],
      test_idx_small.shape[0])

# 3) Figure out the new num_nodes
num_nodes_small = int(edge_index_small.max().item() + 1)
print("num_nodes_small:", num_nodes_small)

# 4) Save a tiny dataset
tiny_data = {
    "edge_index": edge_index_small,
    "edge_attr": edge_attr_small,
    "y": y_small,
    "train_idx": train_idx_small,
    "val_idx": val_idx_small,
    "test_idx": test_idx_small,
    "num_nodes": num_nodes_small,
    "num_edge_features": D_edge,
    "feature_names": numeric_feature_cols
}

OUT_PATH_tiny = BASE / "gnn_data_tiny.pt"
torch.save(tiny_data, OUT_PATH_tiny)
print("Saved tiny data to:", OUT_PATH_tiny)


Original num edges: 12527912
Subsampled num edges: 626395
tiny train/val/test sizes: 501116 62639 62640
num_nodes_small: 17898
Saved tiny data to: /Users/haozhangao/Desktop/RecSys Research/KuaiRec 2.0/data/processed/gnn_data_tiny.pt
