In [2]:
# 数据导入
import torch
!pip install dgl-cu111 dglgo -f https://data.dgl.ai/wheels/repo.html
from dgl.data import CoraGraphDataset
import scipy.sparse as sp


dataset = CoraGraphDataset()
graph = dataset[0]

def norm_nodes(nodes):
    "节点特征归一化"
    feat = nodes.data['feat']
    sum = feat.sum(1)
    inv = 1. / sum
    inv[torch.isinf(inv)] = 0.
    norm = torch.mm(torch.diag(inv), feat)
    return {'norm': norm}

graph.apply_nodes(norm_nodes)

label = graph.ndata['label']
print('label shape:', label.shape, label) # 不是one-hot编码

train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']

train_idx = graph.nodes()[train_mask]
val_idx = graph.nodes()[val_mask]
test_idx = graph.nodes()[test_mask]

graph.ndata['feat'] = graph.ndata.pop('norm')
feat = graph.ndata['feat']
adj = graph.adj(scipy_fmt='csr')

num_nodes = graph.number_of_nodes()
num_feats = feat.shape[1]
num_classes = dataset.num_classes 

Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl-cu111
  Downloading https://data.dgl.ai/wheels/dgl_cu111-0.8.0.post1-cp37-cp37m-manylinux1_x86_64.whl (252.7 MB)
