# Test Dataset to check

In [7]:
#as it turned out interactive shell (like Jupyter cannot handle CPU multiprocessing well so check which medium the code is runing)
#we will write code in Jupyter for understanding purposes but final execuation will be in shell

In [8]:
import DeviceDir

DIR, RESULTS_DIR = DeviceDir.get_directory()
device, NUM_PROCESSORS = DeviceDir.get_device()

In [9]:
import random
import multiprocessing
import pandas as pd
import os
from tqdm import tqdm

In [10]:
from ipynb.fs.full.Utils import isnotebook
from ipynb.fs.full.Dataset import get_data, generate_synthetic
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, from_networkx
import torch_geometric.utils.homophily as homophily
import copy
import ipynb.fs.full.utils.MoonGraph as MoonGraph
import logging
from sklearn.metrics import f1_score, accuracy_score
from torch_geometric.utils import add_self_loops
import heapq
import math

In [11]:
import argparse
from argparse import ArgumentParser
from ipynb.fs.full.Dataset import datasets as available_datasets

#set default arguments here
def get_configuration():
    parser = ArgumentParser()    
    parser.add_argument('--log_info', type=bool, default=True)
    parser.add_argument('--pbar', type=bool, default=False)
    parser.add_argument('--balance', type=bool, default=True)
    parser.add_argument('--num_worker', type=int, default=0)
    parser.add_argument('--dataset', type=str, default="karate", choices=available_datasets)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--link_batch_size', type=int, default=4096*8) #8192
    parser.add_argument('--link_num_steps', type=int, default=100)
    parser.add_argument('--num_neurons', type=int, default=32)
    parser.add_argument('--f') ##dummy for jupyternotebook
    args = parser.parse_args()
    
    dict_args = vars(args)
    
    return args, dict_args

args, dict_args = get_configuration()

In [12]:
import torch
import torch.nn as nn
from torch_sparse import SparseTensor
from tqdm import tqdm
import math
import time
import torch.nn.functional as F

import random
random.seed(12345)
import numpy as np
np.random.seed(12345)

In [13]:
import sklearn
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from multiprocessing.pool import ThreadPool, Pool
import os.path as osp
from typing import Optional, List, Dict
from torch_sparse import SparseTensor
from tqdm import tqdm
from torch_geometric.typing import EdgeType, InputNodes

# Link Prediction Model

In [14]:
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv

In [15]:
GNNlayer=GCNConv

class LinkModel(nn.Module):
    def __init__(self, input_rep, num_neurons=64):
        super(LinkModel, self).__init__()
        
        self.MLP1 = nn.Linear(input_rep,num_neurons)        
        #self.MLP2 = nn.Linear(num_neurons,num_neurons)
        self.MLP3 = nn.Linear(num_neurons*2,1)
        
    def forward(self, x, y):
                            
        x = self.MLP1(x)
        x = x.relu()
        x = F.dropout(x, p=0.2, training=self.training)
        
        y = self.MLP1(y)
        y = y.relu()
        y = F.dropout(y, p=0.2, training=self.training)
        
        z=torch.cat((x-y,x*y),dim=1)  #         xy=x+y        
#         z = self.MLP2(z)
#         z = z.relu()
#         z = F.dropout(z, p=0.5, training=self.training)

        z = self.MLP3(z)
#         z = torch.sigmoid(z)
#         z = z.relu()
#         z = F.log_softmax(z,dim=1)

        return z

In [16]:
# model = LinkModel(data.num_features, num_neurons=64).to(device)
# print(model)

## Link Prediction Sampler

In [17]:
class LinkSampler(torch.utils.data.DataLoader):

    def __init__(self, data, input_nodes: InputNodes = None, batch_size: int=1, num_steps: int = 1, 
                 save_dir: Optional[str] = None, recompute = True,log: bool = True, balance=False, **kwargs):

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']
            
        self.num_steps = num_steps
        self.__batch_size__ = batch_size        
        self.save_dir = save_dir
        self.recompute = recompute
        self.log = log
        self.balance = balance
        
        self.data = data
        
        self.N = N = data.num_nodes
        self.E = E = data.num_edges
        
        self.input_nodeidx = torch.nonzero(input_nodes).flatten()
        
        #print(self.input_nodeidx)
        
        if balance:
            #get an estimate of ratio
            indices = torch.randint(len(self.input_nodeidx), (self.__batch_size__, ))
            x = self.input_nodeidx[indices]
            indices = torch.randint(len(self.input_nodeidx), (self.__batch_size__, ))
            y = self.input_nodeidx[indices]
            label = (self.data.y[x] == self.data.y[y]).type(torch.float)
            self.ratio = torch.sum(label).item()/self.__batch_size__
            
            #print(self.ratio)
                        
            #######
            self.num_class = torch.max(data.y)+1        
            self.clusters = [[] for i in range(self.num_class)]
        
            for i in self.input_nodeidx:
                self.clusters[data.y[i]].append(i.item())
                
            for i in range(self.num_class):
                self.clusters[i] = torch.LongTensor(self.clusters[i])
            
            #print(self.clusters)
        

        super().__init__(self, batch_size=1, collate_fn=self.__collate__,
                         **kwargs)
    
    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __len__(self):
        return self.num_steps

    def __getitem__(self, idx):
        
        if self.balance and self.ratio<=0.40:
            
            per_class = math.ceil(self.__batch_size__*(0.5-self.ratio)/self.num_class)
            
            Xs = torch.LongTensor([])
            Ys = torch.LongTensor([])
                        
            for i in range(self.num_class):
                
                if len(self.clusters[i])==0:
                    continue
                
                indices = torch.randint(len(self.clusters[i]), (per_class, ))
                x = self.clusters[i][indices]
                indices = torch.randint(len(self.clusters[i]), (per_class, ))
                y = self.clusters[i][indices]
                Xs = torch.cat((Xs,x))
                Ys = torch.cat((Ys,y))                  
            
            remaining = per_class*self.num_class
            remaining = self.__batch_size__ - remaining
            
            indices = torch.randint(len(self.input_nodeidx), (remaining, ))
            x = self.input_nodeidx[indices]
            indices = torch.randint(len(self.input_nodeidx), (remaining, ))
            y = self.input_nodeidx[indices]
            
            x = torch.cat((Xs,x))
            y = torch.cat((Ys,y))
                    
            #print(x.shape)
            #print(y.shape)
            
        else:
            indices = torch.randint(len(self.input_nodeidx), (self.__batch_size__, ))
            x = self.input_nodeidx[indices]
            indices = torch.randint(len(self.input_nodeidx), (self.__batch_size__, ))
            y = self.input_nodeidx[indices]
        
        return x, y
    
    def __collate__(self, data_list):
        assert len(data_list) == 1
        
        x, y = data_list[0]
        
        label = (self.data.y[x] == self.data.y[y]).type(torch.float)
        b_data = self.data.__class__()
        b_data.x = x
        b_data.y = y
        b_data.label = label
        b_data.x_feat = self.data.x[x]
        b_data.y_feat = self.data.x[y]
                
        return b_data

