In [42]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from typing import Union, Tuple
from typing import List, Optional, Set, get_type_hints
import pickle
from collections import defaultdict
import numpy as np

from torch_geometric.data import InMemoryDataset
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data, Batch
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_undirected
from torch_geometric.typing import OptPairTensor, Adj, Size
from torch_scatter import gather_csr, scatter, segment_csr
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv

import torch
from torch import Tensor
from torch.nn import Sequential, ReLU, Linear, Dropout, BatchNorm1d
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


from src.data_preparation import RWDataset, data_gen_e_aug
from src.models import GraphNet, add_weight_decay
from src.utils import create_nx_graph,create_gt_graph, draw_deg_distr, relabel, init_graph
from src.utils import sel_start_node, sel_start_node_old, get_errors
from src.train import LabelSmoothing
from src.utils import NodeSelector
from src.model import GPT

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
BSIZE = 16
dset = RWDataset('')
train_loader = DataLoader(dset, batch_size=BSIZE, shuffle=False)# , exclude_keys=['x']

In [5]:
train_gen = data_gen_e_aug(train_loader, dset.slices['edge_index'], batch_size = BSIZE,step = 1)

In [6]:
itt = iter(train_gen)
vi, ei = next(itt)

In [7]:
ei

tensor([[   2,   54,    0,  ..., 1835, 2790, 1967],
        [   0,    0,    2,  ..., 4343, 4343, 4343]])

In [19]:
def create_model(num_features, arch, device):
    dc = ModelGen(num_features, device, arch)
    dc = dc.to(device = device)
    prs = add_weight_decay(dc.named_parameters(), arch['weight_decay'])
#     optimizer = RAdam(prs, lr=arch['lr'])
    optimizer = optim.Adam(prs, lr=arch['lr'])  
    return dc, optimizer

In [36]:
model(vi, dset.features, ei).shape

torch.Size([4344, 128])

In [51]:
device = torch.device('cpu')

arch0 = {'input_dim': dset.features.shape[1], 
        'hidden_dim': 32, 
        'block_size': None,
        'num_layers_tr': 2,
        'num_heads': 8,
        'attn_pdrop': 0.5,
        'resid_pdrop': 0.5,
        'embd_pdrop': 0.5,
        'gnn_pdrop': 0.5,
        'num_gnn_layers': 1,
        'mlp_pdrop': 0.5}
conf = objectview(arch0)


In [50]:
class FeatureTransform(torch.nn.Module):
    def __init__(self, num_features, nn1 , nn2 = None):
        super().__init__()      
        lin1 = nn.Linear(num_features,nn1)
        lin1.bias.data.fill_(0.)  
        bn1 = BatchNorm1d(nn1)
        self.encoder = Sequential(lin1, ReLU(),bn1)
        self.dropout = Dropout(0.2)
        
    def forward(self, x):
#         self.dropout(x)
        z = self.encoder(x)
        return z

class ModelGen(nn.Module):
    def __init__(self, num_features, device, arch):
        super(ModelGen, self).__init__()
        self.lin_inp = FeatureTransform(num_features, arch['nd'])

        self.n_layers_gnn = arch['nlayers_gnn']
        self.convs = nn.ModuleList()
        for i in range(self.n_layers_gnn):
            self.convs.append(SAGEConv(arch['nd'], arch['nd']))
        self.batch_norms = torch.nn.ModuleList(
            [nn.BatchNorm1d(
                num_features=dims[i+1]
            ) for i in range(num_layers-1)])
        
    def forward(self, v_ind, features, edge_index):
        features = self.lin_inp(features)
        x = features[v_ind]
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != self.n_layers_gnn - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)
        return x
    
