In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn.models import DeepGraphInfomax, GCN
from layer.readout import AvgReadout
from layer.corrupt import FeatureShuffle
from torch_geometric.datasets import CitationFull
from torch_geometric.utils import get_laplacian

import os.path as osp
import numpy as np

import GCL.augmentors as A
import copy

from aug.my_feature_masking import MyFeatureMasking

In [2]:
%load_ext autoreload
%autoreload 2

In [20]:
device = torch.device('cuda')
path = osp.join(osp.expanduser('~'), 'datasets')
dataset = CitationFull(path, name='Cora')

n_features = dataset.num_features
data = dataset[0]
num_nodes = data.x.shape[0]

lbl_1 = torch.ones(num_nodes).unsqueeze(1)
lbl_2 = torch.zeros(num_nodes).unsqueeze(1)
double_lbl = torch.cat((lbl_1, lbl_2), 1)

aug_pe = 0.5
aug_mask = MyFeatureMasking(aug_pe)
aug_edge = A.EdgeRemoving(aug_pe)
subgraph_size = int((1 - aug_pe) * num_nodes)
aug_subgraph = A.RWSampling(num_seeds=1000, walk_length=subgraph_size)
aug_node = A.NodeDropping(0.5)

In [10]:
encoder = GCN(in_channels=n_features, hidden_channels=512, num_layers=1, act='prelu')
summary = AvgReadout()
corruption = FeatureShuffle()
contrastiveModel = DeepGraphInfomax(hidden_channels=512, 
                                    encoder=encoder,
                                    summary=summary,
                                    corruption=corruption)

In [11]:
x1, edge_index1, _ = aug_node(data.x, data.edge_index)
x2, edge_index2, _ = aug_mask(data.x, data.edge_index)
edge_index1, edge_weight1 = get_laplacian(edge_index1)

In [12]:
pz1, nz1, s1 = contrastiveModel(x=x1, edge_index=edge_index1)

TypeError: BasicGNN.forward() takes 3 positional arguments but 4 were given

In [87]:
pz, nz, s = out
print(pz.shape)
print(nz.shape)
print(s.shape)

torch.Size([19793, 512])
torch.Size([19793, 512])
torch.Size([512])


In [88]:
def disc(summary_aug1, pos, neg, DGI):
    pos_logits = DGI.discriminate(z = pos, summary = summary_aug1, sigmoid = False)
    neg_logits = DGI.discriminate(z = neg, summary = summary_aug1, sigmoid = False)
    return torch.cat((pos_logits.unsqueeze(1), neg_logits.unsqueeze(1)),1)

In [89]:
double_logits1 = disc(s1, pz, nz, contrastiveModel)
double_logits2 = disc(s2, pz, nz, contrastiveModel)
double_logits = double_logits1 + double_logits2
print(double_logits.shape)

torch.Size([19793, 2])


In [94]:
b_xent = nn.BCEWithLogitsLoss()
loss = b_xent(double_logits, double_lbl)

In [96]:
loss.backward()

In [17]:
def test(a, b, *, d):
    print(a)
    print(b)
    print(d)

In [19]:
test(1,2,d=3)

1
2
3
