In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os.path as osp
import pathlib
import math

from typing import Union, Tuple
from typing import List, Optional, Set, get_type_hints
import pickle
from collections import defaultdict
import numpy as np

from torch_geometric.data import InMemoryDataset
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data, Batch
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_undirected
from torch_geometric.typing import OptPairTensor, Adj, Size
from torch_scatter import gather_csr, scatter, segment_csr
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import TUDataset

import torch
from torch import Tensor
from torch.nn import Sequential, ReLU, Linear, Dropout, BatchNorm1d
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from einops import rearrange, reduce, repeat

# from src.data_preparation import RWDataset, data_gen_e_aug
# from src.models import GraphNet, add_weight_decay
# from src.utils import create_nx_graph,create_gt_graph, draw_deg_distr, relabel, init_graph
# from src.utils import sel_start_node, sel_start_node_old, get_errors
# from src.train import LabelSmoothing
# from src.utils import NodeSelector
from src.model import GPT

In [49]:
def load_PYG_datasets(path, d_name = 'PROTEINS'):
    pth = path
    path = osp.join(pathlib.Path().absolute(), pth , d_name)
    dset = TUDataset(path, d_name)
    
    #sozdal otdelno pole features, a v pole x pomestil indexi - t.k. vektora vershin u nas menyayutsya
    dset.features = dset.data.x 
    dset.data.x = torch.arange(dset.data.x.shape[0])
    return dset

In [50]:
dset = load_PYG_datasets(path='./data/proteins', d_name='PROTEINS')

#dset = RWDataset('')  random walks for Cora

In [51]:
BSIZE = 16
train_loader = DataLoader(dset, batch_size=BSIZE, shuffle=False)# , exclude_keys=['x']

In [52]:
def data_gen_e_aug(train_loader, slices, batch_size = 4, step = 1):
    r""" berem skleennie graphi, iteriruem po vershinam, stroim augmented graph

    Args:
        train_loader: PyG DataLoader
        slices: indexy reber, po kotorim skleivali nabor graphov
        batch_size:
        step: po skolko reber narashivaem graph, obichno 1
    Outputs:
        vi: indexi vershin v ishodnom graphe
        graph: augmented graph
        edges: priroshennie rebra
    """
    
    for ib, data in enumerate(train_loader):
        print(data)
        e_ptr = slices[ib*batch_size:(ib+1)*batch_size+1]
        e_ptr = e_ptr - e_ptr[0]
        szs = e_ptr[1:]-e_ptr[:-1]
        e_ind_start = e_ptr[1:]-szs.min()+1
        visited_e = torch.full((e_ptr[-1],), False, dtype = torch.bool)
        for i in range(e_ind_start.shape[0]):
            visited_e[e_ptr[:-1][i] :e_ind_start[i]] = True # setting emask True for edges in graph
        edges_num = torch.arange(e_ptr[-1])
        
        visited_v = torch.full((data.ptr[-1],), False)
        visited_v[torch.unique(data.edge_index[:, visited_e])] = True
        
        vert_base = torch.where(visited_v)[0]
        
        vert_ind,v_feature_ind = [], []
        last_ind_v = torch.arange(data.ptr[-1])
        last_ind_max = data.ptr[-1].item()        
        ei_dict = defaultdict(set)
        for e in data.edge_index[:, visited_e].T:
            e = e.to(dtype=torch.long)
            ei_dict[e[0].item()].add(e[1].item()) 
            ei_dict[e[1].item()].add(e[0].item()) 
        edge_added = []
        
        for i in range(data.edge_index.shape[1]): # max number of iterations, usually we stop earlier
            if torch.all(visited_e):
                break           
                  
            e1_mask = (visited_v[data.edge_index[0]] | visited_v[data.edge_index[1]]) & ~visited_e # Source in graph
            nnedges = edges_num[e1_mask]
            e1_ind = []
            for j in range(1, e_ptr.shape[0]):
                mmask = (nnedges < e_ptr[j]) & (nnedges >= e_ptr[j-1])
                e1_ind.append(nnedges[mmask][:step])         
            e1_ind = torch.cat(e1_ind)
            edges_1 = data.edge_index[:, e1_ind]
            
            for e in edges_1.T: 
                e_reind = []
                edge_added.append(last_ind_v[e].view(-1,1))
                
                for iv in (True,False):
                    ind = last_ind_v[e[int(iv)]].item()
                    vert_ind.append(ind)
                    v_feature_ind.append(e[int(iv)])
                    if ind in ei_dict.keys():
                        ind = last_ind_max                        
                        last_ind_v[e[int(iv)]] = ind
                        ei_dict[ind] = ei_dict[e[int(iv)].item()]   
                        last_ind_max += 1  
                    e_reind.append(ind)
                    
                ei_dict[e_reind[0]].add(e_reind[1]) #sporno
                ei_dict[e_reind[1]].add(e_reind[0])
        
            # selecting source-target when both vertices in graph            
            visited_e[e1_ind] = True  
            visited_v[edges_1.view(-1)] = True
        
        edge_index = []
        for k,v in ei_dict.items():
            e = torch.tensor(list(v)).view(1,-1)
            edge_index.append(torch.cat((e, torch.full_like(e, k)), dim=0))


        yield  data.x[torch.cat((torch.Tensor(v_feature_ind).to(dtype=torch.long),vert_base))],\
                torch.cat((torch.Tensor(vert_ind).to(dtype=torch.long),last_ind_v)), \
                torch.cat(edge_added, dim=1),\
                torch.cat(edge_index, dim=1)

