import torch_geometric

torch_geometric.__version__ >>  '1.5.0'

torch.__version__ >> '1.5.0'

torch.version.cuda >> '10.2'

torch.cuda.get_device_name(0) >> 'TITAN RTX'

## load data

In [1]:
from torch_geometric.datasets import Planetoid
import os

In [2]:
path = '/home/lixue/Ipy5/IrregularMessagePassing/data'

In [None]:
path = '/data'

In [None]:
Cora = Planetoid(root=path, name='Cora')
cora = Cora[0]

In [3]:
C_S = Planetoid(root=path, name='CiteSeer')
cite_seer = C_S[0]

In [None]:
PM = Planetoid(root=path, name='PubMed')
pub_med = PM[0]

In [None]:
air = torch.load( os.path.join(path,'Air_USA.pt') )

In [None]:
db = torch.load( os.path.join(path,'dblp.pt') )

In [None]:
d_cora = torch.load( os.path.join(path,'dgl_cora.pt'))

In [4]:
import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.utils import remove_self_loops, add_self_loops,softmax
from torch_geometric.nn.conv import MessagePassing
import time
import numpy as np
from utils import EarlyStopping

In [5]:
data = cite_seer
num_features = data.num_features
num_classes = len(set(data.y.numpy()))

In [None]:
#two ways for pubmed normalization
#1:col normalize is better
pub_max = data.x.max(dim=0,keepdim=True)[0]
pub_min = data.x.min(dim=0,keepdim=True)[0]
pub_x = (data.x - pub_min)/(pub_max-pub_min + 1e-12)
data.x = pub_x

In [None]:
#2 row normalize 
pub_max = data.x.max(dim=1,keepdim=True)[0]
pub_min = data.x.min(dim=1,keepdim=True)[0]
pub_x = (data.x - pub_min)/(pub_max-pub_min + 1e-12)
data.x = pub_x

## Random Attention Network

In [6]:
class RAN_Conv(MessagePassing):

    def __init__(self, num_F, num_C,k=2,bias=True, cached=True, improved=False):
        super(RAN_Conv, self).__init__(aggr='add')
        print('random process...')
        self.in_channels = num_F
        self.out_channels = num_C
        self.k = k
        self.bias = bias
        self.cached = cached
        self.improved = improved
        
        
        self.lin = nn.Linear(num_F, num_C, bias=bias)

        self.att = torch.rand(1, 2*num_F).cuda()
        
        #==============Start::Manually choose Lp normalization====================
        L2_att=1
        if L2_att:
          self.att = F.normalize(self.att,dim=1,p=2.0)

        L1_att=0
        if L1_att:
          self.att = F.normalize(self.att,dim=1,p=1.0)
        #=========================End::L2 by default=============================
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.cached_result = None
        self.cached_num_edges = None
    
    def My_norms(self, x, edge_index, num_nodes):
        edge_index_j = edge_index[0]
        edge_index_i = edge_index[1]
        x_j = x[edge_index[0]]
        x_i = x[edge_index[1]]
        
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        
        if self.improved:
          self_weight = edge_index_i== edge_index_j
          alpha[self_weight] += 1.0

        #========================Start::No weight normlization by default======================
        #L2 Edge weight norm
        L2= 0
        if L2:
          print('L2')
          n_sum = scatter(alpha**2, edge_index_i, dim=-1, dim_size=x.size(0),reduce='add')
          n_sqrt = torch.sqrt(n_sum)
          n_sqrt = n_sqrt[edge_index_i]
          alpha = alpha/(n_sqrt+1e-12)

        #L1 Edge weight norm
        L1 = 0
        if L1:
          print('L1')
          n_sum = scatter(abs(alpha), edge_index_i, dim=-1, dim_size=x.size(0),reduce='add')
          n_sum = n_sum[edge_index_i]
          alpha = alpha/(n_sum+1e-12)
        
        #max-min Edge weight norm
        L_inf = 0
        if L_inf:
          print('L_inf')
          n_max = scatter(abs(alpha), edge_index_i, dim=-1, dim_size=x.size(0),reduce='max')
          n_min = scatter(abs(alpha), edge_index_i, dim=-1, dim_size=x.size(0),reduce='min')
          n_mm = n_max - n_min
          n_mm = n_mm[edge_index_i]
          alpha = alpha - n_min[edge_index_i]
          alpha = alpha/(n_mm+1e-12)
        #=====================End::For Attention Norm vs Edge Weight Norm test================
        
        
        alpha = softmax(alpha, edge_index_i, num_nodes=num_nodes)
                
        return alpha
    
    
    def forward(self, x, edge_index, size=None):
        """"""

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

            norm = self.My_norms(x, edge_index, x.size(0))
            for k in range(self.k):
                x = self.propagate(edge_index, x=x, norm=norm) 
                
            self.cached_result = x     
            
        if self.cached:
            x = self.lin(self.cached_result)

        return x


    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    
    def __repr__(self):
        return '{}({}, {}, num_iter={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.k)

