有几个模块

1. DataModule：实现数据模块，能够分批加载训练、验证数据
2. SAGELightning：继承LightningModule的模块，用于sage的信息传递
3. CrossEntropyLoss：定义损失函数，graphsage里的无监督训练
4. UnsuptervisedClassfication：无监督训练，产出emb后进行下游的分类任务


## 1. data loader

In [9]:
import dgl
import torch as th
from dgl.data import AsNodePredDataset

In [2]:
def load_cora():
    data0 = dgl.data.CSVDataset('../graph_dgl/cora_csv/')
    data = AsNodePredDataset(data0, split_ratio=(0.5,0.2,0.3))
    g = data[0]
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
    return g, data.num_classes

In [3]:
def load_reddit(self_loop=True, raw_dir='~/.dgl/'):
    from dgl.data import RedditDataset

    # load reddit data
    data = RedditDataset(self_loop=self_loop, raw_dir=raw_dir)
    g = data[0]
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
    return g, data.num_classes

In [4]:
g, n_classes = load_cora()

Done loading data from cached files.


In [5]:
n_edges = g.num_edges()
reverse_eids = torch.cat([
    torch.arange(n_edges // 2, n_edges),
    torch.arange(0, n_edges // 2)])

In [6]:
n_edges

5429

In [7]:
reverse_eids

tensor([2714, 2715, 2716,  ..., 2711, 2712, 2713])

In [10]:
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0]

In [11]:
def inductive_split(g):
    """Split the graph into training graph, validation graph, and test graph by training
    and validation masks.  Suitable for inductive models."""
    train_g = g.subgraph(g.ndata["train_mask"])
    val_g = g.subgraph(g.ndata["train_mask"] | g.ndata["val_mask"])
    test_g = g
    return train_g, val_g, test_g

In [12]:
fan_out=[10, 25]
sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out])

In [13]:
sampler

<dgl.dataloading.neighbor_sampler.NeighborSampler at 0x7fbc6187ab30>

In [14]:
g.ndata.keys()

dict_keys(['train_mask', 'val_mask', 'test_mask', 'features', 'labels'])

## 2. nagetive sampler