有几个模块

1. DataModule：实现数据模块，能够分批加载训练、验证数据
2. SAGELightning：继承LightningModule的模块，用于sage的信息传递
3. CrossEntropyLoss：定义损失函数，graphsage里的无监督训练
4. UnsuptervisedClassfication：无监督训练，产出emb后进行下游的分类任务


## 1. data loader

In [22]:
import dgl
import torch as th
from dgl.data import AsNodePredDataset
import numpy as np 

In [2]:
def load_cora():
    data0 = dgl.data.CSVDataset('../graph_dgl/cora_csv/')
    data = AsNodePredDataset(data0, split_ratio=(0.5,0.2,0.3))
    g = data[0]
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
    return g, data.num_classes

In [3]:
def load_reddit(self_loop=True, raw_dir='~/.dgl/'):
    from dgl.data import RedditDataset

    # load reddit data
    data = RedditDataset(self_loop=self_loop, raw_dir=raw_dir)
    g = data[0]
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
    return g, data.num_classes

In [4]:
g, n_classes = load_cora()

Done loading data from cached files.


In [5]:
n_edges = g.num_edges()
reverse_eids = torch.cat([
    torch.arange(n_edges // 2, n_edges),
    torch.arange(0, n_edges // 2)])

In [6]:
n_edges

5429

In [7]:
reverse_eids

tensor([2714, 2715, 2716,  ..., 2711, 2712, 2713])

In [10]:
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0]

In [11]:
def inductive_split(g):
    """Split the graph into training graph, validation graph, and test graph by training
    and validation masks.  Suitable for inductive models."""
    train_g = g.subgraph(g.ndata["train_mask"])
    val_g = g.subgraph(g.ndata["train_mask"] | g.ndata["val_mask"])
    test_g = g
    return train_g, val_g, test_g

In [15]:
fan_out=[10, 25]
base_sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out])

## 2. nagetive sampler

In [18]:
class NegativeSampler(object):
    def __init__(self, g, k, neg_share=False, device=None):
        if device is None:
            device = g.device
        self.weights = g.in_degrees().float().to(device) ** 0.75
        self.k = k
        self.neg_share = neg_share

    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        n = len(src)
        if self.neg_share and n % self.k == 0:
            dst = self.weights.multinomial(n, replacement=True)
            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
        else:
            dst = self.weights.multinomial(n * self.k, replacement=True)
        src = src.repeat_interleave(self.k)
        return src, dst


In [20]:
sampler = dgl.dataloading.as_edge_prediction_sampler(
            base_sampler, exclude='reverse_id',
            reverse_eids=reverse_eids,
            negative_sampler=NegativeSampler(g, 1, False))

In [28]:
train_dataloder = dgl.dataloading.DataLoader(
            g,
            np.arange(g.num_edges()),
            sampler,
            device=th.device('cpu'),
            batch_size=128,
            shuffle=True,
            drop_last=False,
            num_workers=0)

In [29]:
one_batch = next(iter(train_dataloder))

In [31]:
len(one_batch) # input_nodes, pos_graph, neg_graph, mfgs = batch

4

In [32]:
one_batch[0]

tensor([ 543, 1095,  552,  379,   28,   46,  435,  822,  331,   71,    0,  301,
        1134,  589,  325,   97, 1543,  236,   15, 1713,  146,  424,    3, 1850,
          62,  493,  577, 1039, 1304, 1287,  184, 1549, 1662, 1004,  486,  341,
         857,  858, 1724,   73,  568, 1698,  618,  784, 1069,  470, 1040,  414,
          74, 1210,  328, 1344,  713, 1077, 1305,  511, 1207,  579,  173,  672,
         228,  259,   39,  479,  188, 1113,   47,   87, 1608,  642, 1712,  509,
         513,  333, 1034,  921,   83,  812, 1502,  728, 1050,  527,  503, 1200,
         529,  445,  619, 1710, 1491,  652,   76,  355,  195,   49,  452, 1184,
         608,   13,  111,  121,  701,   52, 1716, 1111,  446,  898,  272, 1676,
        1074,  279,  402, 1403,   40,  282,   64, 2065, 2685, 1758,  549, 2637,
        2655,  845,  370,  737, 1754, 1126, 1485, 2251,  518,  324,   98, 2329,
        1944,  790, 1952, 2461, 1101,  477, 2560,   37, 1942, 2001, 2143, 2266,
        1925,   58, 1663,  546,  804,  5