In [53]:
for ib, data in enumerate(train_loader):
    print(data.ptr)
    szs =data.ptr[1:]-data.ptr[:-1]
    r = repeat(torch.arange(szs.max()), 'h -> h c', c=16).T
    mask = r < szs.unsqueeze(1)
    print(mask.shape, mask)
    
#     print(r.shape,data.ptr[1:].unsqueeze(1).shape)
#     print(szs.max())
    break

tensor([  0,  42,  69,  79, 103, 114, 450, 558, 712, 731, 742, 762, 814, 835,
        879, 899, 939])
torch.Size([16, 336]) tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])


In [None]:
szs =e_ptr[1:]-e_ptr[:-1]

In [54]:
train_gen = data_gen_e_aug(train_loader, dset.slices['edge_index'], batch_size = BSIZE,step = 1)

In [55]:
itt = iter(train_gen)
vi, vj, e1, graph = next(itt)

Batch(batch=[939], edge_index=[2, 3928], ptr=[17], x=[939], y=[16])


In [134]:
vi.shape

torch.Size([1964])

In [133]:
vj.shape

torch.Size([1995])

In [56]:
device = torch.device('cpu')
    
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d
        
arch0 = {'input_dim': dset.features.shape[1], 
        'hidden_dim': 32,
        'num_layers': 2,
        'num_heads': 8,
        'attn_pdrop': 0.5,
        'resid_pdrop': 0.5,
        'embd_pdrop': 0.5,
        'gnn_pdrop': 0.5,
        'num_gnn_layers': 1,
        'mlp_pdrop': 0.5}
conf = objectview(arch0)

In [57]:
class GraphNet(nn.Module):
    def __init__(self, hidden_dim, num_layers, dropout_p = 0.2):
        super(GraphNet, self).__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for i in range(num_layers):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(num_features=hidden_dim))
            
    def forward(self, v_ind, features, edge_index):
        x = features[v_ind]
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != self.num_layers - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)
        return x
    
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super(LinkPredictor, self).__init__()
        self.lin = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_true, z_concat):
        h = torch.abs(z_true - self.lin(z_concat)) 

        return self.lin_final(h)
    
    
class SRAN(nn.Module):
    def __init__(self, config):
        super(SRAN, self).__init__()

        self.input_dim = config.input_dim
        self.hidden_dim = config.hidden_dim
        self.num_layers = config.num_layers
        self.num_heads = config.num_heads
        self.attn_pdrop = config.attn_pdrop
        self.resid_pdrop = config.resid_pdrop
        self.embd_pdrop = config.embd_pdrop
        self.gnn_pdrop  = config.gnn_pdrop
        self.num_gnn_layers = config.num_gnn_layers
        self.mlp_pdrop = config.mlp_pdrop

        self.lin_inp = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.hidden_dim)
#             nn.Linear(self.hidden_dim, self.hidden_dim),
#             nn.Dropout(self.mlp_pdrop)
        )

        self.gnn = GraphNet(
            hidden_dim=self.hidden_dim,
            num_layers=self.num_gnn_layers,
        )

        self.gpt = GPT(
            hidden_dim=2 * self.hidden_dim,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            attn_pdrop=self.attn_pdrop,
            resid_pdrop=self.resid_pdrop,
            embd_pdrop=self.embd_pdrop
        )
        
        self.lp = LinkPredictor(2 * self.hidden_dim)
        
    def forward(self, v_ind, features, edge_index, edges):
        feat = self.lin_inp(features)
        print(v_ind.shape)

        gnn_feat = self.gnn(v_ind, feat, edge_index) # (N,d)

        edge_embs_true = rearrange(gnn_feat[edges], 
                                   'e batch n_seq d -> n_seq batch (d e)')# (S, N, d)
        next_edge_embs = self.gpt(edge_embs_true[:-1])
        

        return next_edge_embs


        
#         h_next_e = self.gpt(torch.cat((h_source, h_target)))
        