## Test Dataset

In [18]:
# # data, dataset = get_data('Reddit', log=False)

# from torch_geometric.data import Data
# x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
# y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
# edge_index = torch.LongTensor([[1,2],[1,4],[1,5],[2,1],[3,6],[3,7],[4,5],[4,1],[4,6],[4,7],[5,1],[5,4],[5,6],[6,3],[6,4],[6,5],[6,7],[7,3],[7,4],[7,6]]).T
# edge_index = edge_index-1
# train_mask = torch.zeros(len(y)).type(torch.bool)
# train_mask[[0,1,2]]=True
# data = Data(x=x, y=y, edge_index = edge_index, train_mask = train_mask, val_mask = train_mask, test_mask = train_mask)

In [19]:
# batch_size = 4096
# train_sampler  = LinkSampler(data, input_nodes = data.train_mask, batch_size = batch_size, num_steps = 10, save_dir = None,
#                              recompute = True,log = True, balance=True)
# for batch in train_sampler:
#     print(batch)
#     print(sum(batch.label)/batch_size)

In [20]:
# model = LinkModel(data.num_features, num_neurons=64).to(device)
# print(model)


## Train and Test

In [21]:
def predict(model, data, sampler, log = True):

    y_pred=np.array([])
    y_true=np.array([])
    
    if log:
        pbar = tqdm(total=args.link_batch_size*args.link_num_steps)
        pbar.set_description(f'Predicting: ')
    
    model.eval()
    with torch.no_grad():
        for b_data in sampler:
            b_data = b_data.to(device)
            out = model(b_data.x_feat, b_data.y_feat).view(-1)         
            
            pred = torch.zeros_like(out)
            pred[out >= 0.5] = 1
            
            pred = pred.cpu().numpy()
            test_target=b_data.label.cpu().numpy()
        
#             print(out)
#             print(pred)
#             print(test_target)
        
            y_pred = np.append(y_pred,pred)
            y_true = np.append(y_true,test_target)
            
            if log:
                pbar.update(args.link_batch_size)
        if log:
            pbar.close()
    
    micro=f1_score(y_true, y_pred, average='micro')
    weighted=f1_score(y_true, y_pred, average='weighted')
    acc=accuracy_score(y_true, y_pred)
    
    return acc, micro, weighted

In [22]:
def train(model, data, log = True, epochs=1, worker = 0):    
        
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.MSELoss()    
#     criterion = nn.BCELoss()
#     criterion = nn.CrossEntropyLoss()
#     criterion = nn.NLLLoss()
            
    train_sampler  = LinkSampler(data, input_nodes = data.train_mask, batch_size = args.link_batch_size, num_steps = args.link_num_steps,
                                 save_dir = None, recompute = False,log = log, balance=args.balance, num_workers = worker)
    
    val_sampler  = LinkSampler(data, input_nodes = data.val_mask, batch_size = args.link_batch_size, num_steps = args.link_num_steps,
                                 save_dir = None, recompute = False,log = log, balance=args.balance, num_workers = worker)
    
    test_sampler  = LinkSampler(data, input_nodes = data.test_mask, batch_size = args.link_batch_size, num_steps = args.link_num_steps,
                                 save_dir = None, recompute = False,log = log, balance=args.balance, num_workers = worker)
    
    show_pbar = log
#     if log and data.num_nodes>100000:
#         show_pbar = True
    
    #worker = 0     
    train_losses=[]
    
    for epoch in range(1,epochs+1):        
        total_loss = total_examples = 0
        y_pred=[]
        y_true=[]
        
        if show_pbar:
            pbar = tqdm(total=args.link_batch_size*args.link_num_steps)
            pbar.set_description(f'Epoch {epoch:02d}')
        
        model.train()
        for b_data in train_sampler:            
            b_data = b_data.to(device)
            
#             print(sum(b_data.label))
            
            optimizer.zero_grad()            
            out = model(b_data.x_feat,b_data.y_feat)
            loss = criterion(out, b_data.label.view(-1,1))            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_examples += args.link_batch_size
            
            if show_pbar:
                pbar.update(args.link_batch_size)
        if show_pbar:
            pbar.close()
        
        loss=total_loss / args.link_num_steps
        train_losses.append(loss)
        
        if log:            
            print(f'Epoch {epoch:03d} Loss {loss:.4f}', end=' ')            
            tr_a, tr_b, tr_c = predict(model, data, train_sampler, log = log)
            print(f'\t{tr_a:.4f},\t{tr_b:.4f},\t{tr_c:.4f}')            
    
    if log:
#         tr_a, tr_b, tr_c = predict(model, data, train_sampler, log = log) 
#         print(f'Train\t{tr_a:.4f},\t{tr_b:.4f},\t{tr_c:.4f}')
        tr_a, tr_b, tr_c = predict(model, data, val_sampler, log = log) 
        print(f'Val\t{tr_a:.4f},\t{tr_b:.4f},\t{tr_c:.4f}')
        tr_a, tr_b, tr_c = predict(model, data, test_sampler, log = log) 
        print(f'Test\t{tr_a:.4f},\t{tr_b:.4f},\t{tr_c:.4f}')
                
    return model

# train(model, data, log = True, epochs=100)    

In [23]:
def train_link(data, selfloop = True, log = True, worker = 0):
    
    args.balance = False
    
    if data.num_nodes<10000:
        args.epochs = 20
        args.num_neurons = 32
        args.link_batch_size = min(4096, data.num_nodes*data.num_nodes)
        args.link_num_steps = 200

    elif data.num_nodes<100000:
        args.epochs = 10
        args.num_neurons = 32
        args.link_batch_size = 4096*2
        args.link_num_steps = 200
    else:
        args.epochs = 5
        args.num_neurons = 32
        args.link_batch_size = 4096*8
        args.link_num_steps = 200
        
    #args.epochs = 1
    
        
    model = LinkModel(data.num_features, num_neurons=args.num_neurons).to(device)
    
    if log:
        print(model)
    
    train(model, data, log = log, epochs=args.epochs, worker = worker)        
    
    return model

