In [23]:
import time

import matplotlib.pyplot as plt
import torch
from torch import optim
from torch_geometric.data import HeteroData
from torch_scatter import scatter_mean
from modeling.sampling import train_val_test_split_user_stratified

from modeling.losses import BPR_loss
from modeling.metrics import calculate_metrics
from modeling.sampling import prepare_training_data, sample_minibatch
from modeling.layers.bipartite_gcn import BipartiteGCN
from modeling.models.TB_simple import TBBaselineModel
from modeling.utils import get_coauthor_edges

torch.manual_seed(1)

# Load data
data: HeteroData = torch.load("data/hetero_data_no_coauthor.pt", weights_only=False)

paper_ids = data["paper"].node_id
paper_embeddings = data["paper"].x
author_ids = data["author"].node_id
author_embeddings = torch.ones((data["author"].num_nodes, paper_embeddings.shape[1]))
num_authors = data["author"].num_nodes
num_papers = data["paper"].num_nodes

edge_index = data["author", "writes", "paper"].edge_index

print(f"Number of authors: {len(author_ids)}")
print(f"Number of papers: {len(paper_ids)}")
print(f"Number of edges: {edge_index.shape[1]}")

'''
# Train/val/test split and message-passing vs supervision edges
(
    message_passing_edge_index,
    supervision_edge_index,
    val_edge_index_raw,
    test_edge_index_raw,
) = prepare_training_data(edge_index)

train_edge_index_raw = torch.cat([message_passing_edge_index, supervision_edge_index], dim=1)

message_passing_edge_index = message_passing_edge_index# + edge_index_offset.view(2, 1)
supervision_edge_index = supervision_edge_index# +.view(2, 1)
val_edge_index = val_edge_index_raw #+ edge_index_offset.view(2, 1)
test_edge_index = test_edge_index_raw# + edge_index_offset.view(2, 1)'''


message_passing_edge_index, supervision_edge_index, val_edge_index, test_edge_index = train_val_test_split_user_stratified(
    edge_index,
    num_authors,
    random_seed=42)


Number of authors: 90941
Number of papers: 63854
Number of edges: 320187
Train message passing edges: 179304
Train supervision edges: 76845
Validation edges: 32018
Test edges: 32020
Total edges: 320187


In [24]:
# list degrees for all authors on message_passing edges
author_degrees = torch.zeros(num_authors, dtype=torch.long)
for author_id in message_passing_edge_index[0]:
    author_degrees[author_id] += 1
min(author_degrees), max(author_degrees)

# How many authors have degree 0?

(author_degrees == 0).sum()

tensor(5668)

In [25]:
len(author_degrees)

90941

In [26]:
test_author_edges = test_edge_index_raw[0]
test_edge_author_degrees = author_degrees[test_author_edges]

In [27]:
test_edge_author_degrees.shape

torch.Size([32020])

In [28]:
# How many test edges have authors with degree 0?
(test_edge_author_degrees == 0).sum()

tensor(1365)

In [29]:
6700/32000

0.209375

In [30]:
user_links = edge_index[0]

In [31]:
# check degree of users in the edge_index
user_degrees = torch.zeros(num_authors, dtype=torch.long)
for user_id in user_links:
    user_degrees[user_id] += 1

In [35]:
# how many users have degree 0?
(user_degrees == 1).sum(), len(user_degrees)

(tensor(47140), 90941)