class GraphNet(nn.Module):
    def __init__(self, input_dim,
                       hidden_dim,
                       output_dim,
                       num_layers,
                       dropout_p):

        super(GraphNet, self).__init__()

        self.num_layers = num_layers
        self.dropout_p = dropout_p

        dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
        self.convs = torch.nn.ModuleList(
            [SAGEConv(
                in_channels=dims[i],
                out_channels=dims[i+1]
            ) for i in range(num_layers)]
        )
        self.batch_norms = torch.nn.ModuleList(
            [torch.nn.BatchNorm1d(
                num_features=dims[i+1]
            ) for i in range(num_layers-1)]
        )

    def forward(self, x, edge_index):
        for i in range(self.num_layers-1):
            x = self.convs[i](x, edge_index)
            x = self.batch_norms[i](x)
            x = F.gelu(x)
            # x = F.dropout(x, p=self.dropout_p, training=self.training)

        x = self.convs[-1](x, edge_index)

        return x
    
class SRAN(nn.Module):
    def __init__(self, config):
        super(SRAN, self).__init__()

        self.input_dim = config.input_dim
        self.hidden_dim = config.hidden_dim
        self.block_size = config.block_size
        self.num_layers = config.num_layers
        self.num_heads = config.num_heads
        self.attn_pdrop = config.attn_pdrop
        self.resid_pdrop = config.resid_pdrop
        self.embd_pdrop = config.embd_pdrop
        self.gnn_pdrop  = config.gnn_pdrop
        self.num_gnn_layers = config.num_gnn_layers
        self.mlp_pdrop = config.mlp_pdrop

        self.mlp = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(self.mlp_pdrop)
        )

        self.gnn = GCN(
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            output_dim=self.hidden_dim,
            num_layers=self.num_gnn_layers,
            dropout_p=self.gnn_pdrop
        )

        self.gpt = GPT(
            block_size=self.block_size,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            attn_pdrop=self.attn_pdrop,
            resid_pdrop=self.resid_pdrop,
            embd_pdrop=self.embd_pdrop
        )
        
        
    def _infer(self, x, edge_index):
        node_feat = self.mlp(x)

        # if there is no edges yet -- send the first n nodes double
        # embeddings to gpt (n - block size of gpt)
#         if edge_index is None:
#             gpt_input = 

        node_feat = self.gnn(node_feat, edge_index)

    def _sample(self):
        pass
    
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

In [None]:
d = {'a': 1, 'b': 2}
o = objectview(d)

In [None]:
# device = torch.device('cuda:1')
arch0 = {'nd': 128, 'weight_decay': 0.0007,'lr': 0.001,'nlayers_gnn': 2,'add_trans': False, 'conv_type':'gin', #'gin'
        'node_agg_type': 'mean','agg_type': 'sum_gate','fin_type_e': 'l2_sum',
        'lb_smth': False}
# diff_mlp
# diff_het
# het
model, optimizer = create_model(dset.features.shape[1], arch0, device)

In [249]:
class MyConv(MessagePassing):
    def __init__(self, **kwargs):  # yapf: disable
        kwargs.setdefault('aggr', 'add')
        super(MyConv, self).__init__(**kwargs)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, deg: Tensor,
                size: Size = None) -> Tensor:
        
        x = torch.cat((x, deg.view(-1,1)), dim=1)
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)      
        out = self.propagate(edge_index, x=x, size=size)
        return out

    def message(self, x_j: Tensor) -> Tensor:
        return x_j
    
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        r"""Aggregates messages from neighbors as
        :math:`\square_{j \in \mathcal{N}(i)}`.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to scatter functions
        that support "add", "mean" and "max" operations as specified in
        :meth:`__init__` by the :obj:`aggr` argument.
        """
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=self.aggr)
        else:
            print(inputs, index)
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                           reduce=self.aggr)
        
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [None]:
class GENConv(MessagePassing):
    def __init__(self):
        kwargs.setdefault('aggr', None)
        super(GENConv, self).__init__(**kwargs)

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x, size=size)
        if self.msg_norm is not None:
            out = self.msg_norm(x[0], out)
        x_r = x[1]
        if x_r is not None:
            out += x_r
        return self.mlp(out)

    def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
        return F.relu(msg) + self.eps

    def aggregate(self, inputs: Tensor, index: Tensor,
                  dim_size: Optional[int] = None) -> Tensor:

        out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
        return scatter(inputs * out, index, dim=self.node_dim,
                           dim_size=dim_size, reduce='sum')

