In [2]:
import torch
import numpy as np
from torch_geometric.datasets import Planetoid, Amazon, Actor, WikipediaNetwork, HeterophilousGraphDataset
from ogb.nodeproppred import PygNodePropPredDataset
from pathlib import Path
from tqdm import tqdm

In [6]:
#create own (pitfalls/geom gcn) split (per class 20 train, 30 val, others test)
def create_23_split(ds_name):
    #dataset = Amazon(root='dataset/'+ds_name+"/", name=ds_name)
    dataset = Planetoid(root='dataset/'+ds_name+"/", name=ds_name)
    
    y = dataset[0].y.cpu().detach().numpy()
    unique, counts = np.unique(y, return_counts=True)
    train_mask = []
    val_mask = []
    test_mask = []
    for seed in tqdm(range(10)):
        rng = np.random.default_rng(seed)
        train = []
        val = []
        test = []

        for cl in unique:
            tmp = np.argwhere(y==cl)
            rng.shuffle(tmp)
            train.append(tmp[:20])
            val.append(tmp[20:50])
            test.append(tmp[50:])

        train_ix = np.concatenate(train)
        val_ix = np.concatenate(val)
        test_ix = np.concatenate(test)

        train = torch.full_like(dataset[0].y, False, dtype=torch.bool)
        train[train_ix] = True
        val = torch.full_like(dataset[0].y, False, dtype=torch.bool)
        val[val_ix] = True
        test = torch.full_like(dataset[0].y, False, dtype=torch.bool)
        test[test_ix] = True
        train_mask.append(train)
        val_mask.append(val)
        test_mask.append(test)
    dict = {"train":torch.stack(train_mask, 1), "valid":torch.stack(val_mask, 1), "test":torch.stack(test_mask, 1)}
    torch.save(dict,"dataset/"+ds_name+"/own_23_splits.pt")


In [4]:
# for name in ["Photo", "Computers"]:
#     create_23_split(name)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 110.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 69.74it/s]


In [7]:
# for name in ["Cora", "CiteSeer", "PubMed"]:
#     create_23_split(name)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 372.95it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 259.40it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 51.86it/s]


In [3]:
def load_ds(name):
        splits = None
        if name in ["Cora", "CiteSeer", "PubMed"]:
            ds = Planetoid(root='dataset/'+name+"/", name=name)
            splits = torch.load("dataset/"+name+"/own_23_splits.pt")
        elif name in ["Roman-empire", "Minesweeper"]:
            ds = HeterophilousGraphDataset(root='dataset/'+name+"/", name=name)
            splits = {"train":ds[0].train_mask, "valid":ds[0].val_mask, "test":ds[0].test_mask}
        elif name in ["Photo", "Computers"]:
            ds = Amazon(root='dataset/'+name+"/", name=name)
            splits = torch.load("dataset/"+name+"/own_23_splits.pt")
        elif name in ["Chameleon", "Squirrel"]:
            ds = WikipediaNetwork(root="dataset/"+name+"/", name = name)
            splits = {"train":ds[0].train_mask, "valid":ds[0].val_mask, "test":ds[0].test_mask}
        elif name == "Actor":
            ds = Actor(root="dataset/Actor/")
            splits = {"train":ds[0].train_mask, "valid":ds[0].val_mask, "test":ds[0].test_mask}
        elif name == "Arxiv":
            ds = PygNodePropPredDataset(name = "ogbn-arxiv")
            splits = ds.get_idx_split()
        elif name == "Products":
            ds = PygNodePropPredDataset(name = "ogbn-products")
            splits = ds.get_idx_split()
        return ds, splits

#create own NOSMOG inductive split (80% of test as unlabeled structure)
def create_ind82_split(ds_name):
    dataset, split = load_ds(ds_name)

    if ds_name not in ["Arxiv", "Products"]: #masks are boolean and provided per split
        struct_msk = []
        test_msk = []
        for seed in range(10):
            rng = np.random.default_rng(seed)
            tmp = np.argwhere(split["test"][:,seed]==1)[0].numpy()
            rng.shuffle(tmp)
            l = len(tmp)
            struc = torch.full_like(dataset[0].y, False, dtype=torch.bool)
            struc[tmp[0:round(0.8*l)]]=True
            test = torch.full_like(dataset[0].y, False, dtype=torch.bool)
            test[tmp[round(0.8*l):]] = True
            struct_msk.append(struc)
            test_msk.append(test)
    elif ds_name in ["Arxiv", "Products"]: #fixed train test splits and masks are index lists
        struct_msk = []
        test_msk = []
        for seed in range(10):
            rng = np.random.default_rng(seed)
            tmp = split["test"].numpy()
            rng.shuffle(tmp)
            l = len(tmp)
            struct_msk.append(torch.from_numpy(tmp[0:round(0.8*l)]))
            test_msk.append(torch.from_numpy(tmp[round(0.8*l):]))
        ds_name = "ogbn_" + ds_name.lower()
    else:
        raise NotImplementedError(ds_name+ " is an unkown ds")

    dict = {"train":split["train"], "valid":split["valid"], "structure":torch.stack(struct_msk, 1), "test":torch.stack(test_msk, 1)}
    torch.save(dict,"dataset/"+ds_name+"/own_82_splits.pt")
    #print(ds_name, split["test"].shape, dict["structure"].shape, dict["test"].shape, split["test"].sum(), dict["structure"].sum(), dict["test"].sum())


In [4]:
for ds in ["Cora", "CiteSeer", "PubMed", "Computers", "Photo", "Chameleon", "Squirrel", "Actor","Roman-empire", "Minesweeper"]:#, "Arxiv"]:
    create_ind82_split(ds)
    #d, s = load_ds(ds)
    #print(ds, s["train"].shape, s["test"].shape)

In [33]:
#create_23_split("Computers")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 46.18it/s]


In [34]:
#create_23_split("Photo")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 72.54it/s]
