In [1]:
import dgl
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
from sklearn.metrics import roc_auc_score
import numpy as np
import torch.backends.cudnn as cudnn
import random
import pickle as pkl
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
from dgl.nn import GraphConv

In [2]:
# change dropout rate to 0.3, 0.4, 0.5
class GATModel(torch.nn.Module):
    # hyperparameters: 0.3, 0.4, 0.5
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_rate=0):
        super().__init__()
        self.conv1 = GATConv(in_dim, hidden_dim, num_heads=num_heads)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, num_heads)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.conv3 = GATConv(hidden_dim * num_heads, out_dim, num_heads)
    
    def forward(self, g, h):
        h = self.conv1(g, h).flatten(1)
        # h = F.elu(self.dropout1(h))
        h = self.conv2(g, h).flatten(1)
        # h = F.elu(self.dropout2(h))
        h = self.conv3(g, h).mean(1)
        return h


In [3]:
# load /home/qian/HNE/Model/GCN/Ethereum/matching_link/positive_test_edge_indices.pkl, /home/qian/HNE/Model/GCN/Ethereum/matching_link/positive_train_edge_indices.pkl, /home/qian/HNE/Model/GCN/Ethereum/matching_link/positive_validation_edge_indices.pkl, /home/qian/HNE/Model/GCN/Ethereum/matching_link/negative_test_edge_indices.pkl, /home/qian/HNE/Model/GCN/Ethereum/matching_link/negative_train_edge_indices.pkl, /home/qian/HNE/Model/GCN/Ethereum/matching_link/negative_validation_edge_indices.pkl
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/positive_test_edge_indices.pkl', 'rb') as f:
    positive_test_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/positive_train_edge_indices.pkl', 'rb') as f:
    positive_train_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/positive_validation_edge_indices.pkl', 'rb') as f:
    positive_validation_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/negative_test_edge_indices.pkl', 'rb') as f:
    negative_test_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/negative_train_edge_indices.pkl', 'rb') as f:
    negative_train_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/negative_validation_edge_indices.pkl', 'rb') as f:
    negative_validation_edge_indices = pkl.load(f)

In [4]:
# print first examples of these files
# print(positive_test_edge_indices[0]))
print(positive_train_edge_indices[0])
print(negative_train_edge_indices[1])

tensor([1222979, 2519947, 2545154,  ..., 1286348, 1368773, 1190305])
tensor([956188, 432713, 636185,  ...,  74506, 690609, 301482])


In [5]:
# load G_dgl_training
import dgl
with open('/home/qian/HNE/Model/GCN/Ethereum/matching_link/G_dgl_training', 'rb') as f:
    G_dgl_training = pkl.load(f)

In [6]:
# define generate_edge_embeddings function
def generate_edge_embeddings(h, edges):
    # Extract the source and target node indices from the edges
    src, dst = edges[0], edges[1]
    
    # Use the node indices to get the corresponding node embeddings
    src_embed = h[src]
    dst_embed = h[dst]

    # Concatenate the source and target node embeddings
    edge_embs = torch.cat([src_embed, dst_embed], dim=1)

    return edge_embs

In [7]:
# print some features examples
print(G_dgl_training.ndata['features'][0])

tensor([-0.0115, -0.0124, -0.0082, -0.0082, -0.0124, -0.0093, -0.0498, -0.1515,
        -0.0511, -0.0597, -0.0609, -0.0619, -0.0627, -0.0624, -0.0618, -0.0620])


