Skip to content

Commit

Permalink
[Feature] Update UnsupGraphsage (#425)
Browse files Browse the repository at this point in the history
* [Feature] Update UnsupGraphsage
  • Loading branch information
QingFei1 committed Apr 17, 2023
1 parent e54f16a commit d52b271
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 46 deletions.
9 changes: 6 additions & 3 deletions cogdl/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,15 @@ def __getitem__(self, idx):
"""
batch = self.node_idx[idx * self.batch_size : (idx + 1) * self.batch_size]
self.random_walker.build_up(self.edge_index, self.total_num_nodes)
walk_res=self.random_walker.walk_one(batch,length=1,p=0.0)

walk_res = self.random_walker.walk(
batch, walk_length=2, parallel=False
)[:,1]

neg_batch = torch.randint(0, self.total_num_nodes, (batch.numel(), ),
dtype=torch.int64)
pos_batch=torch.tensor(walk_res)
batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
if self.sizes != [-1]:
batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
node_id = batch
adj_list = []
for size in self.sizes:
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def build_model(args):
"sortpool": "cogdl.models.nn.sortpool.SortPool",
"srgcn": "cogdl.models.nn.srgcn.SRGCN",
"gcc": "cogdl.models.nn.gcc_model.GCCModel",
"unsup_graphsage": "cogdl.models.nn.graphsage.Graphsage",
"unsup_graphsage": "cogdl.models.nn.graphsage.UnsupGraphsage",
"graphsaint": "cogdl.models.nn.graphsaint.GraphSAINT",
"m3s": "cogdl.models.nn.m3s.M3S",
"moe_gcn": "cogdl.models.nn.moe_gcn.MoEGCN",
Expand Down
16 changes: 16 additions & 0 deletions cogdl/models/nn/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,19 @@ def forward(self, graph):
for layer in self.layers:
x = layer(graph, x)
return x

class UnsupGraphsage(Graphsage):
def __init__(self, num_features, num_classes, hidden_size, num_layers, sample_size, dropout, aggr):
super(Graphsage, self).__init__()
assert num_layers == len(sample_size)
self.adjlist = {}
self.num_features = num_features
self.num_classes = num_classes
self.hidden_size = hidden_size
self.num_layers = num_layers
self.sample_size = sample_size
self.dropout = dropout
shapes = [num_features] + hidden_size * num_layers
self.convs = nn.ModuleList(
[SAGELayer(shapes[layer], shapes[layer + 1], aggr=aggr) for layer in range(num_layers)]
)
22 changes: 0 additions & 22 deletions cogdl/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,3 @@ def walk(self, start, walk_length, restart_p=0.0, parallel=True):
result = random_walk_single(start, walk_length, self.indptr, self.indices, restart_p)
result = np.array(result, dtype=np.int64)
return result

def walk_one(self, start, length, p):
walk_res = [np.zeros(length, dtype=np.int32)] * len(start)
p = 0.0
for i in range(len(start)):
node = start[i]
result = [np.int32(0)] * length
index = np.int32(0)
_node = node
while index < length:
start1 = self.indptr[node]
end1 = self.indptr[node + 1]
sample1 = random.randint(start1, end1 - 1)
node = self.indices[sample1]
if np.random.uniform(0, 1) > p:
result[index] = node
else:
result[index] = _node
index += 1
k = int(np.floor(np.random.rand() * len(result)))
walk_res[i] = result[k]
return walk_res
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@
import torch

import numpy as np
from cogdl.utils import RandomWalker
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_logreg
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_liblinear
from .. import UnsupervisedModelWrapper

from torch.nn import functional as F

class UnsupGraphSAGEModelWrapper(UnsupervisedModelWrapper):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument("--num-shuffle", type=int, default=1)
parser.add_argument("--training-percents", default=[0.2], type=float, nargs="+")
parser.add_argument("--walk-length", type=int, default=10)
parser.add_argument("--negative-samples", type=int, default=30)
# fmt: on

def __init__(self, model, optimizer_cfg, walk_length, negative_samples):
def __init__(self, model, optimizer_cfg, walk_length, negative_samples, num_shuffle=1, training_percents=[0.1]):
super(UnsupGraphSAGEModelWrapper, self).__init__()
self.model = model
self.optimizer_cfg = optimizer_cfg
self.walk_length = walk_length
self.num_negative_samples = negative_samples
self.num_shuffle = num_shuffle
self.training_percents = training_percents


def train_step(self, batch):
x_src, adjs = batch
out = self.model(x_src,adjs)
out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)

pos_loss = torch.log(torch.sigmoid((out * pos_out).sum(-1)).mean())
neg_loss = torch.log(torch.sigmoid(-(out * neg_out).sum(-1)).mean())
pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
loss = -pos_loss - neg_loss
return loss


def test_step(self, batch):
dataset, test_loader = batch
def test_step(self, graph):
dataset, test_loader = graph
graph = dataset.data
if hasattr(self.model, "inference"):
pred = self.model.inference(graph.x, test_loader)
with torch.no_grad():
if hasattr(self.model, "inference"):
pred = self.model.inference(graph.x, test_loader)
else:
pred = self.model(graph)
if len(graph.y.shape) > 1:
self.label_matrix = graph.y.numpy()
else:
pred = self.model(graph)
pred= pred.split(pred.size(0) // 3, dim=0)[0]
pred = pred[graph.test_mask]
y = graph.y[graph.test_mask]

metric = self.evaluate(pred, y, metric="auto")
self.note("test_loss", self.default_loss_fn(pred, y))
self.note("test_metric", metric)
self.label_matrix = np.zeros((graph.num_nodes, graph.num_classes), dtype=int)
self.label_matrix[range(graph.num_nodes), graph.y.numpy()] = 1
return evaluate_node_embeddings_using_liblinear(pred, self.label_matrix, self.num_shuffle, self.training_percents)


def setup_optimizer(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/ssl/test_contrastive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_unsupervised_graphsage():
args.epochs = 2
args.checkpoint_path = "graphsage.pt"
ret = train(args)
assert ret["test_acc"] > 0
assert ret["micro-f1 0.1"] > 0


def test_dgi():
Expand Down

0 comments on commit d52b271

Please sign in to comment.