Skip to content

Commit

Permalink
[Bugfix] Fix dgk/graph2vec/gdc/grace (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 committed Feb 21, 2022
1 parent 30a69a0 commit 48298bc
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 353 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
[![PyPI Latest Release](https://badge.fury.io/py/cogdl.svg)](https://pypi.org/project/cogdl/)
[![Build Status](https://travis-ci.org/THUDM/cogdl.svg?branch=master)](https://travis-ci.org/THUDM/cogdl)
[![Documentation Status](https://readthedocs.org/projects/cogdl/badge/?version=latest)](https://cogdl.readthedocs.io/en/latest/?badge=latest)
[![Downloads](https://pepy.tech/badge/cogdl)](https://pepy.tech/project/cogdl)
[![Coverage Status](https://coveralls.io/repos/github/THUDM/cogdl/badge.svg?branch=master)](https://coveralls.io/github/THUDM/cogdl?branch=master)
[![License](https://img.shields.io/github/license/thudm/cogdl)](https://github.com/THUDM/cogdl/blob/master/LICENSE)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black)
Expand Down
3 changes: 2 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

[![PyPI Latest Release](https://badge.fury.io/py/cogdl.svg)](https://pypi.org/project/cogdl/)
[![Build Status](https://travis-ci.org/THUDM/cogdl.svg?branch=master)](https://travis-ci.org/THUDM/cogdl)
[![Coverage Status](https://coveralls.io/repos/github/THUDM/cogdl/badge.svg?branch=master)](https://coveralls.io/github/THUDM/cogdl?branch=master)
[![Documentation Status](https://readthedocs.org/projects/cogdl/badge/?version=latest)](https://cogdl.readthedocs.io/en/latest/?badge=latest)
[![Downloads](https://pepy.tech/badge/cogdl)](https://pepy.tech/project/cogdl)
[![Coverage Status](https://coveralls.io/repos/github/THUDM/cogdl/badge.svg?branch=master)](https://coveralls.io/github/THUDM/cogdl?branch=master)
[![License](https://img.shields.io/github/license/thudm/cogdl)](https://github.com/THUDM/cogdl/blob/master/LICENSE)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black)

Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/emb/dgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def wl_iterations(graph, features, rounds):
neighbors = graph.neighbors(node)
neigh_feats = [features[x] for x in neighbors]
neigh_feats = [features[node]] + sorted(neigh_feats)
hash_feat = hashlib.md5("_".join(neigh_feats).encode())
hash_feat = hashlib.md5("_".join([str(x) for x in neigh_feats]).encode())
hash_feat = hash_feat.hexdigest()
new_feats[node] = hash_feat
wl_features = wl_features + list(new_feats.values())
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/emb/graph2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def wl_iterations(graph, features, rounds):
neighbors = graph.neighbors(node)
neigh_feats = [features[x] for x in neighbors]
neigh_feats = [features[node]] + sorted(neigh_feats)
hash_feat = hashlib.md5("_".join(neigh_feats).encode())
hash_feat = hashlib.md5("_".join([str(x) for x in neigh_feats]).encode())
hash_feat = hash_feat.hexdigest()
new_feats[node] = hash_feat
wl_features = wl_features + list(new_feats.values())
Expand Down
7 changes: 5 additions & 2 deletions cogdl/models/nn/gdc_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def add_args(parser):
parser.add_argument("--t", type=float, default=5.0)
parser.add_argument("--k", type=int, default=128)
parser.add_argument("--eps", type=float, default=0.01)
parser.add_argument("--dataset", default=None)
parser.add_argument("--gdc-type", default="ppr")
# fmt: on

Expand Down Expand Up @@ -76,8 +75,12 @@ def __init__(self, nfeat, nhid, nclass, dropout, alpha, t, k, eps, gdctype):
self.dropout = dropout

def forward(self, graph):
if self.data is None:
self.reset_data(graph)
graph = self.data
x = graph.x
graph.sym_norm()
if self.gdc_type == "none":
graph.sym_norm()

x = F.dropout(x, self.dropout, training=self.training)
x = F.relu(self.gc1(graph, x))
Expand Down
73 changes: 0 additions & 73 deletions cogdl/models/nn/grace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def __init__(
)
self.encoder = GraceEncoder(in_feats, hidden_size, num_layers, activation)

def augment(self, graph):
pass

def forward(
self,
graph: Graph,
Expand All @@ -95,76 +92,6 @@ def forward(
graph.sym_norm()
return self.encoder(graph, x)

def prop(
self,
graph: Graph,
x: torch.Tensor,
drop_feature_rate: float = 0.0,
drop_edge_rate: float = 0.0,
):
x = self.drop_feature(x, drop_feature_rate)
with graph.local_graph():
graph = self.drop_adj(graph, drop_edge_rate)
return self.forward(graph, x)

def contrastive_loss(self, z1: torch.Tensor, z2: torch.Tensor):
z1 = F.normalize(z1, p=2, dim=-1)
z2 = F.normalize(z2, p=2, dim=-1)

def score_func(emb1, emb2):
scores = torch.matmul(emb1, emb2.t())
scores = torch.exp(scores / self.tau)
return scores

intro_scores = score_func(z1, z1)
inter_scores = score_func(z1, z2)

_loss = -torch.log(intro_scores.diag() / (intro_scores.sum(1) - intro_scores.diag() + inter_scores.sum(1)))
return torch.mean(_loss)

def batched_loss(
self,
z1: torch.Tensor,
z2: torch.Tensor,
batch_size: int,
):
num_nodes = z1.shape[0]
num_batches = (num_nodes - 1) // batch_size + 1

losses = []
indices = torch.arange(num_nodes).to(z1.device)
for i in range(num_batches):
train_indices = indices[i * batch_size : (i + 1) * batch_size]
_loss = self.contrastive_loss(z1[train_indices], z2)
losses.append(_loss)
return sum(losses) / len(losses)

def embed(self, data):
pred = self.forward(data, data.x)
return pred

def drop_adj(self, graph: Graph, drop_rate: float = 0.5):
if drop_rate < 0.0 or drop_rate > 1.0:
raise ValueError("Dropout probability has to be between 0 and 1, " "but got {}".format(drop_rate))
if not self.training:
return graph

num_edges = graph.num_edges
mask = torch.full((num_edges,), 1 - drop_rate, dtype=torch.float)
mask = torch.bernoulli(mask).to(torch.bool)
row, col = graph.edge_index
row = row[mask]
col = col[mask]
edge_weight = graph.edge_weight[mask]
graph.edge_index = (row, col)
graph.edge_weight = edge_weight
return graph

def drop_feature(self, x: torch.Tensor, droprate: float):
n = x.shape[1]
drop_rates = torch.ones(n) * droprate
if self.training:
masks = torch.bernoulli(1.0 - drop_rates).view(1, -1).expand_as(x)
masks = masks.to(x.device)
x = masks * x
return x
Loading

0 comments on commit 48298bc

Please sign in to comment.