In [10]:
# write a loop to run 5 times of the model and get the average performance
import copy
for i in range(5):
    model = GATModel(16,128,128,4,0)
    # Use the learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    criterion = nn.BCEWithLogitsLoss()

    best_val_loss = float('inf')
    best_model = None
    num_epochs = 200
    patience = 20
    early_stopping_counter = 0
    
    
    transform = nn.Sequential(
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 1)
    )
    
    for epoch in range(num_epochs):
        model.train()
        
        # forward pass
        logits = model(G_dgl_training, G_dgl_training.ndata['features'].float())
        
        # generate edge embeddings
        pos_train_edge_embs = generate_edge_embeddings(logits, positive_train_edge_indices)
        neg_train_edge_embs = generate_edge_embeddings(logits, negative_train_edge_indices)
        
        # concatenete positive and negative edge embeddings
        train_edge_embs = torch.cat([pos_train_edge_embs, neg_train_edge_embs], dim=0)
        train_edge_labels = torch.cat([torch.ones(pos_train_edge_embs.shape[0]), torch.zeros(neg_train_edge_embs.shape[0])], dim=0).unsqueeze(1)
        
        # print shapes of tensors for debugging
        # print(f"Train Edge Embeddings Shape: {train_edge_embs.shape}")
        # print(f"Train Edge Labels Shape: {train_edge_labels.shape}")
        
        # calculate loss
        loss = criterion(transform(train_edge_embs), train_edge_labels)
        print(f"Training Loss: {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        # validation
        model.eval()
        
        with torch.no_grad():
            # repeat the same process as above for validation samples
            logits = model(G_dgl_training, G_dgl_training.ndata['features'].float())
            pos_val_edge_embs = generate_edge_embeddings(logits, positive_validation_edge_indices)
            neg_val_edge_embs = generate_edge_embeddings(logits, negative_validation_edge_indices)
            val_edge_embs = torch.cat([pos_val_edge_embs, neg_val_edge_embs], dim=0)
            val_edge_labels = torch.cat([torch.ones(pos_val_edge_embs.shape[0]), torch.zeros(neg_val_edge_embs.shape[0])], dim=0).unsqueeze(1)
            # # print shapes of tensors for debugging
            # print(f"Validation Edge Embeddings Shape: {val_edge_embs.shape}")
            # print(f"Validation Edge Labels Shape: {val_edge_labels.shape}")

            val_loss = criterion(transform(val_edge_embs), val_edge_labels)
            print(f"Validation Loss: {val_loss.item()}")
            
            # early stopping based on validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                # add patience
                early_stopping_counter = 0
                # # save the best model
                best_model = copy.deepcopy(model)
            
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= patience:
                    print("Early Stopping!")
                    break
                
    # switch to evaluation mode
    best_model.eval()

    with torch.no_grad():
        # generate the embeddings using the best model
        logits = best_model(G_dgl_training, G_dgl_training.ndata['features'].float())

        # generate edge embeddings for the test samples
        pos_test_edge_embs = generate_edge_embeddings(logits, positive_test_edge_indices)
        neg_test_edge_embs = generate_edge_embeddings(logits, negative_test_edge_indices)

        # concatenate the positive and negative edge embeddings and labels
        test_edge_embs = torch.cat([pos_test_edge_embs, neg_test_edge_embs], dim=0)
        test_edge_labels = torch.cat([torch.ones(pos_test_edge_embs.shape[0]), torch.zeros(neg_test_edge_embs.shape[0])], dim=0)


        # test_loss = criterion(linear(test_edge_embs), val_edge_labels)
        # calculate predictions using the linear layer
        
        predictions = torch.sigmoid(transform(test_edge_embs))
        
        # reshape the predictions and the labels
        predictions = predictions.view(-1).cpu().numpy()
        test_edge_labels = test_edge_labels.cpu().numpy()

        # calculate scores and entropyloss
        
        
        auc = roc_auc_score(test_edge_labels, predictions)
        # here use 0.5 as threshold
        predictions_binary = (predictions > 0.5).astype(int)
        f1 = f1_score(test_edge_labels, predictions_binary)
        precision = precision_score(test_edge_labels, predictions_binary)
        recall = recall_score(test_edge_labels, predictions_binary)
        accuracy = accuracy_score(test_edge_labels, predictions_binary)

    print(f"AUC: {auc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Accuracy: {accuracy}")
    
    # write the result to a txt file
    with open('result.txt', 'a') as f:
        # write auc, f1, precision, recall
        f.write(f"AUC: {auc}, F1 Score: {f1}, Precision: {precision}, Recall: {recall}, Accuracy: {accuracy}\n")

  assert input.numel() == input.storage().size(), (


Training Loss: 3.1702635288238525
Validation Loss: 9.903468132019043
Training Loss: 10.059246063232422
Validation Loss: 4.155209541320801
Training Loss: 3.0584399700164795
Validation Loss: 5.698225021362305
Training Loss: 4.284885406494141
Validation Loss: 3.8864736557006836
Training Loss: 2.2817788124084473
Validation Loss: 3.975794792175293
Training Loss: 1.929756999015808
Validation Loss: 3.896003007888794
Training Loss: 1.6471501588821411
Validation Loss: 1.9297116994857788
Training Loss: 1.0205742120742798
Validation Loss: 1.5224859714508057
Training Loss: 1.1497074365615845
Validation Loss: 1.593841791152954
Training Loss: 0.898798406124115
Validation Loss: 1.9451152086257935
Training Loss: 0.7995600700378418
Validation Loss: 1.9717553853988647
Training Loss: 0.7641164064407349
Validation Loss: 1.6385269165039062
Training Loss: 0.7147717475891113
Validation Loss: 1.5149133205413818
Training Loss: 0.6821632385253906
Validation Loss: 1.4791579246520996
Training Loss: 0.610832750797

  assert input.numel() == input.storage().size(), (


Training Loss: 6.16180419921875
Validation Loss: 7.108376979827881
Training Loss: 7.633109092712402
Validation Loss: 3.8246872425079346
Training Loss: 3.116374969482422
Validation Loss: 3.5233707427978516
Training Loss: 3.1878011226654053
Validation Loss: 2.39612078666687
Training Loss: 2.022153615951538
Validation Loss: 2.03983473777771
Training Loss: 1.3719801902770996
Validation Loss: 1.792319416999817
Training Loss: 1.0018177032470703
Validation Loss: 1.7982491254806519
Training Loss: 0.9050778746604919
Validation Loss: 1.4646501541137695
Training Loss: 0.688761293888092
Validation Loss: 1.2798947095870972
Training Loss: 0.7240995168685913
Validation Loss: 1.3472564220428467
Training Loss: 0.7025464177131653
Validation Loss: 1.3828989267349243
Training Loss: 0.6878400444984436
Validation Loss: 1.3542176485061646
Training Loss: 0.6919074058532715
Validation Loss: 1.2734901905059814
Training Loss: 0.6515313982963562
Validation Loss: 1.2977912425994873
Training Loss: 0.623601913452148

  assert input.numel() == input.storage().size(), (


Training Loss: 6.189218997955322
Validation Loss: 2.387789249420166
Training Loss: 2.388380527496338
Validation Loss: 1.1135143041610718
Training Loss: 1.2227368354797363
Validation Loss: 0.9992151260375977
Training Loss: 0.8624670505523682
Validation Loss: 1.1037571430206299
Training Loss: 0.8950212001800537
Validation Loss: 1.3045669794082642
Training Loss: 0.7214423418045044
Validation Loss: 1.750486135482788
Training Loss: 1.0475847721099854
Validation Loss: 1.3827694654464722
Training Loss: 0.8172352313995361
Validation Loss: 1.3129127025604248
Training Loss: 0.8888106942176819
Validation Loss: 1.165693759918213
Training Loss: 0.6881834268569946
Validation Loss: 1.597733736038208
Training Loss: 1.0404070615768433
Validation Loss: 1.1208863258361816
Training Loss: 0.6256932616233826
Validation Loss: 1.0877881050109863
Training Loss: 0.7794337272644043
Validation Loss: 1.0544394254684448
Training Loss: 0.7152877449989319
Validation Loss: 1.1391608715057373
Training Loss: 0.611754715

  assert input.numel() == input.storage().size(), (


Training Loss: 9.67262077331543
Validation Loss: 2.9064114093780518
Training Loss: 3.043586492538452
Validation Loss: 1.8948558568954468
Training Loss: 1.9592279195785522
Validation Loss: 1.1174644231796265
Training Loss: 0.8880476355552673
Validation Loss: 1.7707282304763794
Training Loss: 1.1660184860229492
Validation Loss: 1.6212135553359985
Training Loss: 0.8819423317909241
Validation Loss: 1.2415385246276855
Training Loss: 0.739717960357666
Validation Loss: 0.9355064630508423
Training Loss: 0.7249707579612732
Validation Loss: 0.8371751308441162
Training Loss: 0.7921336889266968
Validation Loss: 0.8077995181083679
Training Loss: 0.6948301792144775
Validation Loss: 0.9801841974258423
Training Loss: 0.6463878154754639
Validation Loss: 1.1979671716690063
Training Loss: 0.6536133289337158
Validation Loss: 1.3404113054275513
Training Loss: 0.6606062054634094
Validation Loss: 1.3447331190109253
Training Loss: 0.6479756236076355
Validation Loss: 1.2382289171218872
Training Loss: 0.6131206

  assert input.numel() == input.storage().size(), (


Training Loss: 3.8992111682891846
Validation Loss: 21.197423934936523
Training Loss: 21.99502944946289
Validation Loss: 11.92469596862793
Training Loss: 12.328338623046875
Validation Loss: 2.9636151790618896
Training Loss: 3.0092477798461914
Validation Loss: 4.614392280578613
Training Loss: 4.927485466003418
Validation Loss: 4.6732096672058105
Training Loss: 4.30834436416626
Validation Loss: 3.4253625869750977
Training Loss: 1.9900323152542114
Validation Loss: 3.062199831008911
Training Loss: 1.7914032936096191
Validation Loss: 2.931201219558716
Training Loss: 2.161802291870117
Validation Loss: 2.476923704147339
Training Loss: 1.873234748840332
Validation Loss: 1.91412353515625
Training Loss: 1.1154625415802002
Validation Loss: 2.2070648670196533
Training Loss: 1.135870337486267
Validation Loss: 2.542001247406006
Training Loss: 1.5904510021209717
Validation Loss: 1.5394912958145142
Training Loss: 0.9923771619796753
Validation Loss: 1.6533523797988892
Training Loss: 1.4791573286056519
V

KeyboardInterrupt: 