## Edge Weight Computation Edge Wise (Fast)

In [24]:
def get_link_weight(data, selfloop = True, log = True, worker = 0):
    
    #if data.num_nodes<10000:
    worker = 0 ##worker 0 found to be fastest
    
    link_model = train_link(data, selfloop = selfloop, log = log, worker = worker)
    
    w = torch.Tensor([]).type(torch.float).to(device)
    
    indices = torch.arange(0, data.edge_index.shape[1])
    batches = torch.split(indices, args.link_batch_size)
    
    link_model.eval()    
    with torch.no_grad():    
        for batch in batches:
            
            #print(batch)
            
            idx = data.edge_index[:,batch]
            ew = link_model(data.x[idx[0]].to(device),data.x[idx[1]].to(device))
            ew = ew.view(-1)
            ew = torch.clamp(ew, min=1e-3, max=1.0)
            w = torch.cat((w,ew))
            
    return w.cpu()

## Test code

In [25]:
# data, dataset = get_data('Reddit', log=False, h_score=True, split_no = 0)
# args.epochs = 1
# args.num_neurons = 32
# args.link_batch_size = 4096*8
# args.link_num_steps = 100
# args.balance = True

# start = time.time()
# link_model = train_link(data, selfloop = True, log = True)
# end = time.time()
# print(end-start)

In [26]:
# w = get_link_weight(data, selfloop = True, log = False, worker=0)
# print(w)

## Nearest Neighbor Weight assignment

In [27]:
class LinkNN():
    
    def __init__(self, data, value='min', log=True, worker=0, 
                 lambda1=0.25, lambda2=0.25, w1=1.0, w2=0.5, w3=0.1):
        
        self.N = N = data.num_nodes
        self.E = E = data.num_edges
        self.data = data        
        self.value = value
        self.log = log
        self.lambda1=lambda1
        self.lambda2=lambda2
        self.w1=w1
        self.w2=w2
        self.w3=w3
        
        self.sign = 1
        
        if value=='min':
            self.sign = -1
            
        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(E, device=data.edge_index.device),
            sparse_sizes=(N, N))
        
        self.weight = get_link_weight(data, selfloop = True, log = log, worker=worker)
        
    def node_weight(self,u):
    
        row, col, edge_index = self.adj[u,:].coo()           
        
        target_class_sim = self.weight[edge_index]
        ind = np.argsort(self.sign*target_class_sim) #-1*desending, normal will be ascending
        
#         print(u, row, col, edge_index)
#         print(target_class_sim)
#         print(ind)
         
        lambda1 = self.lambda1 #top 25% with probability 1
        lambda2 = self.lambda2 #second 25% with probability 0.5 
        
        l1=math.ceil(len(col)*lambda1)
        l2=min(len(col)-l1,math.ceil(len(col)*lambda2))        
        l3=max(0,int(len(col)-l1-l2))
        
#         print(len(col),l1, l2, l3)
        
#         S_G = np.ones(l1, dtype=float)*1.0
#         S_G = np.append(S_G, np.ones(l2, dtype=float)*0.5)
#         if(l3>0):
#             S_G = np.append(S_G, np.ones(l3, dtype=float)*0.1)

        S_G = np.ones(l1, dtype=float)*self.w1
        S_G = np.append(S_G, np.ones(l2, dtype=float)*self.w2)
        
        if(l3>0):
            S_G = np.append(S_G, np.ones(l3, dtype=float)*self.w3)
        
        S_G = S_G.tolist()
        
#         S_G = list(range(1,len(col)+1))
        S_edge = edge_index[ind].tolist()
        
        return S_G, S_edge

    def get_nn_weight(self):
        
        if self.log:        
            pbar = tqdm(total=self.N)
            pbar.set_description(f'Nodes')

        edge_weight=[]
        edge_index=[]

        for u in range(self.N):            
            weight, e_index = self.node_weight(u)
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            if self.log:        
                pbar.update(1)
        if self.log:        
            pbar.close()
        
        assert len(edge_index)==self.E
        
        weight=torch.zeros(len(edge_index))        
        weight[edge_index]=torch.Tensor(edge_weight)
        
        return weight
    
    def process_block(self, list_u):
        
        #print("Processing :",len(list_u), list_u[0], list_u[-1])
        
        edge_weight = []
        edge_index = []
        
        for u in list_u:        
            weight, e_index = self.node_weight(u)            
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            
        #print("Done :",len(list_u), list_u[0], list_u[-1])
            
        return edge_weight, edge_index, len(list_u)
    
    #multiprocessing
    def get_nn_weight_multiproces(self):
        
        edge_weight=[]
        edge_index=[]        
        
        N = self.N
        num_blocks = NUM_PROCESSORS
        elem_size = int(N/num_blocks)

        
        nodes = np.arange(num_blocks*elem_size).reshape(num_blocks,-1).tolist()
        if num_blocks*elem_size<N:
            nodes.append(list(range(num_blocks*elem_size,N)))        
        
        pool_size = NUM_PROCESSORS        
        if self.log:
            print("Pool Size: ", pool_size)        
        pool = Pool(pool_size)
        
        if self.log:
            pbar = tqdm(total=N)
            pbar.set_description(f'Nodes')  
                
        for (weight, e_index, num_el) in pool.imap_unordered(self.process_block, nodes):            
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            if self.log:
                pbar.update(num_el)
        if self.log:
            pbar.close()
        
        assert len(edge_index)==self.E        
        
        weight=torch.zeros(len(edge_index))        
        weight[edge_index]=torch.Tensor(edge_weight)
        
        return weight
    
    
    def compute_weights(self):
        if self.data.num_nodes<10000:
            weight = self.get_nn_weight()    
        else:
            weight = self.get_nn_weight_multiproces()
        
        return weight

## Submodular Weight assignment