[K     |████████████████████████████████| 252.7 MB 56 kB/s 
[?25hCollecting dglgo
  Downloading dglgo-0.0.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 2.7 MB/s 
Collecting isort>=5.10.1
  Downloading isort-5.10.1-py3-none-any.whl (103 kB)
[K     |████████████████████████████████| 103 kB 15.0 MB/s 
[?25hCollecting typer>=0.4.0
  Downloading typer-0.4.0-py3-none-any.whl (27 kB)
Collecting pydantic>=1.9.0
  Downloading pydantic-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.9 MB)
[K     |████████████████████████████████| 10.9 MB 43.8 MB/s 
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 50.3 MB

DGL backend not selected or invalid.  Assuming PyTorch for now.


Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.
label shape: torch.Size([2708]) tensor([4, 4, 4,  ..., 4, 3, 3])


In [3]:
hid_feats = 512
lr = 1e-3
lr2 = 1e-2
l2_coef = .0
aug_type = 'node' # node mask edge subgraph
patience = 20
epochs = 500
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

In [12]:
! nvidia-smi

Fri Mar 25 12:10:49 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   67C    P0   115W / 149W |   6270MiB / 11441MiB |     62%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
# 数据处理
# 四种数据增强方法 data augumentation
#   - Node dropping 
#   - Edge perturbation
#   - Attribute masking
#   - Subgraph: Random Walk

import random
import copy
import dgl
from dgl.transforms import DropNode, DropEdge, AddEdge
from dgl.sampling import node2vec_random_walk

def aug_random_mask(graph, drop_prob=0.2):
    num_nodes = graph.number_of_nodes()
    num_mask = int(num_nodes * drop_prob)
    node_idx = list(range(num_nodes))
    mask_idx = random.sample(node_idx, num_mask)
    
    zero = torch.zeros_like(graph.ndata['feat'][0])
    aug_graph = copy.deepcopy(graph)
    aug_graph.ndata['feat'][mask_idx] = zero # broadcast
    
    return aug_graph

def aug_drop_node(graph, drop_prob=0.2):
    transform = DropNode(p=drop_prob)
    aug_graph = copy.deepcopy(graph)
    aug_graph = transform(aug_graph)
    
    return aug_graph

def aug_random_edge(graph, drop_prob=0.2):
    num_edges = graph.number_of_edges()
    num_drop_edges = int(num_edges * drop_prob)
    add_ratio = num_drop_edges / (num_edges - num_drop_edges)
    aug_graph = copy.deepcopy(graph)
    
    transform = DropEdge(p=drop_prob)
    aug_graph = transform(aug_graph)
    transform = AddEdge(ratio=add_ratio)
    aug_graph = transform(aug_graph)
    
    return aug_graph

def aug_subgraph(graph, drop_prob=0.2):
    num_nodes = graph.number_of_nodes()
    num_subgraph_nodes = int(num_nodes * (1 - drop_prob))
    center_node_id = random.randint(0, num_nodes - 1)
    
    adj = graph.adj(scipy_fmt='csr')
    
    subgraph_idx = [center_node_id]
    neighbor_idx = []
    
    for i in range(num_subgraph_nodes - 1):
        neighbor_idx.extend(adj[subgraph_idx[i]].nonzero()[1]) # Random walk最后一个节点
        neighbor_idx = list(set(neighbor_idx))
        neighbor_idx = [idx for idx in neighbor_idx if idx not in subgraph_idx]
    
        if len(neighbor_idx) != 0:
            new_node = random.sample(neighbor_idx, 1)[0]
            subgraph_idx.append(new_node)
        else:
            break
    
    aug_graph = dgl.node_subgraph(graph, subgraph_idx)
    return aug_graph



if aug_type == 'node':
    aug_view1 = aug_drop_node(graph)
    aug_view2 = aug_drop_node(graph)
elif aug_type == 'edge':
    aug_view1 = aug_random_edge(graph)
    aug_view2 = aug_random_edge(graph)
elif aug_type == 'mask':
    aug_view1 = aug_random_mask(graph)
    aug_view2 = aug_random_mask(graph)
elif aug_type == 'subgraph':
    aug_view1 = aug_subgraph(graph)
    aug_view2 = aug_subgraph(graph)
else:
    print('No aug type {}'.format(aug_type))
    assert False
    
    
graph = graph.add_self_loop()
aug_view1 = aug_view1.add_self_loop()
aug_view2 = aug_view2.add_self_loop()


In [10]:
import torch
import torch.nn as nn
from dgl.nn import GraphConv, AvgPooling

class Discriminator(nn.Module):
    def __init__(self, hid_feats):
        super(Discriminator, self).__init__()
        self.fc = nn.Bilinear(hid_feats, hid_feats, 1)
        
    def forward(self, c, h_pl, h_mi):
        # c 全局表示 h_pl 正例节点表示 h_mi 负例节点表示
        # c_x = torch.unsqueeze(c, 1)
        c_x = c.expand_as(h_pl).contiguous()
        
        sc_1 = torch.squeeze(self.fc(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.fc(h_mi, c_x), 1)
        
        logits = torch.cat((sc_1, sc_2))
        return logits

class DGI(nn.Module):
    def __init__(self, in_feats, hid_feats):
        super(DGI, self).__init__()
        self.gcn = GraphConv(in_feats, hid_feats, norm='both', bias=True, activation=nn.PReLU())
        self.read = AvgPooling()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(hid_feats)
    
    def forward(self, graph, aug_view1, aug_view2, shuf_feat):
        h_0 = self.gcn(graph, graph.ndata['feat'])
        h_2 = self.gcn(graph, shuf_feat)
        
        h_1 = self.gcn(aug_view1, aug_view1.ndata['feat'])
        h_3 = self.gcn(aug_view2, aug_view2.ndata['feat'])
        
        c_1 = self.sigm(self.read(aug_view1, h_1))
        c_3 = self.sigm(self.read(aug_view2, h_3))
        
        ret1 = self.disc(c_1, h_0, h_2)
        ret2 = self.disc(c_3, h_0, h_2)
        
        return ret1 + ret2
    
    def get_embedding(self, graph):
        h_1 = self.gcn(graph, graph.ndata['feat'])
        c = self.read(graph, h_1)
        
        return h_1.detach(), c.detach()
    

In [18]:
import numpy as np

model = DGI(num_feats, hid_feats)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)
loss_fn = nn.BCEWithLogitsLoss()

model.to(device)
graph = graph.to(device)
aug_view1 = aug_view1.to(device)
aug_view2 = aug_view2.to(device)


best = float('inf')
cnt_wait = 0

for e in range(epochs):
    model.train()
    optimizer.zero_grad()
    
    shuf_idx = np.random.permutation(num_nodes)
    shuf_feat = feat[shuf_idx, :]
    
    lbl1 = torch.ones(num_nodes)
    lbl2 = torch.zeros(num_nodes)
    lbl = torch.cat((lbl1, lbl2), 0)

    lbl = lbl.to(device)
    shuf_feat = shuf_feat.to(device)
    
    logits = model(graph, aug_view1, aug_view2, shuf_feat)
    loss = loss_fn(logits, lbl)
    print('Epoch: {:03d}, Loss: {:.4f}'.format(e, loss.item()))
    
    if loss < best:
        best = loss
        cnt_wait = 0
        torch.save(model.state_dict(), 'graphcl.pkl')
    else:
        cnt_wait += 1
        
    if cnt_wait == patience:
        print('Early stopping!')
        break
    
    loss.backward()
    optimizer.step()
    
model.load_state_dict(torch.load('graphcl.pkl'))
embeds, _ = model.get_embedding(graph)
    


Epoch: 000, Loss: 0.6926
Epoch: 001, Loss: 0.7176
Epoch: 002, Loss: 0.6847
Epoch: 003, Loss: 0.6917
Epoch: 004, Loss: 0.6921
Epoch: 005, Loss: 0.6857
Epoch: 006, Loss: 0.6834
Epoch: 007, Loss: 0.6854
Epoch: 008, Loss: 0.6798
Epoch: 009, Loss: 0.6744
Epoch: 010, Loss: 0.6731
Epoch: 011, Loss: 0.6687
Epoch: 012, Loss: 0.6621
Epoch: 013, Loss: 0.6565
Epoch: 014, Loss: 0.6526
Epoch: 015, Loss: 0.6431
Epoch: 016, Loss: 0.6352
Epoch: 017, Loss: 0.6259
Epoch: 018, Loss: 0.6147
Epoch: 019, Loss: 0.6080
Epoch: 020, Loss: 0.5921
Epoch: 021, Loss: 0.5823
Epoch: 022, Loss: 0.5653
Epoch: 023, Loss: 0.5606
Epoch: 024, Loss: 0.5432
Epoch: 025, Loss: 0.5243
Epoch: 026, Loss: 0.5145
Epoch: 027, Loss: 0.4969
Epoch: 028, Loss: 0.4826
Epoch: 029, Loss: 0.4662
Epoch: 030, Loss: 0.4495
Epoch: 031, Loss: 0.4419
Epoch: 032, Loss: 0.4188
Epoch: 033, Loss: 0.4053
Epoch: 034, Loss: 0.3942
Epoch: 035, Loss: 0.3974
Epoch: 036, Loss: 0.3711
Epoch: 037, Loss: 0.3572
Epoch: 038, Loss: 0.3650
Epoch: 039, Loss: 0.3401


In [21]:
# down stream task
# node classification 只有节点的嵌入

# model
class LogReg(nn.Module):
    def __init__(self, hid_dim, n_classes):
        super(LogReg, self).__init__()
        
        self.fc = nn.Linear(hid_dim, n_classes)
        
    def forward(self, x):
        h = self.fc(x)
        # h = torch.log_softmax(h, dim=-1)
        return h

# evaluation 
embeds = embeds.to('cpu')
train_embs = embeds[train_idx]
test_embs = embeds[test_idx]


train_labels = label[train_idx]
test_labels = label[test_idx]
accs = []

for _ in range(10):
  model = LogReg(hid_feats, num_classes)

  opt = torch.optim.Adam(model.parameters(), lr=lr2, weight_decay=l2_coef)
  loss_fn = nn.CrossEntropyLoss()
  # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

  for epoch in range(300):
    model.train()

    logits = model(train_embs)
    loss = loss_fn(logits, train_labels) # target可以接受 shape (N) 或 (N, d_1, ..., d_k)

    opt.zero_grad()
    loss.backward()
    opt.step()

  model.eval()
  logits = model(test_embs)
  preds = torch.argmax(logits, dim=1)
  acc = torch.sum(preds == test_labels).float() / test_labels.shape[0]
  accs.append(acc * 100)

accs = torch.stack(accs) # Concatenates a sequence of tensors along a new dimension. 类型转换
print(accs.mean().item(), accs.std().item())
  


82.07000732421875 0.0823286846280098
