In [61]:
import torch
import torch.nn as nn
from torch_geometric.nn.models import DeepGraphInfomax, GCN
from layer.readout import AvgReadout
from layer.corrupt import FeatureShuffle
from torch_geometric.datasets import CitationFull
from torch_geometric.utils import get_laplacian, remove_self_loops

import os.path as osp
import numpy as np

import GCL.augmentors as A
import copy

from aug.my_feature_masking import MyFeatureMasking
from layer.MyGCN import MyGCN

In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
hid_units = 512

# load dataset
device = torch.device('cuda')
path = osp.join(osp.expanduser('~'), 'datasets')
dataset = CitationFull(path, name='Cora')
# dataset = Planetoid(path, name=dataset_name)
assert len(dataset) == 1, "Expecting node classification on one huge graph"

data = dataset[0]
x = data.x
y = data.y
edge_index = data.edge_index
n_features = dataset.num_features
num_classes = dataset.num_classes
num_nodes = x.shape[0]

lbl_1 = torch.ones(num_nodes).unsqueeze(1)
lbl_2 = torch.zeros(num_nodes).unsqueeze(1)
double_lbl = torch.cat((lbl_1, lbl_2), 1)

train_set, val_set, test_set = torch.utils.data.random_split(range(num_nodes), [0.8,0,0.2])
idx_train = torch.tensor(train_set.indices)
idx_val = torch.tensor(val_set.indices)
idx_test = torch.tensor(test_set.indices)

augmentor = A.EdgeRemoving(0.05)
edge_index, edge_weight = get_laplacian(edge_index, normalization = 'sym')

In [36]:
encoder = MyGCN(in_channels=n_features, hidden_channels=512, num_layers=1, act='prelu')
summary = AvgReadout()
corruption = FeatureShuffle()
model = DeepGraphInfomax(hidden_channels=512, 
                        encoder=encoder,
                        summary=summary,
                        corruption=corruption)

In [47]:
x1, edge_index1, _ = aug_node(data.x, data.edge_index)
x2, edge_index2, _ = aug_mask(data.x, data.edge_index)
edge_index1, edge_weight1 = get_laplacian(edge_index1, normalization = 'sym')

In [48]:
encoder.gcn(x, edge_index)

tensor([[-0.1439,  0.0794, -0.0705,  ...,  0.0196, -0.1589, -0.0187],
        [-0.0268,  0.0704, -0.0988,  ..., -0.0169,  0.0222,  0.0472],
        [-0.0268,  0.0704, -0.0988,  ..., -0.0169,  0.0222,  0.0472],
        ...,
        [-0.0334, -0.0432,  0.0212,  ..., -0.0057, -0.0629, -0.0028],
        [-0.0912,  0.0022, -0.0392,  ..., -0.0208, -0.1000,  0.0148],
        [ 0.0209, -0.0495, -0.0114,  ...,  0.0036, -0.0212,  0.0468]],
       grad_fn=<AddBackward0>)

In [63]:
edge_weight.min()

tensor(0.0065)

In [62]:
pos_z = model.encoder(x, edge_index, edge_weight = edge_weight)
edge_index, edge_weight = remove_self_loops(edge_index, -edge_weight)

In [64]:
gcn = GCN(in_channels=n_features, hidden_channels=512, num_layers=1, act='prelu')
gcn(x, edge_index, edge_weight = edge_weight)

tensor([[ 0.0902,  0.0316, -0.0217,  ...,  0.0839, -0.0078,  0.0388],
        [ 0.1330,  0.0760, -0.0780,  ...,  0.0748,  0.0993,  0.1501],
        [ 0.1330,  0.0760, -0.0780,  ...,  0.0748,  0.0993,  0.1501],
        ...,
        [-0.0965, -0.0693,  0.0226,  ...,  0.0684,  0.1946,  0.0115],
        [-0.0166,  0.0364,  0.0570,  ..., -0.0003, -0.0682,  0.0269],
        [-0.0053, -0.0453, -0.1507,  ..., -0.0905, -0.0432,  0.0371]],
       grad_fn=<AddBackward0>)

In [87]:
pz, nz, s = out
print(pz.shape)
print(nz.shape)
print(s.shape)

torch.Size([19793, 512])
torch.Size([19793, 512])
torch.Size([512])


In [88]:
def disc(summary_aug1, pos, neg, DGI):
    pos_logits = DGI.discriminate(z = pos, summary = summary_aug1, sigmoid = False)
    neg_logits = DGI.discriminate(z = neg, summary = summary_aug1, sigmoid = False)
    return torch.cat((pos_logits.unsqueeze(1), neg_logits.unsqueeze(1)),1)

In [89]:
double_logits1 = disc(s1, pz, nz, contrastiveModel)
double_logits2 = disc(s2, pz, nz, contrastiveModel)
double_logits = double_logits1 + double_logits2
print(double_logits.shape)

torch.Size([19793, 2])


In [94]:
b_xent = nn.BCEWithLogitsLoss()
loss = b_xent(double_logits, double_lbl)

In [96]:
loss.backward()

In [17]:
def test(a, b, *, d):
    print(a)
    print(b)
    print(d)

In [19]:
test(1,2,d=3)

1
2
3