In [28]:
class LinkSubOld():
    
    def __init__(self, data, value='max', selfloop = False, log = True, worker=0, lambda1=0.25, lambda2=0.25, w1=1.0, w2=0.5, w3=0.1):
        
        self.N = N = data.num_nodes
        self.E = E = data.num_edges
        self.data = data
        self.log = log
        self.selfloop = selfloop
        
        self.lambda1=lambda1
        self.lambda2=lambda2
        self.w1=w1
        self.w2=w2
        self.w3=w3
    
        #self.X = data.x.to(device)
        self.X = self.data.x
        self.on_device=True        
        
        self.model = train_link(data, selfloop = selfloop, log= log, worker=worker)
        self.model.eval()        

        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(E, device=data.edge_index.device),
            sparse_sizes=(N, N))
        
        if self.log:
            print("value: ", value)
        
        self.value = value
        self.sign = -1
        
        if self.value == 'max':
            self.sign = 1 ##-1 select the nearest ones, 1 for the farthest        
            
        elif self.value == 'min':
            self.sign = -1
        else:
            raise 'Not implemented error'
    
    def pairwise_link(self, x):  
                
        n, f = x.shape
        
        x_col1 = x.repeat_interleave(n, dim=0)
        x_col2 = x.repeat(n,1)
        # print(x_col1, x_col2)
        
        if self.on_device:
            x_col1 = x_col1.to(device)
            x_col2 = x_col2.to(device)            
        
        with torch.no_grad():
            output = self.model(x_col1, x_col2).detach().cpu()
            output = torch.clamp(output, min=1e-3, max=1.0)
        #print(output.shape)
#         output = output.softmax(dim=1)
#         second_column = output[:,1].cpu()        
        #print(second_column)
        
        similarity_matrix = output.view(n,n)
        
#         print(similarity_matrix)
        
        return similarity_matrix
        
    def lazy_greedy_weight(self,u):
        
        row, col, edge_index = self.adj[u,:].coo()
        
        if len(col)==0:
            return [],[]
        
        
        vertices = [u]+col.tolist()
        
        v2i={i:j for i,j in zip(vertices, range(len(vertices)))}
        i2v={value:key for key, value in v2i.items()}
        
        kernel_dist = self.pairwise_link(self.X[vertices])
        
        gain_list=[(self.sign*kernel_dist[v2i[u],v2i[v.item()]],v.item(), e.item()) for v,e in zip(col,edge_index)] 
        #-1 selecting nearest
        #1 selecting farthest

        heapq.heapify(gain_list)
        #print(gain_list)

        S=[u]
        S_G=[]
        S_edge=[]
        S_index=[v2i[u]]
        
        lambda1 = self.lambda1 #top 25% with probability 1
        lambda2 = self.lambda2 #second 25% with probability 0.5         
        l1=math.ceil(len(col)*lambda1)
        l2=min(len(col)-l1,math.ceil(len(col)*lambda2))
        l3=max(0,int(len(col)-l1-l2))
        
        #print(len(col),l1, l2, l3)
        
        rank=1 #rank weight instead gain weight
        
        while(gain_list):
            (gain_v, v, e) = heapq.heappop(gain_list)
            gain_v = self.sign*gain_v #make it positive
            #print(gain_v, v)

            if len(gain_list)==0:                                    
                S.append(v)
                #if gain_v<1e-6:gain_v=1e-6#S_G.append(gain_v)#S_G.append(rank)
                if rank <= l1:S_G.append(self.w1)                
                elif rank<=l1+l2:S_G.append(self.w2)
                else:S_G.append(self.w3)

                rank+=1                
                S_edge.append(e)
                S_index.append(v2i[v])
                break
            elif len(gain_list)<l3:
                S.append(v)
                S_G.append(self.w3)
                rank+=1                
                S_edge.append(e)
                S_index.append(v2i[v])
                continue
            
            gain_v_update = self.sign*min(kernel_dist[v2i[v],S_index])
            
            #print("updated: ", S,v,gain_v_update, gain_v)
            (gain_v_second,v_second,_)=gain_list[0] #top
            gain_v_second = gain_v_second #make it positive

            if gain_v_update<=gain_v_second:
                
                gain_v_update = self.sign*gain_v_update
                
                if gain_v_update<1e-6:
                    gain_v_update=1e-6
                S.append(v)
                #S_G.append(gain_v_update)
                #S_G.append(rank)
                
                if rank<=l1:
                    S_G.append(self.w1)
                elif rank<=l1+l2:
                    S_G.append(self.w2)
                else:
                    S_G.append(self.w3)
                rank+=1
                
                S_edge.append(e)
                S_index.append(v2i[v])
            else:
                heapq.heappush(gain_list,(self.sign*gain_v_update,v, e))

        return S_G, S_edge
    
    #serial
    def get_submodular_weight(self):
        
        N = self.N
        #N = 1000
        
        if self.log:
            pbar = tqdm(total=N)
            pbar.set_description(f'Nodes')

        edge_weight=[]
        edge_index=[]
        
        test = 0

        for u in range(N):                
            weight, e_index = self.lazy_greedy_weight(u)
            edge_weight.extend(weight)
            edge_index.extend(e_index)
        
            #test += sum((np.array(weight)>1.0).astype(int))
            if self.log:
                pbar.update(1)
        
        #print(test)
        if self.log:
            pbar.close()
        
        assert len(edge_index)==self.E        
        
        weight=torch.zeros(len(edge_index))        
        weight[edge_index]=torch.Tensor(edge_weight)
        
        return weight
        
    
    def process_block(self, list_u):
        
        #print("Processing :",len(list_u), list_u[0], list_u[-1])
        
        edge_weight = []
        edge_index = []
        
        for u in list_u:        
            weight, e_index = self.lazy_greedy_weight(u)            
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            
        #print("Done :",len(list_u), list_u[0], list_u[-1])
            
        return edge_weight, edge_index, len(list_u)
        
    
    #multiprocessing
    def get_submodular_weight_multiproces(self):
        
        edge_weight=[]
        edge_index=[]        
        
        N = self.N
        num_blocks = NUM_PROCESSORS
        elem_size = int(N/num_blocks)
        
        nodes = np.arange(num_blocks*elem_size).reshape(num_blocks,-1).tolist()
        if num_blocks*elem_size<N:
            nodes.append(list(range(num_blocks*elem_size,N)))        
        
        pool_size = NUM_PROCESSORS        
        if self.log:
            print("Pool Size: ", pool_size)        
        pool = Pool(pool_size)
        
        if self.log:
            pbar = tqdm(total=N)
            pbar.set_description(f'Nodes')  
                
        for (weight, e_index, num_el) in pool.imap_unordered(self.process_block, nodes):            
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            
            if self.log:
                pbar.update(num_el)
        
        if self.log:
            pbar.close()
        
        assert len(edge_index)==self.E
                
        weight=torch.zeros(len(edge_index))        
        weight[edge_index]=torch.Tensor(edge_weight)        
        
        return weight
    
    
    def compute_weights(self):  
        
        if self.data.num_nodes<10000:
            weight = self.get_submodular_weight()
        else:      
            self.on_device=False
            self.X = self.data.x.to('cpu')        
            self.model = self.model.to('cpu')            
            self.model.eval()
            #weight = self.get_submodular_weight_multiproces()
            weight = self.get_submodular_weight()    
        
        return weight
    
