In [30]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
from sklearn.exceptions import UndefinedMetricWarning, ConvergenceWarning
import warnings
import pickle
import copy
import numpy as np


In [31]:
import pickle as pkl
# read G_dgl_with_twitter_converted.pkl
with open('/home/qian/HNE/Model/GCN/Ethereum/G_dgl_with_twitter_features_converted.pkl', 'rb') as f:
    G_dgl_with_twitter_features_converted = pkl.load(f)

In [38]:
# again print some examples
print(G_dgl_with_twitter_features_converted.nodes[0].data['combine_normalized_pca_8_twitter_features'])

tensor([[ 1.0142e-01,  6.7766e-02,  6.7329e-02,  1.0352e-01,  6.7766e-02,
          6.7329e-02,  7.1084e-02,  4.1487e-01, -9.9902e-01,  1.0765e-02,
          4.2618e-02,  5.1163e-03,  2.6459e-04, -7.2428e-04, -5.4747e-04,
         -2.2076e-03]], dtype=torch.float64)


In [39]:
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats, dropout_rate):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size)
        self.conv2 = GraphConv(hidden_size, hidden_size)  
        self.conv3 = GraphConv(hidden_size, out_feats)  
        self.dropout = nn.Dropout(dropout_rate)
        self.batchnorm1 = nn.BatchNorm1d(hidden_size) 

    def forward(self, g, features):
        x = F.relu(self.conv1(g, features))
        x = self.dropout(x)  
        x = self.batchnorm1(x)
        x = F.relu(self.conv2(g, x))
        x = self.dropout(x)
        # x = self.batchnorm1(x)
        x = self.conv3(g, x)
        return x


In [34]:
# store all edge_indices in separate files
with open('/home/qian/HNE/Model/GCN/Ethereum/positive_train_edge_indices.pkl', 'rb') as f:
    positive_train_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/negative_train_edge_indices.pkl', 'rb') as f:
    negative_train_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/positive_validation_edge_indices.pkl', 'rb') as f:
    positive_validation_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/negative_validation_edge_indices.pkl', 'rb') as f:
    negative_validation_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/positive_test_edge_indices.pkl', 'rb') as f:
    positive_test_edge_indices = pkl.load(f)
    
with open('/home/qian/HNE/Model/GCN/Ethereum/negative_test_edge_indices.pkl', 'rb') as f:
    negative_test_edge_indices = pkl.load(f)

In [35]:
def generate_edge_embeddings(h, edges):
    # Extract the source and target node indices from the edges
    src, dst = edges[0], edges[1]
    
    # Use the node indices to get the corresponding node embeddings
    src_embed = h[src]
    dst_embed = h[dst]

    # Concatenate the source and target node embeddings
    edge_embs = torch.cat([src_embed, dst_embed], dim=1)

    return edge_embs

In [42]:
# write a five loop to get the result and document them
import copy

from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
linear = (
    nn.Linear(256, 1)
)