# #         y_source = self.choice(self.mlp_y_s(h_next_e), h_source) # сместить на 1 позицию
# #         z_target = self.mlp_y_t(torch.cat((h_next_e, h_source))) #а здесь не смещаить h_source?
# #         y_target = self.choice(z_target, h_target)
        
#         return h_next_e


In [58]:
model = SRAN(conf)

In [59]:
y = model(vi, dset.features, graph, e1.view(2, BSIZE,-1))

torch.Size([1964])


In [60]:
y.shape

torch.Size([32, 16, 64])

In [22]:
e1

tensor([[  33,   59,   69,  ..., 1927, 1929, 1931],
        [  34,   61,   71,  ..., 1799, 1833, 1867]])

In [73]:
dset.features

tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        ...,
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])

In [18]:
v1 = torch.randn(100)
v2 = torch.randn(100)
tensors = rearrange([v1, v2], 'b a -> a b')


In [12]:
(v1.unsqueeze(2) - v2.unsqueeze(1)).shape

torch.Size([100, 16, 16])

In [6]:
repeat(v1, 'a b -> a b c', c=16).shape

torch.Size([100, 16, 16])

In [18]:
rearrange(v1, 'a b -> b a 1').shape

torch.Size([16, 100, 1])

In [61]:
batch_size = 16
Nvert = 100
d = 32

v1 = torch.zeros(batch_size, Nvert, d)
v2 = torch.ones(batch_size, Nvert, d)
z1 = repeat(v1, 'b a d -> b a c d', c = Nvert)
z2 = repeat(v2, 'b a d -> b c a d', c = Nvert)
print(z1.shape, z2.shape)


torch.Size([16, 100, 100, 32]) torch.Size([16, 100, 100, 32])


In [62]:
tensors = rearrange([z1,z2], 'e b x y d -> b x y (e d)')
tensors.shape

torch.Size([16, 100, 100, 64])

In [63]:
tensors[0, 1,1,:]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [None]:
# eshe ne pravil!!!!

def train(sampler, mod, optimizer, data, bsz):
    mod.train()
        
    loss = 0   
    encoded_features = mod.encode(data.x.to(device = mod.device))

    ii=0
    for inp in sampler(data, batch_size = bsz):
        ii +=1
        encoded_features = mod.encode(data.x.to(device = mod.device))

        optimizer.zero_grad() 
        pred_e, targ_e, w0 = mod.iterate(data, inp, encoded, bsz) 
#         n_1.append(targ_e.sum())
#         num_e.append(pred_e.shape[0])
        if mod.label_smooth:
            loss += cal_edge_loss(pred_e, targ_e)
        else:
            loss += F.binary_cross_entropy_with_logits(pred_e, targ_e)    
#             loss += F.binary_cross_entropy_with_logits(pred_e, targ_e, 
#                                                        pos_weight = torch.tensor([w0], device = mod.device))    
            
    loss.backward()
    optimizer.step()


In [None]:
num_epochs = 100
for epoch in range(1, num_epochs + 1):
#     d_train
    train(data_gen_edges_dyn, model, optimizer, data, B_SIZE)
#     train(data_gen_edges_dyn, model, optimizer, d_train, B_SIZE)
#     train(data_gen_edges_dyn, model, optimizer, data, B_SIZE)
#     acc_e_test = test( data_gen_edges_dyn, model, d_test, 1)
#     acc_e_test = test( data_gen_edges_dyn, model, data, 1)
#     acc = acc_e_test.mean()
#     print(acc, np.median(acc_e_test))

In [None]:
# eshe ne pravil!!!!
@torch.no_grad()
def test(sampler,mod, data, bsz, return_y = False, dtype = 'test'):
    mod.eval()
        
    out = []
    n_1, acc_e,num_e = [],[],[]
    # init_state = torch.zeros((data.num_nodes*bsz,  mod.nd), device=mod.device)
    encoded = mod.encode(data.x.to(device = mod.device))
    indd = torch.cat(bsz*[torch.arange(data.num_nodes, device = mod.device)])
    mod.reset_state(encoded[indd,:])
    for inp in sampler(data, batch_size = bsz, step_max = 512, dtype = dtype):
        pred_e, targ_e, w0 = mod.iterate(data, inp, encoded, bsz) 
        if return_y:
            out.append((pred_e, targ_e, w0))#, dim=1
        pred_e = torch.round(torch.sigmoid(pred_e))
#         print('pred_e', pred_e)
        tp_e = pred_e.eq(targ_e).sum().item()
        n_e = pred_e.shape[0]
        
        n_1.append(targ_e.sum())

        acc_e.append(tp_e)
        num_e.append(n_e)
        
    if return_y:
        return out
    else:
        return  np.array(acc_e)/(np.array(num_e)+1e-9)

    