# data, dataset = get_data('karate', log=False)
# submodular_weight = LinkSub(data, selfloop = False, log = True)
# submodular_weight.lazy_greedy_weight(1)
# #data.weight = submodular_weight.compute_weights()    

## Submodular Weight assignment

In [29]:
class LinkSub():
    
    def __init__(self, data, value='max', selfloop = False, log = True, worker=0, lambda1=0.25, lambda2=0.25, w1=1.0, w2=0.5, w3=0.1):
        
        self.N = N = data.num_nodes
        self.E = E = data.num_edges
        self.data = data
        self.log = log
        self.selfloop = selfloop
        
        self.lambda1=lambda1
        self.lambda2=lambda2
        self.w1=w1
        self.w2=w2
        self.w3=w3
    
        #self.X = data.x.to(device)
        self.X = self.data.x
        self.on_device=True        
        
        self.model = train_link(data, selfloop = selfloop, log= log, worker=worker)
        self.model.eval()        

        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(E, device=data.edge_index.device),
            sparse_sizes=(N, N))
        
        if self.log:
            print("value: ", value)
        
        self.value = value
        
    def pairwise_link(self, x):  
                
        n, f = x.shape
        
        x_col1 = x.repeat_interleave(n, dim=0)
        x_col2 = x.repeat(n,1)
        # print(x_col1, x_col2)
        
        if self.on_device:
            x_col1 = x_col1.to(device)
            x_col2 = x_col2.to(device)            
        
        with torch.no_grad():
            output = self.model(x_col1, x_col2).detach().cpu()
            output = torch.clamp(output, min=1e-3, max=1.0)
        #print(output.shape)
#         output = output.softmax(dim=1)
#         second_column = output[:,1].cpu()        
        #print(second_column)
        
        similarity_matrix = output.view(n,n)
        
#         print(similarity_matrix)
        
        return similarity_matrix.numpy()
    
    
    def lazy_greedy_weight(self,u):
    
        row, col, edge_index = self.adj[u,:].coo()
        
        if len(col)==0:
            return [],[]
                
        vertices = [u]+col.tolist()
        
        v2i={i:j for i,j in zip(vertices, range(len(vertices)))}
        i2v={value:key for key, value in v2i.items()}
                    
        kernel_dist = self.pairwise_link(self.X[vertices])
                
#         print(vertices)
#         print(row,col,edge_index)
#         print(self.data.x[vertices])
#         print(kernel_dist) 
        
        #convert to max heap by multiplying with -1
        gain_of_u = np.sum(kernel_dist[v2i[u],:])
        gain_list=[(-1*(sum(np.max(kernel_dist[[v2i[u],v2i[v.item()]],:],axis=0))-gain_of_u),v.item(), e.item()) for v,e in zip(col,edge_index)] 
        
#         print(gain_of_u)
#         print(gain_list)

        heapq.heapify(gain_list)
        #print(gain_list)
        
        S=[u]
        S_G=[]
        S_edge=[]
        S_index=[v2i[u]]
                        
        lambda1 = self.lambda1 #top 25% with probability 1
        lambda2 = self.lambda2 #second 25% with probability 0.5         
        l1=math.ceil(len(col)*lambda1)
        l2=min(len(col)-l1,math.ceil(len(col)*lambda2))
        l3=max(0,int(len(col)-l1-l2))
        
        #print(len(col),l1, l2, l3)
        
        rank=1 #rank weight instead gain weight
        
        S_index_gain=gain_of_u
        
        while(gain_list):
            (gain_v, v, e) = heapq.heappop(gain_list)
            gain_v = -1*gain_v #make it positive
            
            #print("popped: ",gain_v, v)                        

            if len(gain_list)==0:                                    
                S.append(v)
                if gain_v<1e-6:
                    gain_v=1e-6#S_G.append(gain_v)#S_G.append(rank)
                                        
                if rank <= l1:S_G.append(self.w1)                
                elif rank<=l1+l2:S_G.append(self.w2)
                else:S_G.append(self.w3)

                rank+=1                
                S_edge.append(e)
                S_index.append(v2i[v])
                break
                
            elif len(gain_list)<l3:
                S.append(v)
                S_G.append(self.w3)
                rank+=1                
                S_edge.append(e)
                S_index.append(v2i[v])
                continue
                
            gain_v_update = sum(np.max(kernel_dist[np.append(S_index,v2i[v]),:],axis=0))-S_index_gain
            
            #print("updated: ",gain_v_update, S_index_gain)
                    
            #print("updated: ", S,v,gain_v_update, gain_v)
            (gain_v_second,v_second,_)=gain_list[0] #top
            gain_v_second = -1*gain_v_second #make it positive

            if gain_v_update>=gain_v_second:                
                if gain_v_update<1e-6:
                    gain_v_update=1e-6
                    
                gain_v_update = -1*gain_v_update    
                
                S.append(v)
                #S_G.append(gain_v_update)
                #S_G.append(rank)
                S_index_gain = sum(np.max(kernel_dist[np.append(S_index,v2i[v]),:],axis=0))
                                
                
                    
                if rank<=l1:S_G.append(self.w1)
                elif rank<=l1+l2:S_G.append(self.w2)
                else:S_G.append(self.w3)
                rank+=1                
                S_edge.append(e)
                S_index.append(v2i[v])
                
                #print("Taken: ", S_index, S_index_gain)
                
            else:
                heapq.heappush(gain_list,(-1*gain_v_update,v, e))            
    
        return S_G, S_edge
            
    #serial
    def get_submodular_weight(self):
        
        N = self.N
        #N = 1000
        
        if self.log:
            pbar = tqdm(total=N)
            pbar.set_description(f'Nodes')

        edge_weight=[]
        edge_index=[]
        
        test = 0

        for u in range(N):                
            weight, e_index = self.lazy_greedy_weight(u)
            edge_weight.extend(weight)
            edge_index.extend(e_index)
        
            #test += sum((np.array(weight)>1.0).astype(int))
            if self.log:
                pbar.update(1)
        
        #print(test)
        if self.log:
            pbar.close()
        
        assert len(edge_index)==self.E        
        
        weight=torch.zeros(len(edge_index))        
        weight[edge_index]=torch.Tensor(edge_weight)
        
        return weight
        
    
    def process_block(self, list_u):
        
        #print("Processing :",len(list_u), list_u[0], list_u[-1])
        
        edge_weight = []
        edge_index = []
        
        for u in list_u:        
            weight, e_index = self.lazy_greedy_weight(u)            
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            
        #print("Done :",len(list_u), list_u[0], list_u[-1])
            
        return edge_weight, edge_index, len(list_u)
        
    
    #multiprocessing
    def get_submodular_weight_multiproces(self):
        
        edge_weight=[]
        edge_index=[]        
        
        N = self.N
        num_blocks = NUM_PROCESSORS
        elem_size = int(N/num_blocks)
        
        nodes = np.arange(num_blocks*elem_size).reshape(num_blocks,-1).tolist()
        if num_blocks*elem_size<N:
            nodes.append(list(range(num_blocks*elem_size,N)))        
        
        pool_size = NUM_PROCESSORS        
        if self.log:
            print("Pool Size: ", pool_size)        
        pool = Pool(pool_size)
        
        if self.log:
            pbar = tqdm(total=N)
            pbar.set_description(f'Nodes')  
                
        for (weight, e_index, num_el) in pool.imap_unordered(self.process_block, nodes):            
            edge_weight.extend(weight)
            edge_index.extend(e_index)
            
            if self.log:
                pbar.update(num_el)
        
        if self.log:
            pbar.close()
        
        assert len(edge_index)==self.E
                
        weight=torch.zeros(len(edge_index))        
        weight[edge_index]=torch.Tensor(edge_weight)        
        
        return weight
    
    
    def compute_weights(self):  
        
        if self.data.num_nodes<10000:
            weight = self.get_submodular_weight()
        else:      
            self.on_device=False
            self.X = self.data.x.to('cpu')        
            self.model = self.model.to('cpu')            
            self.model.eval()
            #weight = self.get_submodular_weight_multiproces()
            weight = self.get_submodular_weight()    
        
        return weight
    