#         elif self.aggr == 'softmax_sg':
#             out = scatter_softmax(inputs * self.t, index,
#                                   dim=self.node_dim).detach()
#             return scatter(inputs * out, index, dim=self.node_dim,
#                            dim_size=dim_size, reduce='sum')

#         else:
#             min_value, max_value = 1e-7, 1e1
#             torch.clamp_(inputs, min_value, max_value)
#             out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
#                           dim_size=dim_size, reduce='mean')
#             torch.clamp_(out, min_value, max_value)
#             return torch.pow(out, 1 / self.p)

In [250]:
conv = MyConv()

In [127]:
el = [(0,13),(0,5),(0,14),(13,1),(0,18),(0,8),(14,7),(0,11),(0,3),
         (3,10),(0,19),(11,4),(1,9),(9,12),(0,12),(3,7),(0,16),(11,8),
         (1,3),(1,6),(0,13),(14,13),(11,15),(0,12),(4,17),(11,12),(3,2)]
ei = torch.tensor(el).T

In [300]:
ei_dict = defaultdict(set)
for e in ei.T:
    e = e.to(dtype=torch.long)
    ei_dict[e[0]].add(e[1]) 
    ei_dict[e[1]].add(e[0]) 

In [305]:
b = {torch.tensor(0): {torch.tensor(13)}}
b

{tensor(0): {tensor(13)}}

In [147]:
oh = F.one_hot(ei[0],num_classes=20)+F.one_hot(ei[1],num_classes=20)
oh

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0,

In [365]:
# torch.cumsum(oh, dim=0)


In [254]:
row, col = ei
adj_t = SparseTensor(row=col, col=row,
                     sparse_sizes=(data.num_nodes, data.num_nodes))

In [259]:
adj_t

SparseTensor(row=tensor([ 1,  2,  3,  3,  4,  5,  6,  7,  7,  8,  8,  9, 10, 11, 12, 12, 12, 12,
                           13, 13, 13, 14, 15, 16, 17, 18, 19]),
             col=tensor([13,  3,  0,  1, 11,  0,  1,  3, 14,  0, 11,  1,  3,  0,  0,  0,  9, 11,
                            0,  0, 14,  0, 11,  0,  4,  0,  0]),
             size=(275, 275), nnz=27, density=0.04%)

In [258]:
adj_t[3]

SparseTensor(row=tensor([0, 0]),
             col=tensor([0, 1]),
             size=(1, 275), nnz=2, density=0.73%)

In [262]:
delta_v = torch.ones(10, 8)
delta_v

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
torch.zeros(16, 8, dtype=src.dtype).scatter_add_(0, index, src)


In [252]:
_, dg = torch.unique(ei, return_counts = True)
dg

tensor([12,  4,  1,  5,  2,  1,  1,  2,  2,  2,  1,  5,  4,  4,  3,  1,  1,  1,
         1,  1])

In [253]:
xx = torch.rand(20,8)
y = conv(x = xx, edge_index = ei, deg = torch.ones(20))

tensor([[0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.9967, 0.6499, 0.4298, 0.2115, 0.7027, 0.8352, 0.6765, 0.0435, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.7772, 0.1651, 0.0392, 0.4880, 0.2471, 0.5597, 0.5317, 0.4280, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.0049, 0.0289, 0.7177, 0.0580, 0.8168, 0.5887, 0.8685, 0.2321, 1.0000],
        [0.3218, 0.3033, 0.4526, 0.0737, 0.3362, 0.4346, 0.5787, 0.9123, 1.0000],
        [0.7230, 0.3680, 0.4621, 0.3198, 0.0290, 0.2466, 0.4086, 0.0826, 1.0000],
        [0.7480,