In [364]:
from torch_geometric.data import InMemoryDataset
import pickle
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_undirected
from typing import Union, Tuple
from torch_geometric.typing import OptPairTensor, Adj, Size
from typing import List, Optional, Set, get_type_hints
from torch_scatter import gather_csr, scatter, segment_csr

from torch import Tensor
from torch.nn import Linear
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


import numpy as np
from collections import defaultdict

In [249]:
class MyConv(MessagePassing):
    def __init__(self, **kwargs):  # yapf: disable
        kwargs.setdefault('aggr', 'add')
        super(MyConv, self).__init__(**kwargs)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, deg: Tensor,
                size: Size = None) -> Tensor:
        
        x = torch.cat((x, deg.view(-1,1)), dim=1)
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)      
        out = self.propagate(edge_index, x=x, size=size)
        return out

    def message(self, x_j: Tensor) -> Tensor:
        return x_j
    
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        r"""Aggregates messages from neighbors as
        :math:`\square_{j \in \mathcal{N}(i)}`.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to scatter functions
        that support "add", "mean" and "max" operations as specified in
        :meth:`__init__` by the :obj:`aggr` argument.
        """
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=self.aggr)
        else:
            print(inputs, index)
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                           reduce=self.aggr)
        
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [None]:
class GENConv(MessagePassing):
    def __init__(self):
        kwargs.setdefault('aggr', None)
        super(GENConv, self).__init__(**kwargs)

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x, size=size)
        if self.msg_norm is not None:
            out = self.msg_norm(x[0], out)
        x_r = x[1]
        if x_r is not None:
            out += x_r
        return self.mlp(out)

    def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
        return F.relu(msg) + self.eps

    def aggregate(self, inputs: Tensor, index: Tensor,
                  dim_size: Optional[int] = None) -> Tensor:

        out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
        return scatter(inputs * out, index, dim=self.node_dim,
                           dim_size=dim_size, reduce='sum')

#         elif self.aggr == 'softmax_sg':
#             out = scatter_softmax(inputs * self.t, index,
#                                   dim=self.node_dim).detach()
#             return scatter(inputs * out, index, dim=self.node_dim,
#                            dim_size=dim_size, reduce='sum')

#         else:
#             min_value, max_value = 1e-7, 1e1
#             torch.clamp_(inputs, min_value, max_value)
#             out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
#                           dim_size=dim_size, reduce='mean')
#             torch.clamp_(out, min_value, max_value)
#             return torch.pow(out, 1 / self.p)

In [250]:
conv = MyConv()

In [127]:
el = [(0,13),(0,5),(0,14),(13,1),(0,18),(0,8),(14,7),(0,11),(0,3),
         (3,10),(0,19),(11,4),(1,9),(9,12),(0,12),(3,7),(0,16),(11,8),
         (1,3),(1,6),(0,13),(14,13),(11,15),(0,12),(4,17),(11,12),(3,2)]
ei = torch.tensor(el).T

In [300]:
ei_dict = defaultdict(set)
for e in ei.T:
    e = e.to(dtype=torch.long)
    ei_dict[e[0]].add(e[1]) 
    ei_dict[e[1]].add(e[0]) 

In [305]:
b = {torch.tensor(0): {torch.tensor(13)}}
b

{tensor(0): {tensor(13)}}

In [147]:
oh = F.one_hot(ei[0],num_classes=20)+F.one_hot(ei[1],num_classes=20)
oh

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0,

In [365]:
# torch.cumsum(oh, dim=0)


In [254]:
row, col = ei
adj_t = SparseTensor(row=col, col=row,
                     sparse_sizes=(data.num_nodes, data.num_nodes))

In [259]:
adj_t

SparseTensor(row=tensor([ 1,  2,  3,  3,  4,  5,  6,  7,  7,  8,  8,  9, 10, 11, 12, 12, 12, 12,
                           13, 13, 13, 14, 15, 16, 17, 18, 19]),
             col=tensor([13,  3,  0,  1, 11,  0,  1,  3, 14,  0, 11,  1,  3,  0,  0,  0,  9, 11,
                            0,  0, 14,  0, 11,  0,  4,  0,  0]),
             size=(275, 275), nnz=27, density=0.04%)

In [258]:
adj_t[3]

SparseTensor(row=tensor([0, 0]),
             col=tensor([0, 1]),
             size=(1, 275), nnz=2, density=0.73%)

In [262]:
delta_v = torch.ones(10, 8)
delta_v

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
torch.zeros(16, 8, dtype=src.dtype).scatter_add_(0, index, src)