# data, dataset = get_data('karate', log=False)
# submodular_weight = LinkSub(data, selfloop = False, log = True)
# submodular_weight.lazy_greedy_weight(1)
# data.weight = submodular_weight.compute_weights()    

In [30]:
# data, dataset = get_data('karate', log = False)
# submodular_weight = LinkSub(data, log = True)

In [31]:
# submodular_weight.lazy_greedy_weight(0)
#data.weight = submodular_weight.compute_weights()    

# Main

In [32]:
if __name__ == '__main__':  
    
    log = True
    datasetname = 'karate'
    #args.link_batch_size = 32
    args.epochs = 2
    
    data, dataset = get_data(datasetname, log=log, h_score=True)
    #data = generate_synthetic(data, d=100, h=0.25, train=0.1, random_state=1, log=log)
    #--------------------------#
    start = time.time() 
    data.weight = get_link_weight(data, selfloop = True, log = log, worker = 0)
    end = time.time()
    print("Execution time: ", end-start)
    #--------------------------#
    start = time.time()     
    nn_weight = LinkNN(data, value ='min', log = log) 
    data.weight = nn_weight.compute_weights()    
    end = time.time()
    print("Execution time: ", end-start)
    #--------------------------#
    start = time.time()    
    submodular_weight = LinkSub(data, value ='max', selfloop = True, log = log)    
    data.weight = submodular_weight.compute_weights()    
    end = time.time()
    print("Execution time: ", end-start)
    #--------------------------#    
#     if 'weight' in data:
#         cp_data= copy.deepcopy(data)
#         G = to_networkx(cp_data, to_undirected=True, edge_attrs=['weight'])
#         to_remove = [(a,b) for a, b, attrs in G.edges(data=True) if attrs["weight"] <1.0 ]
#         G.remove_edges_from(to_remove)
#         updated_data = from_networkx(G)
        
#         print(updated_data)
        
#         updated_data = from_networkx(G, group_edge_attrs=['weight'])
#         updated_data.weight = updated_data.edge_attr.view(-1)
        
#         row, col = updated_data.edge_index
#         updated_data.edge_index = torch.stack((torch.cat((row, col),dim=0), torch.cat((col, row),dim=0)),dim=0)
#         updated_data.weight = torch.cat((updated_data.weight, updated_data.weight),dim=0)
        
#         print("Node Homophily:", homophily(updated_data.edge_index, cp_data.y, method='node'))
#         print("Edge Homophily:", homophily(updated_data.edge_index, cp_data.y, method='edge'))
#         print("Edge_insensitive Homophily:", homophily(updated_data.edge_index, cp_data.y, method='edge_insensitive'))    
        
    None

Data directory:  ./Dataset/
Result directory: ./Dataset/RESULTS/

Dataset: KarateClub():
Number of graphs: 1
Number of features: 34
Number of classes: 4

Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34], val_mask=[34], test_mask=[34])
Number of nodes: 34
Number of edges: 156
Average node degree: 4.59
Number of training nodes: 4
Training node label rate: 0.12
Has isolated nodes: False
Has self-loops: False
Is undirected: True
N  34  E  156  d  4.588235294117647 0.8020520210266113 0.7564102411270142 0.6170591711997986 -0.4756128787994385 LinkModel(
  (MLP1): Linear(in_features=34, out_features=32, bias=True)
  (MLP3): Linear(in_features=64, out_features=1, bias=True)
)


Epoch 01: 100%|██████████| 231200/231200 [00:00<00:00, 351522.92it/s]


Epoch 001 Loss 0.1713 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 820732.97it/s]


	0.8744,	0.8744,	0.8582


Epoch 02: 100%|██████████| 231200/231200 [00:00<00:00, 330570.10it/s]


Epoch 002 Loss 0.0517 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 359883.66it/s]


	1.0000,	1.0000,	1.0000


Epoch 03: 100%|██████████| 231200/231200 [00:00<00:00, 428052.74it/s]


Epoch 003 Loss 0.0322 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 870894.03it/s] 


	1.0000,	1.0000,	1.0000


Epoch 04: 100%|██████████| 231200/231200 [00:00<00:00, 423541.24it/s]


Epoch 004 Loss 0.0290 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 748440.63it/s]


	1.0000,	1.0000,	1.0000


Epoch 05: 100%|██████████| 231200/231200 [00:00<00:00, 407108.67it/s]


Epoch 005 Loss 0.0270 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 757610.37it/s]


	1.0000,	1.0000,	1.0000


