In [None]:
import dgl
import torch
import numpy as np
import os
import random
# from pygod.utils import load_data
import pandas
import bidict
from dgl.data import FraudAmazonDataset, FraudYelpDataset
from sklearn.model_selection import train_test_split

def set_seed(seed=3407):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

In [None]:
class Dataset:
    def __init__(self, name='tfinance', prefix='datasets/',homo=True, add_self_loop=True, to_bidirectional=False, to_simple=True):
        if name == 'yelp':
            dataset = FraudYelpDataset()
            graph = dataset[0]
            graph.ndata['train_mask'] = graph.ndata['train_mask'].bool()
            graph.ndata['val_mask'] = graph.ndata['val_mask'].bool()
            graph.ndata['test_mask'] = graph.ndata['test_mask'].bool()
            if homo:
                graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])

        elif name == 'amazon':
            dataset = FraudAmazonDataset()
            graph = dataset[0]
            graph.ndata['train_mask'] = graph.ndata['train_mask'].bool()
            graph.ndata['val_mask'] = graph.ndata['val_mask'].bool()
            graph.ndata['test_mask'] = graph.ndata['test_mask'].bool()
            graph.ndata['mark'] = graph.ndata['train_mask']+graph.ndata['val_mask']+graph.ndata['test_mask']
            if homo:
                graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask', 'mark'])

        else:
            graph = dgl.load_graphs(prefix+name)[0][0]
        graph.ndata['feature'] = graph.ndata['feature'].float()
        graph.ndata['label'] = graph.ndata['label'].long()
        self.name = name
        self.graph = graph
        if add_self_loop:
            self.graph = dgl.add_self_loop(self.graph)
        if to_bidirectional:
            self.graph = dgl.to_bidirected(self.graph, copy_ndata=True)
        if to_simple:
            self.graph = dgl.to_simple(self.graph)

    def split(self, samples=20):
        labels = self.graph.ndata['label']
        n = self.graph.num_nodes()
        index = list(range(n))
        train_masks = torch.zeros([n,20]).bool()
        val_masks = torch.zeros([n,20]).bool()
        test_masks = torch.zeros([n,20]).bool()
        # official split
        train_masks[:,:10] = self.graph.ndata['train_mask'].repeat(10,1).T
        val_masks[:,:10] = self.graph.ndata['val_mask'].repeat(10,1).T
        test_masks[:,:10] = self.graph.ndata['test_mask'].repeat(10,1).T


        for i in range(10):
            pos_index = np.where(labels == 1)[0]
            neg_index = list(set(index) - set(pos_index))
            pos_train_idx = np.random.choice(pos_index, size=2*samples, replace=False)
            neg_train_idx = np.random.choice(neg_index, size=8*samples, replace=False)
            train_idx = np.concatenate([pos_train_idx[:samples], neg_train_idx[:4*samples]])
            train_masks[train_idx, 10+i] = 1
            val_idx = np.concatenate([pos_train_idx[samples:], neg_train_idx[4*samples:]])
            val_masks[val_idx, 10+i] = 1
            test_masks[index, 10+i] = 1
            test_masks[train_idx, 10+i] = 0
            test_masks[val_idx, 10+i] = 0

        self.graph.ndata['train_masks'] = train_masks
        self.graph.ndata['val_masks'] = val_masks
        self.graph.ndata['test_masks'] = test_masks

In [None]:
prefix = '/data/sx/NFTGraph'

In [None]:
for data_name in ['tinynftgraph', 'nftgraph']:
    data = Dataset(data_name,prefix=prefix+'/datasets/dgl_graph/')
    data.split()
    print(data.graph)
    print(data.graph.ndata['train_masks'].sum(0), data.graph.ndata['val_masks'].sum(0), data.graph.ndata['test_masks'].sum(0))
    dgl.save_graphs('datasets/'+data_name, [data.graph])