Skip to content

Commit

Permalink
writing structure using LightningModule
Browse files Browse the repository at this point in the history
- Partially Fixed issue #5
- Not completed and tested
  • Loading branch information
ParthaPratimBanik committed Aug 20, 2023
1 parent 2a4f3c8 commit b1c798b
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 0 deletions.
21 changes: 21 additions & 0 deletions pygod_lightning/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# from .adone import AdONEBase
# from .anomalydae import AnomalyDAEBase
# from .cola import CoLABase
from .dominant import DOMINANTBase
# from .done import DONEBase
# from .gaan import GAANBase
# from .gae import GAEBase
# from .guide import GUIDEBase
# from .ocgnn import OCGNNBase
# from . import conv
# from . import decoder
# from . import encoder
# from . import functional

# __all__ = [
# "AdONEBase", "AnomalyDAEBase", "CoLABase", "DOMINANTBase", "DONEBase",
# "GAANBase", "GAEBase", "GUIDEBase", "OCGNNBase"
# ]
__all__ = [
"DOMINANTBase"
]
154 changes: 154 additions & 0 deletions pygod_lightning/nn/dominant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import math
import torch
# import torch.nn as nn
from torch_geometric.nn import GCN
# from torch_geometric.utils import to_dense_adj

from .decoder import DotProductDecoder
from .functional import double_recon_loss

import lightning.pytorch as pl


class DOMINANTBase(pl.LightningModule):
"""
Deep Anomaly Detection on Attributed Networks
DOMINANT is an anomaly detector consisting of a shared graph
convolutional encoder, a structure reconstruction decoder, and an
attribute reconstruction decoder. The reconstruction mean squared
error of the decoders are defined as structure anomaly score and
attribute anomaly score, respectively.
See :cite:`ding2019deep` for details.
Parameters
----------
in_dim : int
Input dimension of model.
hid_dim : int
Hidden dimension of model. Default: ``64``.
num_layers : int, optional
Total number of layers in model. A half (floor) of the layers
are for the encoder, the other half (ceil) of the layers are
for decoders. Default: ``4``.
dropout : float, optional
Dropout rate. Default: ``0.``.
act : callable activation function or None, optional
Activation function if not None.
Default: ``torch.nn.functional.relu``.
sigmoid_s : bool, optional
Whether to apply sigmoid to the structure reconstruction.
Default: ``False``.
backbone : torch.nn.Module, optional
The backbone of the deep detector implemented in PyG.
Default: ``torch_geometric.nn.GCN``.
**kwargs : optional
Additional arguments for the backbone.
"""

def __init__(self,
in_dim,
hid_dim=64,
num_layers=4,
dropout=0.,
act=torch.nn.functional.relu,
sigmoid_s=False,
backbone=GCN,
weight_decay=0.,
**kwargs):
super(DOMINANTBase, self).__init__()

# split the number of layers for the encoder and decoders
assert num_layers >= 2, \
"Number of layers must be greater than or equal to 2."
encoder_layers = math.floor(num_layers / 2)
decoder_layers = math.ceil(num_layers / 2)

self.shared_encoder = backbone(in_channels=in_dim,
hidden_channels=hid_dim,
num_layers=encoder_layers,
out_channels=hid_dim,
dropout=dropout,
act=act,
**kwargs)

self.attr_decoder = backbone(in_channels=hid_dim,
hidden_channels=hid_dim,
num_layers=decoder_layers,
out_channels=in_dim,
dropout=dropout,
act=act,
**kwargs)

self.struct_decoder = DotProductDecoder(in_dim=hid_dim,
hid_dim=hid_dim,
num_layers=decoder_layers - 1,
dropout=dropout,
act=act,
sigmoid_s=sigmoid_s,
backbone=backbone,
**kwargs)

self.weight_decay = weight_decay
self.emb = None

def forward(self, x, edge_index):
"""
Forward computation.
Parameters
----------
x : torch.Tensor
Input attribute embeddings.
edge_index : torch.Tensor
Edge index.
Returns
-------
x_ : torch.Tensor
Reconstructed attribute embeddings.
s_ : torch.Tensor
Reconstructed adjacency matrix.
"""
# encode feature matrix
self.emb = self.shared_encoder(x, edge_index)

# reconstruct feature matrix
x_ = self.attr_decoder(self.emb, edge_index)

# decode adjacency matrix
s_ = self.struct_decoder(self.emb, edge_index)

return x_, s_

def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
batch_size = batch.batch_size
node_idx = batch.n_id

x = batch.x.to(self.device)
s = batch.s.to(self.device)
edge_index = batch.edge_index.to(self.device)

x_, s_ = self(x, edge_index)

score = double_recon_loss(x[:batch_size],
x_[:batch_size],
s[:batch_size, node_idx],
s_[:batch_size],
self.weight)

loss = torch.mean(score)

self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

return loss


def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
return optimizer
60 changes: 60 additions & 0 deletions pygod_lightning/test/test_dominant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import tqdm
import torch
import argparse
from random import choice
from pygod.detector import *
from pygod.utils import load_data

from torch_geometric.seed import seed_everything
from pygod_lightning.dataset import DataSet
from pygod_lightning.nn import DOMINANTBase

def main(args):
## Checking Training Result on pygodv100 implementation
model = DOMINANT(hid_dim=choice(hid_dim),
weight_decay=weight_decay,
dropout=choice(dropout),
lr=choice(lr),
epoch=epoch,
gpu=gpu,
weight=choice(alpha),
batch_size=batch_size,
num_neigh=num_neigh)
data = load_data(args.dataset)
model.fit(data)
# score = model.decision_score_

## Checking Training Result on pytorch.lightning implementation
modelPL = DOMINANTBase(hid_dim=choice(hid_dim),
weight_decay=weight_decay,
dropout=choice(dropout),
lr=choice(lr),
epoch=epoch,
gpu=gpu,
weight=choice(alpha),
batch_size=batch_size,
num_neigh=num_neigh)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=int, default=0,
help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--dataset", type=str, default='inj_cora',
help="supported dataset: [inj_cora, inj_amazon, "
"inj_flickr, weibo, reddit, disney, books, "
"enron]. Default: inj_cora")
args = parser.parse_args()

## Declare the parameters
dropout = [0, 0.1, 0.3]
lr = [0.1, 0.05, 0.01]
weight_decay = 0.01
batch_size = 0
num_neigh = -1
epoch = 300
gpu = args.gpu
hid_dim = [32, 64, 128, 256]
alpha = [0.8, 0.5, 0.2]

main(args)

0 comments on commit b1c798b

Please sign in to comment.