Epoch 06: 100%|██████████| 231200/231200 [00:00<00:00, 386223.38it/s]


Epoch 006 Loss 0.0257 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 622238.00it/s]


	1.0000,	1.0000,	1.0000


Epoch 07: 100%|██████████| 231200/231200 [00:00<00:00, 420031.20it/s]


Epoch 007 Loss 0.0238 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 899188.91it/s] 


	1.0000,	1.0000,	1.0000


Epoch 08: 100%|██████████| 231200/231200 [00:00<00:00, 370645.19it/s]


Epoch 008 Loss 0.0226 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 861798.35it/s]


	1.0000,	1.0000,	1.0000


Epoch 09: 100%|██████████| 231200/231200 [00:00<00:00, 424572.65it/s]


Epoch 009 Loss 0.0219 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 751748.38it/s]


	1.0000,	1.0000,	1.0000


Epoch 10: 100%|██████████| 231200/231200 [00:00<00:00, 425088.01it/s]


Epoch 010 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 780933.25it/s] 


	1.0000,	1.0000,	1.0000


Epoch 11: 100%|██████████| 231200/231200 [00:00<00:00, 380179.50it/s]


Epoch 011 Loss 0.0222 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 503538.02it/s]


	1.0000,	1.0000,	1.0000


Epoch 12: 100%|██████████| 231200/231200 [00:00<00:00, 319929.44it/s]


Epoch 012 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 730994.25it/s]


	1.0000,	1.0000,	1.0000


Epoch 13: 100%|██████████| 231200/231200 [00:00<00:00, 248543.70it/s]


Epoch 013 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 537363.87it/s]


	1.0000,	1.0000,	1.0000


Epoch 14: 100%|██████████| 231200/231200 [00:00<00:00, 412108.17it/s]


Epoch 014 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 685536.14it/s] 


	1.0000,	1.0000,	1.0000


Epoch 15: 100%|██████████| 231200/231200 [00:00<00:00, 322162.96it/s]


Epoch 015 Loss 0.0216 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 869871.41it/s] 


	1.0000,	1.0000,	1.0000


Epoch 16: 100%|██████████| 231200/231200 [00:00<00:00, 351399.36it/s]


Epoch 016 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 333308.73it/s]


	1.0000,	1.0000,	1.0000


Epoch 17: 100%|██████████| 231200/231200 [00:00<00:00, 442283.97it/s]


Epoch 017 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 818125.29it/s] 


	1.0000,	1.0000,	1.0000


Epoch 18: 100%|██████████| 231200/231200 [00:00<00:00, 297990.96it/s]


Epoch 018 Loss 0.0221 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 855198.34it/s] 


	1.0000,	1.0000,	1.0000


Epoch 19: 100%|██████████| 231200/231200 [00:00<00:00, 245173.03it/s]


Epoch 019 Loss 0.0221 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 753946.59it/s]


	1.0000,	1.0000,	1.0000


Epoch 20: 100%|██████████| 231200/231200 [00:00<00:00, 440339.48it/s]


Epoch 020 Loss 0.0216 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 623340.74it/s] 


	1.0000,	1.0000,	1.0000


Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 481133.48it/s]


Val	0.6772,	0.6772,	0.5469


Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 735793.67it/s]


Test	0.6779,	0.6779,	0.5477
Execution time:  32.23678493499756
LinkModel(
  (MLP1): Linear(in_features=34, out_features=32, bias=True)
  (MLP3): Linear(in_features=64, out_features=1, bias=True)
)


Epoch 01: 100%|██████████| 231200/231200 [00:00<00:00, 284711.83it/s]


Epoch 001 Loss 0.1562 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 409271.37it/s]


	0.9369,	0.9369,	0.9336


Epoch 02: 100%|██████████| 231200/231200 [00:00<00:00, 440110.65it/s]


Epoch 002 Loss 0.0522 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 802357.69it/s] 


	1.0000,	1.0000,	1.0000


Epoch 03: 100%|██████████| 231200/231200 [00:01<00:00, 212251.20it/s]


Epoch 003 Loss 0.0270 

Predicting: : 100%|██████████| 231200/231200 [00:01<00:00, 169082.20it/s]


	1.0000,	1.0000,	1.0000


Epoch 04: 100%|██████████| 231200/231200 [00:00<00:00, 260868.43it/s]


Epoch 004 Loss 0.0246 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 737262.69it/s]


	1.0000,	1.0000,	1.0000


Epoch 05: 100%|██████████| 231200/231200 [00:00<00:00, 241761.88it/s]


Epoch 005 Loss 0.0233 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 741888.98it/s]


	1.0000,	1.0000,	1.0000


Epoch 06: 100%|██████████| 231200/231200 [00:00<00:00, 243977.90it/s]


Epoch 006 Loss 0.0220 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 647117.96it/s]


	1.0000,	1.0000,	1.0000


Epoch 07: 100%|██████████| 231200/231200 [00:00<00:00, 274620.09it/s]


Epoch 007 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 373288.00it/s]


	1.0000,	1.0000,	1.0000


Epoch 08: 100%|██████████| 231200/231200 [00:00<00:00, 411488.95it/s]


Epoch 008 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 265533.08it/s]


	1.0000,	1.0000,	1.0000


Epoch 09: 100%|██████████| 231200/231200 [00:00<00:00, 234719.19it/s]


Epoch 009 Loss 0.0216 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 746525.42it/s]


	1.0000,	1.0000,	1.0000


Epoch 10: 100%|██████████| 231200/231200 [00:01<00:00, 215682.48it/s]


Epoch 010 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 747259.46it/s]


	1.0000,	1.0000,	1.0000


Epoch 11: 100%|██████████| 231200/231200 [00:01<00:00, 163920.90it/s]


Epoch 011 Loss 0.0215 

Predicting: : 100%|██████████| 231200/231200 [00:01<00:00, 175109.74it/s]


	1.0000,	1.0000,	1.0000


Epoch 12: 100%|██████████| 231200/231200 [00:00<00:00, 341800.09it/s]


Epoch 012 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 554380.27it/s]


	1.0000,	1.0000,	1.0000


Epoch 13: 100%|██████████| 231200/231200 [00:00<00:00, 290024.52it/s]


Epoch 013 Loss 0.0219 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 568638.81it/s]


	1.0000,	1.0000,	1.0000


Epoch 14: 100%|██████████| 231200/231200 [00:00<00:00, 276736.64it/s]


Epoch 014 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 369727.20it/s]


	1.0000,	1.0000,	1.0000


