In [10]:
import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.loader import NeighborLoader
from torch_geometric.transforms import RandomNodeSplit, RandomLinkSplit

In [2]:
data = Planetoid(".", name='Cora')[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [3]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [14]:
data = Data(x=data.x, edge_index=data.edge_index)

In [23]:
data = Data(x=data.x, edge_index=data.edge_index, conditions=torch.tensor(list(range(len(data.x)))))

In [24]:
if "conditions" in data:
    print("Yes")

Yes


## Random Node Split

In [25]:
random_node_split = RandomNodeSplit(
    num_val=0.1,
    num_test=0,
    key="x")
data = random_node_split(data)

In [26]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], conditions=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [27]:
train_data = data.clone()
val_data = data.clone()

In [28]:
train_node_indices = data.train_mask.nonzero()

pure_train_node_edge = torch.isin(data.edge_index, train_node_indices)[0] & torch.isin(data.edge_index, train_node_indices)[1]

train_data.edge_label_index = torch.vstack((data.edge_index[0][pure_train_node_edge],
                                            data.edge_index[1][pure_train_node_edge]))

# Remove duplicate edge indices
edge_label_index_zero = train_data.edge_label_index[0].clone()
edge_label_index_one = train_data.edge_label_index[1].clone()
train_data.edge_label_index[0][edge_label_index_zero > edge_label_index_one] = edge_label_index_one[edge_label_index_zero > edge_label_index_one]
train_data.edge_label_index[1][edge_label_index_zero > edge_label_index_one] = edge_label_index_zero[edge_label_index_zero > edge_label_index_one]
train_data.edge_label_index = torch.unique(train_data.edge_label_index, dim=1)

train_data.edge_label = torch.ones(len(train_data.edge_label_index[0]))

In [29]:
train_data

Data(x=[2708, 1433], edge_index=[2, 10556], conditions=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label_index=[2, 4255], edge_label=[4255])

In [30]:
val_data

Data(x=[2708, 1433], edge_index=[2, 10556], conditions=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [511]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [411]:
train_data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [None]:
torch.isin(data.edge_index, train_node_indices)[0] & torch.isin(data.edge_index, train_node_indices)[1]

In [None]:
result = [vec for vec in tensor1 if not torch.all(torch.eq(tensor, tensor2))]

In [None]:
torch.eq(

In [427]:
len(data.edge_index[0])

10556

In [453]:
test = torch.tensor([   4, 1762])

In [466]:
len(data.edge_index[0])

10556

In [469]:
len(torch.stack([vec for vec in data.edge_index.T if not torch.all(torch.eq(vec, val_data.edge_label_index.T))]))

10556

9525

In [471]:
val_data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label_index=[2, 1031], edge_label=[1031])

In [None]:
torch.eq(test, val_data.edge_label_index.T)[:,0]

tensor([False, False, False,  ..., False, False, False])

In [432]:
val_data.edge_label_index.T

tensor([[   4, 1761],
        [   9,  723],
        [  10,  476],
        ...,
        [2671, 2673],
        [2671, 2674],
        [2671, 2675]])

In [None]:
len(torch.stack(

In [438]:
len(val_data.edge_label_index[0])

1031

In [439]:
len(data.edge_index[0])

10556

In [435]:
len([tensor for tensor in data.edge_index.T if not tensor in val_data.edge_label_index.T])
    

5070

## Random Link Split

In [516]:
random_link_split = RandomLinkSplit(
    num_val=0.1,
    num_test=0.,
    is_undirected=True, 
    neg_sampling_ratio=0.)
train_data_link, val_data_link, test_data_link = random_link_split(data)

In [517]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [518]:
train_data_link

Data(x=[2708, 1433], edge_index=[2, 9502], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[4751], edge_label_index=[2, 4751])

In [519]:
val_data_link

Data(x=[2708, 1433], edge_index=[2, 9502], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[527], edge_label_index=[2, 527])

In [321]:
val_data_link.edge_label_index

tensor([[2024,  306,  415,  ..., 1090,    6, 1218],
        [2178,  542,  525,  ..., 2367, 1416, 2116]])

In [322]:
val_data_link.edge_index

tensor([[ 219, 1408,  484,  ..., 1072, 2046,  488],
        [ 507, 1826,  542,  ..., 1070,  236,  444]])

In [200]:
torch.unique(torch.flatten(train_data.edge_label_index))

tensor([   0,    1,    2,  ..., 2705, 2706, 2707])

In [201]:
torch.unique(torch.flatten(val_data.edge_label_index))

tensor([   0,   11,   18,   21,   22,   24,   25,   33,   35,   37,   38,   39,
          41,   43,   48,   55,   59,   61,   65,   67,   72,   73,   74,   75,
          78,   83,   88,   89,   94,   95,   97,  100,  102,  105,  109,  113,
         114,  118,  119,  121,  131,  132,  133,  137,  139,  140,  143,  151,
         152,  153,  154,  156,  158,  162,  166,  169,  172,  173,  175,  180,
         181,  184,  185,  189,  195,  202,  205,  227,  229,  235,  239,  240,
         242,  253,  261,  263,  270,  275,  277,  279,  281,  286,  290,  295,
         297,  301,  303,  306,  310,  319,  323,  324,  328,  330,  331,  334,
         341,  344,  350,  353,  354,  359,  364,  365,  370,  371,  374,  387,
         391,  392,  394,  397,  398,  404,  406,  409,  411,  415,  426,  428,
         429,  433,  436,  440,  441,  443,  444,  453,  455,  456,  457,  458,
         460,  464,  470,  479,  480,  483,  484,  487,  490,  492,  497,  505,
         506,  507,  510,  511,  512,  5

In [6]:
data.train_mask.sum()

tensor(140)

In [7]:
data.val_mask.sum()

tensor(500)

In [8]:
data.test_mask.sum()

tensor(1000)

In [12]:
data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [190]:
link_loader = LinkNeighborLoader(
    data,
    # Sample 30 neighbors for each node for 2 iterations
    num_neighbors=[20] * 0,
    # Use a batch size of 128 for sampling training nodes
    batch_size=4,
    # edge_label_index=data.edge_index,
    disjoint=True
)

TypeError: __init__() got an unexpected keyword argument 'disjoint'

In [183]:
link_loader_iter = iter(link_loader)

In [184]:
sampled_data = next(link_loader_iter)
print(sampled_data)

Data(x=[6, 1433], edge_index=[2, 0], y=[6], train_mask=[6], val_mask=[6], test_mask=[6], edge_label_index=[2, 4], edge_label=[4])


In [185]:
sampled_data.x[:4]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [186]:
sampled_data.edge_label_index

tensor([[0, 0, 0, 1],
        [3, 4, 5, 2]])

In [172]:
import torch
max_idx_node_with_edge = torch.max(sampled_data.edge_label_index)

In [176]:
sampled_data.x[:max_idx_node_with_edge + 1]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [168]:
sampled_data.edge_index

tensor([], size=(2, 0), dtype=torch.int64)

In [188]:
node_loader = NeighborLoader(
    data,
    # Sample 30 neighbors for each node for 2 iterations
    num_neighbors=[20] * 2,
    # Use a batch size of 128 for sampling training nodes
    batch_size=4,
    input_nodes=data.train_mask,
)

sampled_data_node = next(iter(node_loader))
sampled_data_node

Data(x=[48, 1433], edge_index=[2, 58], y=[48], train_mask=[48], val_mask=[48], test_mask=[48], batch_size=4)

In [None]:
sampled_data_node

In [None]:
sampled_data_node.x.shape