for i in range(5):
    model = GCN(16, 128, 128, 0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = nn.BCEWithLogitsLoss()
    best_val_loss = float('inf')
    best_model = None
    num_epochs = 200
    patience = 30
    early_stopping_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        
        # forward pass
        logits = model(G_dgl_with_twitter_features_converted, G_dgl_with_twitter_features_converted.ndata['combine_normalized_pca_8_twitter_features'].float())
        
        # generate edge embeddings
        pos_train_edge_embs = generate_edge_embeddings(logits, positive_train_edge_indices)
        neg_train_edge_embs = generate_edge_embeddings(logits, negative_train_edge_indices)
        
        # concatenete positive and negative edge embeddings
        train_edge_embs = torch.cat([pos_train_edge_embs, neg_train_edge_embs], dim=0)
        train_edge_labels = torch.cat([torch.ones(pos_train_edge_embs.shape[0]), torch.zeros(neg_train_edge_embs.shape[0])], dim=0).unsqueeze(1)
        
        # print shapes of tensors for debugging
        # print(f"Train Edge Embeddings Shape: {train_edge_embs.shape}")
        # print(f"Train Edge Labels Shape: {train_edge_labels.shape}")
        
        # calculate loss
        loss = criterion(linear(train_edge_embs), train_edge_labels)
        print(f"Training Loss: {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        # validation
        model.eval()
        
        with torch.no_grad():
            # repeat the same process as above for validation samples
            logits = model(G_dgl_with_twitter_features_converted, G_dgl_with_twitter_features_converted.ndata['combine_normalized_pca_8_twitter_features'].float())
            pos_val_edge_embs = generate_edge_embeddings(logits, positive_validation_edge_indices)
            neg_val_edge_embs = generate_edge_embeddings(logits, negative_validation_edge_indices)
            val_edge_embs = torch.cat([pos_val_edge_embs, neg_val_edge_embs], dim=0)
            val_edge_labels = torch.cat([torch.ones(pos_val_edge_embs.shape[0]), torch.zeros(neg_val_edge_embs.shape[0])], dim=0).unsqueeze(1)
            # print shapes of tensors for debugging
            # print(f"Validation Edge Embeddings Shape: {val_edge_embs.shape}")
            # print(f"Validation Edge Labels Shape: {val_edge_labels.shape}")

            val_loss = criterion(linear(val_edge_embs), val_edge_labels)
            print(f"Validation Loss: {val_loss.item()}")
            
            # early stopping based on validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stopping_counter = 0
                # save the best model
                best_model = copy.deepcopy(model)
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= patience:
                    print('early stopping due to validation loss not improving')
                    break
                
    # switch to evaluation mode
    best_model.eval()

    with torch.no_grad():
        # generate the embeddings using the best model
        logits = best_model(G_dgl_with_twitter_features_converted, G_dgl_with_twitter_features_converted.ndata['combine_normalized_pca_8_twitter_features'].float())

        # generate edge embeddings for the test samples
        pos_test_edge_embs = generate_edge_embeddings(logits, positive_test_edge_indices)
        neg_test_edge_embs = generate_edge_embeddings(logits, negative_test_edge_indices)

        # concatenate the positive and negative edge embeddings and labels
        test_edge_embs = torch.cat([pos_test_edge_embs, neg_test_edge_embs], dim=0)
        test_edge_labels = torch.cat([torch.ones(pos_test_edge_embs.shape[0]), torch.zeros(neg_test_edge_embs.shape[0])], dim=0)


        # test_loss = criterion(linear(test_edge_embs), val_edge_labels)
        # calculate predictions using the linear layer
        
        predictions = torch.sigmoid(linear(test_edge_embs))
        
        # reshape the predictions and the labels
        predictions = predictions.view(-1).cpu().numpy()
        test_edge_labels = test_edge_labels.cpu().numpy()

        # calculate scores and entropyloss
        
        
        auc = roc_auc_score(test_edge_labels, predictions)
        predictions_binary = (predictions > 0.5).astype(int)
        f1 = f1_score(test_edge_labels, predictions_binary)
        precision = precision_score(test_edge_labels, predictions_binary)
        recall = recall_score(test_edge_labels, predictions_binary)
        accuracy = accuracy_score(test_edge_labels, predictions_binary)

        print(f"AUC: {auc}")
        print(f"F1 Score: {f1}")
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        print(f"Accuracy: {accuracy}")
    # print accuracy, f1, precision, recall, auc-roc
    # print(f"Test Loss: {test_loss.item()}")
        with open('results_with_twitter.txt', 'a') as f:
            f.write(f"AUC: {auc}\n")
            f.write(f"F1 Score: {f1}\n")
            f.write(f"Precision: {precision}\n")
            f.write(f"Recall: {recall}\n")
            f.write(f"Accuracy: {accuracy}\n")
            f.write('\n')
    

  assert input.numel() == input.storage().size(), (


Training Loss: 1.7386268377304077
Validation Loss: 0.6901754140853882
Training Loss: 1.039158821105957
Validation Loss: 0.6898903250694275
Training Loss: 1.0595930814743042
Validation Loss: 0.689062774181366
Training Loss: 1.2785909175872803
Validation Loss: 0.6881284713745117
Training Loss: 0.9178532361984253
Validation Loss: 0.6880518794059753
Training Loss: 0.8904348611831665
Validation Loss: 0.6888184547424316
Training Loss: 1.2134063243865967
Validation Loss: 0.688949465751648
Training Loss: 1.3337510824203491
Validation Loss: 0.6881977319717407
Training Loss: 1.2286165952682495
Validation Loss: 0.6870120763778687
Training Loss: 1.0566109418869019
Validation Loss: 0.6862386465072632
Training Loss: 0.997449517250061
Validation Loss: 0.6859010457992554
Training Loss: 0.9251148700714111
Validation Loss: 0.6858291625976562
Training Loss: 1.0010982751846313
Validation Loss: 0.685610830783844
Training Loss: 1.0663191080093384
Validation Loss: 0.6848577260971069
Training Loss: 0.87840402

  assert input.numel() == input.storage().size(), (


Training Loss: 1.426124095916748
Validation Loss: 0.6959084272384644
Training Loss: 1.343600869178772
Validation Loss: 0.6954771280288696
Training Loss: 1.446946144104004
Validation Loss: 0.695195198059082
Training Loss: 1.5230010747909546
Validation Loss: 0.6944419741630554
Training Loss: 1.5867723226547241
Validation Loss: 0.69330894947052
Training Loss: 1.465413212776184
Validation Loss: 0.6927035450935364
Training Loss: 1.1155540943145752
Validation Loss: 0.6933996677398682
Training Loss: 1.148374319076538
Validation Loss: 0.6936255693435669
Training Loss: 1.1377243995666504
Validation Loss: 0.69321608543396
Training Loss: 1.102708339691162
Validation Loss: 0.6925396919250488
Training Loss: 1.2515819072723389
Validation Loss: 0.6912766098976135
Training Loss: 1.1340980529785156
Validation Loss: 0.6897817850112915
Training Loss: 0.8324885964393616
Validation Loss: 0.688910186290741
Training Loss: 0.7680529952049255
Validation Loss: 0.6889806985855103
Training Loss: 0.979704916477203

  assert input.numel() == input.storage().size(), (


Training Loss: 1.7202366590499878
Validation Loss: 0.6918795108795166
Training Loss: 0.9200410842895508
Validation Loss: 0.6908190846443176
Training Loss: 1.0656102895736694
Validation Loss: 0.6908101439476013
Training Loss: 1.0868250131607056
Validation Loss: 0.6917164325714111
Training Loss: 0.8863468766212463
Validation Loss: 0.6937862038612366
Training Loss: 0.7986826300621033
Validation Loss: 0.694739818572998
Training Loss: 1.0155960321426392
Validation Loss: 0.6940116882324219
Training Loss: 0.9140011668205261
Validation Loss: 0.6921234726905823
Training Loss: 0.954328179359436
Validation Loss: 0.690581202507019
Training Loss: 0.8950116634368896
Validation Loss: 0.6898202300071716
Training Loss: 1.0804890394210815
Validation Loss: 0.6899120211601257
Training Loss: 1.1129658222198486
Validation Loss: 0.6903224587440491
Training Loss: 1.0290710926055908
Validation Loss: 0.6912075877189636
Training Loss: 0.8680955171585083
Validation Loss: 0.6931306719779968
Training Loss: 0.928642

  assert input.numel() == input.storage().size(), (


Training Loss: 1.243715763092041
Validation Loss: 0.6925697922706604
Training Loss: 0.9422377347946167
Validation Loss: 0.6912044882774353
Training Loss: 1.2430756092071533
Validation Loss: 0.6900621652603149
Training Loss: 1.1531386375427246
Validation Loss: 0.6888607740402222
Training Loss: 1.152767300605774
Validation Loss: 0.6883586645126343
Training Loss: 1.1039377450942993
Validation Loss: 0.6884068846702576
Training Loss: 1.0549745559692383
Validation Loss: 0.6882723569869995
Training Loss: 0.953557014465332
Validation Loss: 0.6884655356407166
Training Loss: 0.6823407411575317
Validation Loss: 0.6900383234024048
Training Loss: 0.9055289626121521
Validation Loss: 0.691127598285675
Training Loss: 0.9636735320091248
Validation Loss: 0.6914748549461365
Training Loss: 0.9682511687278748
Validation Loss: 0.6905336380004883
Training Loss: 0.8904240727424622
Validation Loss: 0.688751757144928
Training Loss: 0.7950706481933594
Validation Loss: 0.6866374015808105
Training Loss: 0.68364411

  assert input.numel() == input.storage().size(), (


Training Loss: 1.4716542959213257
Validation Loss: 0.6937825083732605
Training Loss: 1.1478114128112793
Validation Loss: 0.6986373662948608
Training Loss: 1.6590659618377686
Validation Loss: 0.6987977623939514
Training Loss: 1.725767970085144
Validation Loss: 0.6968786716461182
Training Loss: 1.4449292421340942
Validation Loss: 0.6945701837539673
Training Loss: 1.079770803451538
Validation Loss: 0.6927202343940735
Training Loss: 1.1108312606811523
Validation Loss: 0.6910439729690552
Training Loss: 1.1238659620285034
Validation Loss: 0.690416157245636
Training Loss: 1.086656093597412
Validation Loss: 0.6902177929878235
Training Loss: 1.1545814275741577
Validation Loss: 0.6900887489318848
Training Loss: 1.212859869003296
Validation Loss: 0.6896464228630066
Training Loss: 1.1433355808258057
Validation Loss: 0.6893115043640137
Training Loss: 0.9726083874702454
Validation Loss: 0.6898865699768066
Training Loss: 0.915260910987854
Validation Loss: 0.6906912326812744
Training Loss: 0.886118829