不管创建哪种数据集，都要继承`dgl.data.DGLDataset`类。并定义以下三种方法：
- process: 创建静态数据部分，例如从数据源读取数据，创建图并为节点和边分配特征。
- __getitem__: 对于单图直接返回图对象，对于多图，指定索引返回图对象。（一般有固定格式）
- __len__: 返回数据集中图的数量。（一般有固定格式）

## 1. 单图数据集

In [1]:
import urllib.request

import pandas as pd

urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/members.csv", "./data/members.csv"
)
urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/interactions.csv",
    "./data/interactions.csv",
)

members = pd.read_csv("./data/members.csv")
interactions = pd.read_csv("./data/interactions.csv")

In [2]:
'''
在上面的数据中，id是唯一的，Club有3个
'''

print("Club总共有%d种" %len(members['Club'].unique()))
members['Club'].astype('category').cat.codes.to_numpy()

Club总共有2种


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)

In [3]:
import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
from dgl.data import DGLDataset


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class TestClubDataset(DGLDataset):
    def __init__(self):
        super().__init__(name="test_club")
    
    def process(self):
        nodes_data = pd.read_csv("./data/members.csv")
        edges_data = pd.read_csv("./data/interactions.csv")
        node_features = torch.tensor(nodes_data["Age"].to_numpy())
        #NOTE: 将字符串用one-hot编码
        node_labels = torch.tensor(nodes_data["Club"].astype("category").cat.codes.to_numpy())
        edge_features = torch.tensor(edges_data["Weight"].to_numpy())
        edge_src = torch.tensor(edges_data["Src"].to_numpy())
        edge_dst = torch.tensor(edges_data["Dst"].to_numpy())
        
        self.graph = dgl.graph((edge_src, edge_dst), num_nodes=nodes_data.shape[0])
        self.graph.ndata['feat'] = node_features
        self.graph.ndata['label'] = node_labels
        self.graph.edata['weight'] = edge_features
        
        # If your dataset is a node classification dataset, you will need to assign
        # masks indicating whether a node belongs to training, validation, and test set.
        
        '''
        这里用简单的方式直接指定前60%的节点为训练集，20%的节点为验证集，20%的节点为测试集
        '''
        
        n_nodes = nodes_data.shape[0]
        n_train = int(n_nodes * 0.6)
        n_val = int(n_nodes * 0.2)
        train_mask = torch.zeros(n_nodes, dtype=torch.bool)
        val_mask = torch.zeros(n_nodes, dtype=torch.bool)
        test_mask = torch.zeros(n_nodes, dtype=torch.bool)
        train_mask[:n_train] = True
        val_mask[n_train : n_train + n_val] = True
        test_mask[n_train + n_val :] = True
        self.graph.ndata["train_mask"] = train_mask
        self.graph.ndata["val_mask"] = val_mask
        self.graph.ndata["test_mask"] = test_mask

    def __getitem__(self, i):
        #NOTE: 只有1个图，所以直接返回即可,i都用不到
        return self.graph

    def __len__(self):
        return 1
    
dataset = TestClubDataset()
graph = dataset[0]



In [5]:
print(graph)
print(graph.ndata)

Graph(num_nodes=34, num_edges=156,
      ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int8), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float64)})
{'feat': tensor([44, 37, 37, 40, 30, 32, 36, 47, 35, 37, 35, 46, 46, 48, 41, 49, 46, 38,
        44, 41, 48, 34, 43, 41, 40, 34, 38, 42, 42, 44, 48, 41, 35, 46]), 'label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int8), 'train_mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False]), 'val_mask': tensor([False, False, False,

## 2. 多图数据集

In [6]:
'''
从CSV文件创建多图数据集，以图分类为例。
数据包含多个图本身以及对应的label。

* 每个表多了一个graph_id列，用于标识图的ID

'''

urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./data/graph_edges.csv"
)
urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
    "./data/graph_properties.csv",
)
edges = pd.read_csv("./data/graph_edges.csv")
properties = pd.read_csv("./data/graph_properties.csv")

edges.head()

properties.head()

Unnamed: 0,graph_id,label,num_nodes
0,0,0,15
1,1,0,10
2,2,0,13
3,3,0,13
4,4,0,17


In [7]:
class SyntheticDataset(DGLDataset):
    def __init__(self):
        super().__init__(name="synthetic")

    def process(self):
        edges = pd.read_csv("./data/graph_edges.csv")
        properties = pd.read_csv("./data/graph_properties.csv")
        self.graphs = []
        self.labels = []

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        for _, row in properties.iterrows():
            label_dict[row["graph_id"]] = row["label"]
            num_nodes_dict[row["graph_id"]] = row["num_nodes"]

        # For the edges, first group the table by graph IDs.
        # NOTE: pandas的groupby方法, 返回的是一个按图编号的GroupBy对象
        edges_group = edges.groupby("graph_id")

        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id["src"].to_numpy()
            dst = edges_of_id["dst"].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]

            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes) # NOTE: 加num_nodes参数可以避免节点不连续以及存在鼓励节点的情况
            #NOTE: 最关键的就是在list里追加图和标签
            self.graphs.append(g)
            self.labels.append(label)

        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)




In [8]:
dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)

Graph(num_nodes=15, num_edges=45,
      ndata_schemes={}
      edata_schemes={}) tensor(0)