In [7]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = RAN_Conv(
            num_features, num_classes,k=4, improved=True)

    def forward(self):
        x, edge_index = data.x, data.edge_index
        x = self.conv(x, edge_index)
        return F.log_softmax(x, dim=1)

In [8]:
das = []
for _ in range(10):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model, data = Net().to(device), data.to(device)

  #cora_pytorch k = 4
  #optimizer = torch.optim.Adam(model.parameters(), lr=0.6, weight_decay=5e-3)

  #pubmed  k = 2
  #optimizer = torch.optim.Adam(model.parameters(), lr=0.4, weight_decay=5e-4)
  #optimizer = torch.optim.Adam(model.parameters(), lr=0.5, weight_decay=5e-3)

  #cite_seer k = 4
  optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-1)
 
  #air 
  #optimizer = torch.optim.Adam(model.parameters(), lr=0.99, weight_decay=5e-5)
    
  #dblp
  #optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-4)
    
  #dgl_cora
  #optimizer = torch.optim.Adam(model.parameters(), lr=0.99, weight_decay=5e-5)
    
  def train():
      model.train()
      optimizer.zero_grad()
      loss = F.nll_loss(model()[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

  def test(mask):
      model.eval()
      with torch.no_grad():
          logits = model()
          pred = logits[mask].max(1)[1]
          acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
          
      return acc


  early_stop = True
  if early_stop:
      stopper = EarlyStopping(patience=100)
  dur = []
  #print(model)

  for epoch in range(1, 101):
      model.train()
      if epoch >= 3:
          t0 = time.time()
      loss = train()
      
      if epoch >= 3:
          dur.append(time.time() - t0)
      
      val_acc = test(data.val_mask)
      
      if early_stop:
          if stopper.step(val_acc, model):   
              break
      
      '''print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
                " ValAcc {:.4f}".
                format(epoch, np.mean(dur), loss, val_acc))'''
      
  #print() # 空出一行
      
  if early_stop:
      model.load_state_dict(torch.load('es_checkpoint.pt'))
  test_acc = test(data.test_mask)
  das.append(test_acc)
  print("Test Accuracy {:.4f}".format(test_acc))

random process...
Test Accuracy 0.7290
random process...
Test Accuracy 0.7240
random process...
Test Accuracy 0.7300
random process...
Test Accuracy 0.7320
random process...
Test Accuracy 0.7250
random process...
Test Accuracy 0.7270
random process...
Test Accuracy 0.7250
random process...
Test Accuracy 0.7210
random process...
Test Accuracy 0.7220
random process...
Test Accuracy 0.7260


In [9]:
np.mean(das)*100

72.61

In [10]:
np.std(das)*100

0.33000000000000024

### <font color='blue'>Try to tune k and you can get a better result than that of my paper!</font>