In [252]:
_, dg = torch.unique(ei, return_counts = True)
dg

tensor([12,  4,  1,  5,  2,  1,  1,  2,  2,  2,  1,  5,  4,  4,  3,  1,  1,  1,
         1,  1])

In [253]:
xx = torch.rand(20,8)
y = conv(x = xx, edge_index = ei, deg = torch.ones(20))

tensor([[0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.9967, 0.6499, 0.4298, 0.2115, 0.7027, 0.8352, 0.6765, 0.0435, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.7772, 0.1651, 0.0392, 0.4880, 0.2471, 0.5597, 0.5317, 0.4280, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.0049, 0.0289, 0.7177, 0.0580, 0.8168, 0.5887, 0.8685, 0.2321, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.7230, 0.3680, 0.4621, 0.3198, 0.0290, 0.2466, 0.4086, 0.0826, 1.0000],
        [0.7480,

In [10]:
class RWDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(RWDataset, self).__init__(root, transform, pre_transform)
        self.pth = '/data/egor/graph_generation/graph_generation/data/rw/cora/graphs/'
        self.data, self.slices = torch.load(self.processed_paths[0])
        self.features = torch.load(self.pth+'cora_x.pt')
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return ['/data/egor/graph_generation/graph_generation/data/rw/cora//cora.dataset']

    def download(self):
        pass
    
    def process(self):
        self.pth = '/data/egor/graph_generation/graph_generation/data/rw/cora/graphs/'

        data_list = []

        edge_list_array = []
        
        for ig in range(14538):
            with open(self.pth + 'graph' + str(ig) + '.dat', 'rb') as f:        
                G = pickle.load(f)
                x, ei = torch.unique(torch.tensor(list(G.edges)).T, return_inverse  = True)
                ei = to_undirected(ei)
#                 data_list.append(Data(x = x, edge_index = ei))
                
                c,r = ei
                data_list.append(Data(x = x, edge_index = ei[:,c<r]))
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [396]:
BSIZE = 16
dd = RWDataset('')
train_loader = DataLoader(dd, batch_size=BSIZE, shuffle=False)# , exclude_keys=['x']

In [413]:
def data_gen_e_aug(train_loader, slices, batch_size = 4, step = 1):
    for ib, data in enumerate(train_loader):
#         r,c = data.edge_index
#         print(data.edge_index)
#         print(data.edge_index[:,r>c].shape)
        e_ptr = slices[ib*batch_size:(ib+1)*batch_size+1]
        e_ptr = e_ptr - e_ptr[0]
        szs = e_ptr[1:]-e_ptr[:-1]
        e_ind_start = e_ptr[1:]-szs.min()+1
        visited_e = torch.full((e_ptr[-1],), False, dtype = torch.bool)
        for i in range(e_ind_start.shape[0]):
            visited_e[e_ptr[:-1][i] :e_ind_start[i]] = True # setting emask True for edges in graph
        edges_num = torch.arange(e_ptr[-1])

        visited_v = torch.full((data.ptr[-1],), False)
        visited_v[torch.unique(data.edge_index[:, visited_e])] = True
        
        vert_ind = []
        last_ind_v = torch.arange(data.ptr[-1])
        last_ind_max = data.ptr[-1].item()        
        ei_dict = defaultdict(set)
        for e in data.edge_index[:, visited_e].T:
            e = e.to(dtype=torch.long)
            ei_dict[e[0].item()].add(e[1].item()) 
            ei_dict[e[1].item()].add(e[0].item()) 
        
        for i in range(data.edge_index.shape[1]): # max number of iterations, usually we stop earlier
            if torch.all(visited_e):
                break           
                  
            e1_mask = (visited_v[data.edge_index[0]] | visited_v[data.edge_index[1]]) & ~visited_e # Source in graph
            nnedges = edges_num[e1_mask]
            e1_ind = []
            for j in range(1, e_ptr.shape[0]):
                mmask = (nnedges < e_ptr[j]) & (nnedges >= e_ptr[j-1])
                e1_ind.append(nnedges[mmask][:step])         
            e1_ind = torch.cat(e1_ind)
            edges_1 = data.edge_index[:, e1_ind]
                        
            for e in edges_1.T: 
                for iv in (True,False):
                    ind = last_ind_v[e[int(iv)]].item()
                    if ind in ei_dict.keys():
                        ind = last_ind_max
                        vert_ind.append(e[int(iv)])
                        last_ind_v[e[int(iv)]] = ind
                        ei_dict[ind] = ei_dict[e[int(iv)].item()]   
                        last_ind_max += 1                        
                    ei_dict[ind].add(last_ind_v[e[int(~iv)]].item())
        
            # selecting source-target when both vertices in graph            
            visited_e[e1_ind] = True  
            visited_v[edges_1.view(-1)] = True
        
        edge_index = []
        for k,v in ei_dict.items():
            e = torch.tensor(list(v)).view(1,-1)
            edge_index.append(torch.cat((e, torch.full_like(e, k)), dim=0))
        vert_index = torch.cat((torch.arange(data.ptr[-1]),torch.Tensor(vert_ind))).to(dtype=torch.long)

        yield  data.x[vert_index],\
                torch.cat(edge_index, dim=1)
        
        
itt = iter(data_gen_e_aug(train_loader, dd.slices['edge_index'], batch_size = BSIZE,step = 1))
vi, ei = next(itt)

In [414]:
vi

tensor([  15,   62,  121,  ..., 1906, 2355, 2342])

In [264]:
def data_gen_edges(train_loader, slices, batch_size = 4, step = 1):
    for ib, data in enumerate(train_loader):
#         print(data)
        e_ptr = slices[ib*batch_size:(ib+1)*batch_size+1]
        e_ptr = e_ptr - e_ptr[0]
#         print(ib, e_ptr)

        szs = e_ptr[1:]-e_ptr[:-1]
        e_ind_start = e_ptr[1:]-szs.min()
        e_mask = torch.full((e_ptr[-1],), False, dtype = torch.bool)
        for i in range(e_ind_start.shape[0]):
            e_mask[e_ptr[:-1][i] :e_ind_start[i]] = True # setting emask True for edges in graph

        for i in range(szs.min()-1): #growing the graph
            new_ei = e_ind_start + i
            e_mask[new_ei] = True  
            graph = data.edge_index[:, e_mask]
            e_1 = data.edge_index[:, new_ei]
            print(graph.shape, e_1.shape)
        yield graph, e_1
        
#                 v_add_1_mask.view((batch_size, data.num_nodes)).sum(dim=1),\
#                 torch.cat((v_exp_0,v_exp_1)),\
#                 torch.cat((v_add_0,v_add_1)), \
#                 torch.cat((e_0,e_1), dim=2).reshape((2,-1)).to(dtype = torch.long),\
#                 visited_mask.view((batch_size, data.num_nodes))        
#     if ib>5:
#         break
itt = iter(data_gen_edges(train_loader, dd.slices['edge_index'], batch_size = BSIZE))
g, e1 = next(itt)

torch.Size([2, 83]) torch.Size([2, 4])
torch.Size([2, 87]) torch.Size([2, 4])
torch.Size([2, 91]) torch.Size([2, 4])
torch.Size([2, 95]) torch.Size([2, 4])
torch.Size([2, 99]) torch.Size([2, 4])
torch.Size([2, 103]) torch.Size([2, 4])
torch.Size([2, 107]) torch.Size([2, 4])
torch.Size([2, 111]) torch.Size([2, 4])
torch.Size([2, 115]) torch.Size([2, 4])
torch.Size([2, 119]) torch.Size([2, 4])
torch.Size([2, 123]) torch.Size([2, 4])
torch.Size([2, 127]) torch.Size([2, 4])
torch.Size([2, 131]) torch.Size([2, 4])
torch.Size([2, 135]) torch.Size([2, 4])
torch.Size([2, 139]) torch.Size([2, 4])
torch.Size([2, 143]) torch.Size([2, 4])
torch.Size([2, 147]) torch.Size([2, 4])
torch.Size([2, 151]) torch.Size([2, 4])
torch.Size([2, 155]) torch.Size([2, 4])
torch.Size([2, 159]) torch.Size([2, 4])
torch.Size([2, 163]) torch.Size([2, 4])
torch.Size([2, 167]) torch.Size([2, 4])
torch.Size([2, 171]) torch.Size([2, 4])
torch.Size([2, 175]) torch.Size([2, 4])
torch.Size([2, 179]) torch.Size([2, 4])
torch

In [None]:
def data_gen_e_connected(train_loader, slices, batch_size = 4, step = 1):
    for ib, data in enumerate(train_loader):
        e_ptr = slices[ib*batch_size:(ib+1)*batch_size+1]
        e_ptr = e_ptr - e_ptr[0]
        szs = e_ptr[1:]-e_ptr[:-1]
        e_ind_start = e_ptr[1:]-szs.min()+1
        visited_e = torch.full((e_ptr[-1],), False, dtype = torch.bool)
        for i in range(e_ind_start.shape[0]):
            visited_e[e_ptr[:-1][i] :e_ind_start[i]] = True # setting emask True for edges in graph
        edges_num = torch.arange(e_ptr[-1])

        visited_v = torch.full((data.ptr[-1],), False)
        visited_v[torch.unique(data.edge_index[:, visited_e])] = True
        
        last_ind_v = torch.arange(data.ptr[-1])
        print(last_ind_v)
                
#         for i in range(szs.min()-1): #growing the graph
        for i in range(data.edge_index.shape[1]):
            if torch.all(visited_e):
                break           
            ei_graph = data.edge_index[:, visited_e]
            ei_graph = torch.cat((ei_graph,torch.flip(ei_graph, [0])), dim=1)
                                  
            e1_mask = (visited_v[data.edge_index[0]] | visited_v[data.edge_index[1]]) & ~visited_e # Source in graph
            nnedges = edges_num[e1_mask]
            e1_ind = []
            for j in range(1, e_ptr.shape[0]):
                mmask = (nnedges < e_ptr[j]) & (nnedges >= e_ptr[j-1])
                e1_ind.append(nnedges[mmask][:step])         
            e1_ind = torch.cat(e1_ind)
            edges_1 = data.edge_index[:, e1_ind]
            
            print(edges_1)
                        
            # selecting source-target when both vertices in graph            
            edges_1 = torch.where(visited_v[edges_1[0]], edges_1,torch.flip(edges_1, [0]))            
            visited_e[e1_ind] = True  
            visited_v[edges_1.view(-1)] = True
            
            print(ei_graph.shape, edges_1.shape)
            yield ei_graph, edges_1
        
        
itt = iter(data_gen_e_connected(train_loader, dd.slices['edge_index'], batch_size = BSIZE,step = 1))
g, e1 = next(itt)

In [None]:
def data_gen_edges(data, batch_size = 4, step = 8):
    istart = step
    n_edges = data.edge_index.shape[1]
    e_perm_edges = []
    for i in range(batch_size):
        e_perm_edges.append(data.edge_index[:,torch.randperm(n_edges)] + data.num_nodes *i)
    e_perm_edges = torch.stack(e_perm_edges)
    edge_index =  torch.transpose(e_perm_edges, 0, 1)

    active_mask = torch.full((data.num_nodes * batch_size,), False,  dtype=torch.bool)
    visited_mask = torch.full((data.num_nodes * batch_size,), False,  dtype=torch.bool)
    nodes = torch.arange(batch_size*data.num_nodes)
    visited_mask[torch.flatten(edge_index[:,:,:istart])] = True 

    for i in range(istart+step, n_edges, step):
        # *********  sampling edges
        graph = edge_index[:,:,:(i-step)]
        e_1 = edge_index[:,:,(i-step):i]

        # *********  node expansion
        active_mask.fill_(False)  
        active_mask[torch.flatten(e_1)] = True  
        v_exp_0_mask = visited_mask & ~active_mask # in graph but not expanding
        v_exp_1_mask = visited_mask & active_mask  # in graph AND expanding
        n_e_1 = v_exp_1_mask.sum()
        v_exp_1 = nodes[v_exp_1_mask]
        v_exp_0 = nodes[v_exp_0_mask][:n_e_1]

        # *********  node additon
        v_add_1_mask = active_mask & ~visited_mask
        
        v_add_0_mask = ~visited_mask 
        n_v_1 = v_add_1_mask.sum()
        v_left = nodes[v_add_0_mask]
        v_add_0 = v_left[torch.randperm(v_left.shape[0])][:n_v_1]
        v_add_1 = nodes[v_add_1_mask]
        # print(n_v_1, n_e_1)

        # *********  negative edges 
        r1,c1 = e_1
        c0 = torch.zeros(size=c1.shape)
        c0[:, 0] = c1[:, -1]
        c0[:, 1:] = c1[:, :-1]
        e_0 = torch.stack( (r1,c0))
            
        yield graph.reshape((2,-1)), \
                v_add_1_mask.view((batch_size, data.num_nodes)).sum(dim=1),\
                torch.cat((v_exp_0,v_exp_1)),\
                torch.cat((v_add_0,v_add_1)), \
                torch.cat((e_0,e_1), dim=2).reshape((2,-1)).to(dtype = torch.long),\
                visited_mask.view((batch_size, data.num_nodes))

        visited_mask = visited_mask | v_add_1_mask  # add new to visited