Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Oct 9, 2020
1 parent 3bd9fde commit f11eabc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 48 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ In detail, the following methods are currently implemented:
* **[MetaPath2Vec](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.MetaPath2Vec)** from Dong *et al.*: [metapath2vec: Scalable Representation Learning for Heterogeneous Networks](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) (KDD 2017) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/metapath2vec.py)]
* **[Deep Graph Infomax](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.DeepGraphInfomax)** from Veličković *et al.*: [Deep Graph Infomax](https://arxiv.org/abs/1809.10341) (ICLR 2019) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/infomax.py)]
* All variants of **[Graph Autoencoders](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.GAE)** and **[Variational Autoencoders](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.VGAE)** from:
* [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308) from Kipf and Welling (NIPS-W 2016) [**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/autoencoder.py)
* [Adversarially Regularized Graph Autoencoder for Graph Embedding](https://arxiv.org/abs/1802.04407) from Pan *et al.* (IJCAI 2018) [**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/argva_node_clustering.py)
* [Keep It Simple: Graph Autoencoders Without Graph Convolutional Networks](https://arxiv.org/abs/1910.00942) from Salha *et al.* (NeurIPS-W 2019, then ECML-PKDD 2020 - [see here](https://arxiv.org/abs/2001.07614)) [**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/autoencoder.py)
* [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308) from Kipf and Welling (NIPS-W 2016) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/autoencoder.py)]
* [Adversarially Regularized Graph Autoencoder for Graph Embedding](https://arxiv.org/abs/1802.04407) from Pan *et al.* (IJCAI 2018) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/argva_node_clustering.py)]
* [Keep It Simple: Graph Autoencoders Without Graph Convolutional Networks](https://arxiv.org/abs/2001.07614) from Salha *et al.* (ECML 2020) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/autoencoder.py)]
* **[SEAL](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/seal_link_pred.py)** from Zhang and Chen: [Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf) (NeurIPS 2018)
* **[RENet](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.RENet)** from Jin *et al.*: [Recurrent Event Network for Reasoning over Temporal Knowledge Graphs](https://arxiv.org/abs/1904.05530) (ICLR-W 2019) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/renet.py)]
* **[GraphUNet](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.GraphUNet)** from Gao and Ji: [Graph U-Nets](https://arxiv.org/abs/1905.05178) (ICML 2019) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/graph_unet.py)]
Expand Down
101 changes: 56 additions & 45 deletions examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,86 @@

import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE
from torch_geometric.utils import train_test_split_edges

torch.manual_seed(12345)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='GAE')
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--variational', action='store_true')
parser.add_argument('--linear', action='store_true')
parser.add_argument('--dataset', type=str, default='Cora',
choices=['Cora', 'CiteSeer', 'PubMed'])
parser.add_argument('--epochs', type=int, default=400)
args = parser.parse_args()
assert args.model in ['GAE', 'LinearGAE', 'VGAE', 'LinearVGAE']
assert args.dataset in ['Cora', 'CiteSeer', 'PubMed']
kwargs = {'GAE': GAE, 'LinearGAE': GAE, 'VGAE': VGAE, 'LinearVGAE': VGAE}

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
args.dataset)
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)


# 2-layer GCN encoder from standard graph AE and VAE models: https://arxiv.org/abs/1611.07308
class GCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(GCNEncoder, self).__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
if args.model in ['GAE']:
self.conv2 = GCNConv(2 * out_channels, out_channels, cached=True)
elif args.model in ['VGAE']:
self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
self.conv_logstd = GCNConv(2 * out_channels, out_channels,
cached=True)
self.conv2 = GCNConv(2 * out_channels, out_channels, cached=True)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)


class VariationalGCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(VariationalGCNEncoder, self).__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
if args.model in ['GAE']:
return self.conv2(x, edge_index)
elif args.model in ['VGAE']:
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


# Linear encoder from linear graph AE and VAE models: https://arxiv.org/abs/1910.00942
x = self.conv1(x, edge_index).relu()
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


class LinearEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(LinearEncoder, self).__init__()
if args.model in ['LinearGAE']:
self.conv1 = GCNConv(in_channels, out_channels, cached=True)
elif args.model in ['LinearVGAE']:
self.conv_mu = GCNConv(in_channels, out_channels, cached=True)
self.conv_logstd = GCNConv(in_channels, out_channels,
cached=True)
self.conv = GCNConv(in_channels, out_channels, cached=True)

def forward(self, x, edge_index):
if args.model in ['LinearGAE']:
return self.conv1(x, edge_index)
elif args.model in ['LinearVGAE']:
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
return self.conv(x, edge_index)


class VariationalLinearEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(VariationalLinearEncoder, self).__init__()
self.conv_mu = GCNConv(in_channels, out_channels, cached=True)
self.conv_logstd = GCNConv(in_channels, out_channels, cached=True)

def forward(self, x, edge_index):
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

channels = 16
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.model in ['LinearGAE','LinearVGAE']:
model = kwargs[args.model](LinearEncoder(dataset.num_features, channels)).to(dev)

out_channels = 16
num_features = dataset.num_features

if not args.variational:
if not args.linear:
model = GAE(GCNEncoder(num_features, out_channels))
else:
model = GAE(LinearEncoder(num_features, out_channels))
else:
model = kwargs[args.model](GCNEncoder(dataset.num_features, channels)).to(dev)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
x, train_pos_edge_index = data.x.to(dev), data.train_pos_edge_index.to(dev)
if args.linear:
model = VGAE(VariationalLinearEncoder(num_features, out_channels))
else:
model = VGAE(VariationalGCNEncoder(num_features, out_channels))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = data.x.to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


Expand All @@ -80,10 +90,11 @@ def train():
optimizer.zero_grad()
z = model.encode(x, train_pos_edge_index)
loss = model.recon_loss(z, train_pos_edge_index)
if args.model in ['VGAE','LinearVGAE']:
if args.variational:
loss = loss + (1 / data.num_nodes) * model.kl_loss()
loss.backward()
optimizer.step()
return float(loss)


def test(pos_edge_index, neg_edge_index):
Expand All @@ -94,6 +105,6 @@ def test(pos_edge_index, neg_edge_index):


for epoch in range(1, args.epochs + 1):
train()
loss = train()
auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

0 comments on commit f11eabc

Please sign in to comment.