Epoch 15: 100%|██████████| 231200/231200 [00:00<00:00, 382735.02it/s]


Epoch 015 Loss 0.0220 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 284172.27it/s]


	1.0000,	1.0000,	1.0000


Epoch 16: 100%|██████████| 231200/231200 [00:00<00:00, 387494.92it/s]


Epoch 016 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 296465.27it/s]


	1.0000,	1.0000,	1.0000


Epoch 17: 100%|██████████| 231200/231200 [00:00<00:00, 240000.70it/s]


Epoch 017 Loss 0.0219 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 848791.51it/s] 


	1.0000,	1.0000,	1.0000


Epoch 18: 100%|██████████| 231200/231200 [00:01<00:00, 218945.52it/s]


Epoch 018 Loss 0.0215 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 878043.89it/s] 


	1.0000,	1.0000,	1.0000


Epoch 19: 100%|██████████| 231200/231200 [00:01<00:00, 187309.70it/s]


Epoch 019 Loss 0.0219 

Predicting: : 100%|██████████| 231200/231200 [00:01<00:00, 179508.34it/s]


	1.0000,	1.0000,	1.0000


Epoch 20: 100%|██████████| 231200/231200 [00:00<00:00, 289469.66it/s]


Epoch 020 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 527835.54it/s]


	1.0000,	1.0000,	1.0000


Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 384176.86it/s]


Val	0.6785,	0.6785,	0.5486


Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 282522.06it/s]


Test	0.6805,	0.6805,	0.5511


Nodes: 100%|██████████| 34/34 [00:00<00:00, 2272.61it/s]


Execution time:  35.98017597198486
LinkModel(
  (MLP1): Linear(in_features=34, out_features=32, bias=True)
  (MLP3): Linear(in_features=64, out_features=1, bias=True)
)


Epoch 01: 100%|██████████| 231200/231200 [00:00<00:00, 295239.53it/s]


Epoch 001 Loss 0.1575 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 378563.55it/s]


	0.9374,	0.9374,	0.9342


Epoch 02: 100%|██████████| 231200/231200 [00:00<00:00, 235803.51it/s]


Epoch 002 Loss 0.0518 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 870624.28it/s]


	1.0000,	1.0000,	1.0000


Epoch 03: 100%|██████████| 231200/231200 [00:00<00:00, 232672.99it/s]


Epoch 003 Loss 0.0319 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 742471.21it/s]


	1.0000,	1.0000,	1.0000


Epoch 04: 100%|██████████| 231200/231200 [00:00<00:00, 254693.03it/s]


Epoch 004 Loss 0.0274 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 513163.54it/s]


	1.0000,	1.0000,	1.0000


Epoch 05: 100%|██████████| 231200/231200 [00:00<00:00, 261325.59it/s]


Epoch 005 Loss 0.0259 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 363262.37it/s]


	1.0000,	1.0000,	1.0000


Epoch 06: 100%|██████████| 231200/231200 [00:01<00:00, 140488.78it/s]


Epoch 006 Loss 0.0244 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 265453.49it/s]


	1.0000,	1.0000,	1.0000


Epoch 07: 100%|██████████| 231200/231200 [00:00<00:00, 292809.79it/s]


Epoch 007 Loss 0.0238 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 362539.46it/s]


	1.0000,	1.0000,	1.0000


Epoch 08: 100%|██████████| 231200/231200 [00:00<00:00, 235087.52it/s]


Epoch 008 Loss 0.0235 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 703205.93it/s]


	1.0000,	1.0000,	1.0000


Epoch 09: 100%|██████████| 231200/231200 [00:01<00:00, 207263.94it/s]


Epoch 009 Loss 0.0238 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 730626.89it/s]


	1.0000,	1.0000,	1.0000


Epoch 10: 100%|██████████| 231200/231200 [00:00<00:00, 254029.84it/s]


Epoch 010 Loss 0.0237 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 750475.44it/s] 


	1.0000,	1.0000,	1.0000


Epoch 11: 100%|██████████| 231200/231200 [00:00<00:00, 282936.37it/s]


Epoch 011 Loss 0.0234 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 739499.79it/s]


	1.0000,	1.0000,	1.0000


Epoch 12: 100%|██████████| 231200/231200 [00:00<00:00, 257585.63it/s]


Epoch 012 Loss 0.0230 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 624512.94it/s]


	1.0000,	1.0000,	1.0000


Epoch 13: 100%|██████████| 231200/231200 [00:00<00:00, 239161.02it/s]


Epoch 013 Loss 0.0216 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 734353.86it/s]


	1.0000,	1.0000,	1.0000


Epoch 14: 100%|██████████| 231200/231200 [00:01<00:00, 159642.46it/s]


Epoch 014 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:01<00:00, 170531.56it/s]


	1.0000,	1.0000,	1.0000


Epoch 15: 100%|██████████| 231200/231200 [00:00<00:00, 431468.79it/s]


Epoch 015 Loss 0.0219 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 268639.29it/s]


	1.0000,	1.0000,	1.0000


Epoch 16: 100%|██████████| 231200/231200 [00:00<00:00, 451794.48it/s]


Epoch 016 Loss 0.0219 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 294640.83it/s]


	1.0000,	1.0000,	1.0000


Epoch 17: 100%|██████████| 231200/231200 [00:01<00:00, 221573.25it/s]


Epoch 017 Loss 0.0218 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 880154.99it/s] 


	1.0000,	1.0000,	1.0000


Epoch 18: 100%|██████████| 231200/231200 [00:00<00:00, 234337.12it/s]


Epoch 018 Loss 0.0222 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 867563.96it/s]


	1.0000,	1.0000,	1.0000


Epoch 19: 100%|██████████| 231200/231200 [00:01<00:00, 230966.13it/s]


Epoch 019 Loss 0.0217 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 880020.01it/s] 


	1.0000,	1.0000,	1.0000


Epoch 20: 100%|██████████| 231200/231200 [00:00<00:00, 238043.74it/s]


Epoch 020 Loss 0.0215 

Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 739034.84it/s]


	1.0000,	1.0000,	1.0000


Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 287044.35it/s]


Val	0.6790,	0.6790,	0.5492


Predicting: : 100%|██████████| 231200/231200 [00:00<00:00, 880282.83it/s] 


Test	0.6768,	0.6768,	0.5464
value:  max


Nodes: 100%|██████████| 34/34 [00:00<00:00, 502.19it/s]

Execution time:  35.348976373672485



