## import necessary modules

In [238]:
import torch
import dgl
from dgl.data import DGLDataset
import h5py
import numpy as np
import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import save_info

## construct dataset class to store graph dataset

In [239]:
class SignNet(DGLDataset):
    def __init__(self, file_name:str, dataset_path:str, train_ratio:float=0.7, mode:str="read", save_dir:str="")->None: 
        self.train_ratio = train_ratio
        self.mode = mode
        super().__init__(name=file_name, raw_dir=dataset_path, save_dir=save_dir)
    
    def process(self) -> None:
        # read data from file to generate Dataset
        np_data = np.transpose(np.array(h5py.File(name=self.raw_path)["Gwl_ud"])).astype(np.byte)
        torch_data = torch.from_numpy(np_data).type(torch.int8) # transform array into tensor
        del np_data 
        src, dst = torch.nonzero(torch_data, as_tuple=True)
        self.graph = dgl.graph((src,dst), num_nodes=torch_data.shape[0])
        self.graph.edata["sign"] = torch_data[src,dst]
        self.graph.ndata["conn"] = torch_data

        # divide dataset into training, validation and testing class
        num_of_train_nodes = int(self.train_ratio * torch_data.shape[0])
        num_of_val_nodes = int((torch_data.shape[0] - num_of_train_nodes)/2)
        
        train_mask = torch.zeros(torch_data.shape[0], dtype=torch.bool)
        val_mask = torch.zeros(torch_data.shape[0], dtype=torch.bool)
        test_mask = torch.zeros(torch_data.shape[0], dtype=torch.bool)

        train_mask[:num_of_train_nodes] = True
        val_mask[num_of_train_nodes:num_of_train_nodes+num_of_val_nodes] = True
        test_mask[num_of_train_nodes+num_of_val_nodes:] = True

        # save mask as attribute in dataset
        self.graph.ndata["train_mask"] = train_mask
        self.graph.ndata["val_mask"] = val_mask
        self.graph.ndata["test_mask"] = test_mask

    def __getitem__(self, idx)->None:
        #return self.graph.srcdata[idx]
        return self.graph.ndata["conn"][idx]
    
    def __len__(self)-> None:
        return 1
        # return self.graph.num_nodes()

    def save(self)->None:
        # save graph
        graph_path = os.path.join(self.save_path[:-4], self.name[0:-7] + '_dgl_graph.bin')
        save_graphs(graph_path, self.graph)
        
        # save graph infomation
        info_path = os.path.join(self.save_path[:-4], self.name[0:-7] + '_info.pkl')
        save_info(info_path, {'graph_type': "directed signed network"})

    def load(self)->None:
        # load processed data from directory `self.save_path`
        graph_path = os.path.join(self.save_path[:-4], self._name[0:-7] + '_dgl_graph.bin')
        graphs, _ = load_graphs(graph_path)
        self.graph = graphs[0]

    def has_cache(self)->bool:
        # check whether there are processed data in `self.save_path`
        graph_path = os.path.join(self.save_path[:-4], self._name[0:-7] + '_dgl_graph.bin')
        return os.path.exists(graph_path)

## test dataset SignNet

In [1]:
if __name__ == "__main__":
    epinions = SignNet(file_name="epinions_UD.mat", dataset_path="dataset", save_dir="dataset_cache", train_ratio=0.6)
    slashdot = SignNet(file_name="slashdot_UD.mat", dataset_path="dataset", save_dir="dataset_cache", train_ratio=0.6)
    wiki = SignNet(file_name="wiki_UD.mat", dataset_path="dataset", save_dir="dataset_cache", train_ratio=0.6)

NameError: name 'SignNet' is not defined