In [1]:
cd /scratch/aqd215/k-gnn/examples/

/scratch/aqd215/k-gnn/examples


In [2]:
import os.path as osp

import argparse
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_scatter import scatter_mean
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.nn import NNConv
from k_gnn import GraphConv, DataLoader, avg_pool
from k_gnn import ConnectedThreeMalkin
import sys

In [3]:
class MyFilter(object):
    def __call__(self, data):
        return data.num_nodes > 6  # Remove graphs with less than 6 nodes.


class MyPreTransform(object):
    def __call__(self, data):
        x = data.x
        data.x = data.x[:, :5]
        data = ConnectedThreeMalkin()(data)
        data.x = x
        return data


class MyTransform(object):
    def __call__(self, data):
        data.y = data.y[:, 0]  # Specify target: 0 = mu
        return data


In [4]:
target = 0
path = osp.join(osp.dirname(osp.realpath("__file__")), '..', 'data', '1-3-QM9')
dataset = QM9(
    path,
    transform=T.Compose([MyTransform(), T.Distance()]),
    pre_transform=MyPreTransform(),
    pre_filter=MyFilter())

In [5]:
dataset.data

Data(assignment_index_3=[2, 13144659], edge_attr=[4823298, 4], edge_index=[2, 4823298], edge_index_3=[2, 1244888], idx=[129410], iso_type_3=[4381553], name=[129410], pos=[2333506, 3], x=[2333506, 13], y=[129410, 19], z=[2333506])

In [8]:
dataset.data.iso_type_3, dataset.data.iso_type_3.size()

(tensor([  6,   6,   6,  ..., 131, 131, 126]), torch.Size([4381553]))

In [13]:
torch.unique(dataset.data.iso_type_3)

tensor([  1,   2,   6,   7,   8,  12,  13,  31,  32,  33,  34,  37,  38,  39,
         43,  44,  49,  62,  63,  68, 126, 127, 131, 132, 133, 137, 138, 156,
        157, 158, 159, 162, 163, 164, 168, 169, 174, 187, 188, 193])

In [14]:
torch.unique(dataset.data.iso_type_3, True, True)

(tensor([  1,   2,   6,   7,   8,  12,  13,  31,  32,  33,  34,  37,  38,  39,
          43,  44,  49,  62,  63,  68, 126, 127, 131, 132, 133, 137, 138, 156,
         157, 158, 159, 162, 163, 164, 168, 169, 174, 187, 188, 193]),
 tensor([ 2,  2,  2,  ..., 22, 22, 20]))

In [10]:
torch.unique(dataset.data.iso_type_3, True, True)[1], torch.unique(dataset.data.iso_type_3, True, True)[1].size()

(tensor([ 2,  2,  2,  ..., 22, 22, 20]), torch.Size([4381553]))

In [6]:
dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1]
num_i_3 = dataset.data.iso_type_3.max().item() + 1
dataset.data.iso_type_3 = F.one_hot(
    dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float)

In [7]:
test_loader = DataLoader(dataset[2:3], batch_size=64)
val_loader = DataLoader(dataset[1:2], batch_size=64)
train_loader = DataLoader(dataset[:1], batch_size=64)

In [8]:
for data in test_loader:
    print(data, '\n', data.x, '\n', data.edge_attr, '\n', data.edge_index, '\n', data.y)

Batch(assignment_index_3=[2, 27], batch=[7], batch_3=[9], edge_attr=[12, 5], edge_index=[2, 12], edge_index_3=[2, 0], idx=[1], iso_type_3=[9, 40], name=[1], pos=[7, 3], x=[7, 13], y=[1], z=[7]) 
 tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 0., 1., 3.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 1., 0., 8., 1., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]]) 
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 1.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.7237],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.7275],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.7275],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.7984],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.7396],
        [0.0000, 