In [13]:
%load_ext autoreload
%autoreload 2

import torch
import wandb
import torch.nn.functional as F

import numpy as np
from pathlib import Path
from torch_geometric.nn.models import GCN
from ogb.nodeproppred import PygNodePropPredDataset
from tqdm import tqdm
from torch_geometric.datasets import Planetoid, Amazon
from torch_geometric.utils import to_undirected
#from OpenGraphCon import OpenGraph
from open_dataset import load_dataset, load_arxiv, load_plentoid, create_class_folds,prepare_fold_masks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
#data = load_plentoid("Cora", train_portion=0.6, val_portion=0.2, test_portion=0.2, seed=0)
data = load_arxiv()
data = load_plentoid("cora")
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [15]:
unknown_class_ratio=0.2
folds = create_class_folds(data, unknown_class_ratio=unknown_class_ratio)
folds
for i, f in enumerate(folds):
    print(i)
    print(f)
path = Path("fold_indices/cora_class_split_"+str(unknown_class_ratio)+".pt")
torch.save(folds, path)

0
tensor([5, 3])
1
tensor([2, 0])
2
tensor([4, 1, 6])


In [16]:
torch.load(path)

[tensor([5, 3]), tensor([2, 0]), tensor([4, 1, 6])]

In [17]:
n_folds = len(folds)
datasets = []

for test_fold_idx in range(n_folds):
    val_fold_idx = (test_fold_idx - 1) % n_folds
    data_new = prepare_fold_masks(data, folds, val_fold_idx, test_fold_idx)
    datasets.append(data_new)

In [18]:
datasets[0]

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], classes=[7], known_classes=[2], val_classes=[3], test_classes=[2], unknown_classes=[5], known_class_mask=[2708], labeled_mask=[2708], val_class_mask=[2708], known_class_val_mask=[2708], unknown_class_val_mask=[2708], all_class_val_mask=[2708], test_class_mask=[2708], known_class_test_mask=[2708], unknown_class_test_mask=[2708], all_class_test_mask=[2708], unlabeled_mask=[2708])

In [19]:
known_val_acc = 2
result_dict = {
        "known_val_acc" : known_val_acc,
    }

In [20]:
def create_23_split(ds_name):
    dataset = Amazon(root='dataset/'+ds_name+"/", name=ds_name)
    #dataset = Planetoid(root='dataset/'+ds_name+"/", name=ds_name)

    y = dataset[0].y.cpu().detach().numpy()
    unique, counts = np.unique(y, return_counts=True)
    train_mask = []
    val_mask = []
    test_mask = []
    for seed in tqdm(range(10)):
        rng = np.random.default_rng(seed)
        train = []
        val = []
        test = []

        for cl in unique:
            tmp = np.argwhere(y==cl)
            rng.shuffle(tmp)
            train.append(tmp[:20])
            val.append(tmp[20:50])
            test.append(tmp[50:])

        train_ix = np.concatenate(train)
        val_ix = np.concatenate(val)
        test_ix = np.concatenate(test)

        train = torch.full_like(dataset[0].y, False, dtype=torch.bool)
        train[train_ix] = True
        val = torch.full_like(dataset[0].y, False, dtype=torch.bool)
        val[val_ix] = True
        test = torch.full_like(dataset[0].y, False, dtype=torch.bool)
        test[test_ix] = True
        train_mask.append(train)
        val_mask.append(val)
        test_mask.append(test)
    dict = {"train":torch.stack(train_mask, 1), "valid":torch.stack(val_mask, 1), "test":torch.stack(test_mask, 1)}
    torch.save(dict,"dataset/"+ds_name+"/own_23_splits.pt")

In [24]:
create_23_split("computers")

Downloading https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_computers.npz
Processing...
Done!
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 86.66it/s]


In [27]:
ds_name = "computers"
splits = torch.load("dataset/"+ds_name+"/own_23_splits.pt")
splits

{'train': tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]),
 'valid': tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]),
 'test': tensor([[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
        

In [29]:
splits["test"]

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])