In [None]:
import pandas as pd
import csv
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pickle

In [None]:

with open("temporal_graph_filtered_weighted.gpickle", "rb") as f:
    G = pickle.load(f)

In [None]:

with open("digraph_temporal_LP.gpickle", "rb") as f:
    G_LP = pickle.load(f)

In [None]:
# print the information of G_LP
print(G_LP.number_of_nodes())
print(G_LP.number_of_edges())

In [None]:
# print first 10 edges of G_LP
print(list(G_LP.edges(data=True))[0:10])

In [None]:
# Get the largest weakly connected component of the graph, with all edges data
Gcc = sorted(nx.weakly_connected_components(G_LP), key=len, reverse=True)

# Get the largest weakly connected component of the graph
G_LP_connected = G_LP.subgraph(Gcc[0])

In [None]:
# print nodes and edges number of G_LP_connected
print(G_LP_connected.number_of_nodes())
print(G_LP_connected.number_of_edges())

In [None]:
# print some edges of G_LP_connected
print(list(G_LP_connected.edges(data=True))[0:10])

In [None]:
# randomly split all nodes into train, validation and test sets
# train:validation:test = 0.7:0.1:0.2

# get all nodes
nodes = list(G_LP_connected.nodes())

# randomly shuffle nodes
np.random.shuffle(nodes)

# split nodes into train, validation and test sets
train_nodes = nodes[0:int(len(nodes)*0.7)]

validation_nodes = nodes[int(len(nodes)*0.7):int(len(nodes)*0.8)]

test_nodes = nodes[int(len(nodes)*0.8):]

In [None]:
# load eoa_addr_list.txt
eoa_addr_list = []
with open("eoa_addr_list.txt", "r") as f:
    for line in f:
        eoa_addr_list.append(line.strip())
        

In [None]:
# for train_nodes, validation_nodes and test_nodes get all nodes appear in the phishing account list
train_nodes_phishing = []
for node in train_nodes:
    if node in eoa_addr_list:
        train_nodes_phishing.append(node)
        
validation_nodes_phishing = []
for node in validation_nodes:
    if node in eoa_addr_list:
        validation_nodes_phishing.append(node)
        
test_nodes_phishing = []
for node in test_nodes:
    if node in eoa_addr_list:
        test_nodes_phishing.append(node)

In [None]:
# print how many phishing accounts in train, validation and test sets
print(len(train_nodes_phishing))
print(len(validation_nodes_phishing))
print(len(test_nodes_phishing))

In [None]:
# store train_nodes_phishing, validation_nodes_phishing and test_nodes_phishing to txt files
with open("train_nodes_wt.txt", "w") as f:
    for node in train_nodes_phishing:
        f.write(node + "\n")
        
with open("validation_nodes_wt.txt", "w") as f:
    for node in validation_nodes_phishing:
        f.write(node + "\n")
        
with open("test_nodes_wt.txt", "w") as f:
    for node in test_nodes_phishing:
        f.write(node + "\n")

In [None]:

train_nodes_normal = []
# the same number of normal accounts as train_nodes_phishing
train_node_num = len(train_nodes_phishing)

for node in train_nodes:
    if node not in eoa_addr_list:
        train_nodes_normal.append(node)
        if len(train_nodes_normal) == train_node_num:
            break

In [None]:
# print some examples in train_nodes_normal
print(train_nodes_normal[0:10])

In [None]:
# store train_nodes_normal to txt file
with open("train_nodes_normal.txt", "w") as f:
    for node in train_nodes_normal:
        f.write(node + "\n")

In [None]:
# do the same thing for validation_nodes_phishing and test_nodes_phishing
validation_nodes_normal = []

validation_node_num = len(validation_nodes_phishing)

for node in validation_nodes:
    if node not in eoa_addr_list:
        validation_nodes_normal.append(node)
        if len(validation_nodes_normal) == validation_node_num:
            break

In [None]:
# store validation_nodes_normal to txt file
with open("validation_nodes_normal.txt", "w") as f:
    for node in validation_nodes_normal:
        f.write(node + "\n")

In [None]:
# do the same thing for test_nodes_phishing
test_nodes_normal = []

test_node_num = len(test_nodes_phishing)

for node in test_nodes:
    if node not in eoa_addr_list:
        test_nodes_normal.append(node)
        if len(test_nodes_normal) == test_node_num:
            break

In [None]:
# store test_nodes_normal to txt file
with open("test_nodes_normal.txt", "w") as f:
    for node in test_nodes_normal:
        f.write(node + "\n")

In [None]:
# print some examples of phishing accounts in train, validation and test sets
print(train_nodes_phishing[0:10])
print(validation_nodes_phishing[0:10])
print(test_nodes_phishing[0:10])

In [None]:
# for G_LP_connected, add node attributes: label, train_pos, train_neg, validation_pos, validation_neg, test_pos, test_neg
# label: 1 for phishing accounts, 0 for normal accounts

# add node attribute: label
for node in G_LP_connected.nodes():
    if node in eoa_addr_list:
        G_LP_connected.nodes[node]["label"] = 1
    else:
        G_LP_connected.nodes[node]["label"] = 0

In [None]:
# add node attribute: train_pos, train_neg, validation_pos, validation_neg, test_pos, test_neg
for node in G_LP_connected.nodes():
    if node in train_nodes_phishing:
        G_LP_connected.nodes[node]["train_pos"] = 1
    else:
        G_LP_connected.nodes[node]["train_pos"] = 0

In [None]:
# add node attributes: train_pos, train_neg, validation_pos, validation_neg, test_pos, test_neg
for node in G_LP_connected.nodes():
    # Add attribute: train_pos
    if node in train_nodes_phishing:
        G_LP_connected.nodes[node]["train_pos"] = 1
    else:
        G_LP_connected.nodes[node]["train_pos"] = 0

    # Add attribute: train_neg
    if node in train_nodes_normal:
        G_LP_connected.nodes[node]["train_neg"] = 1
    else:
        G_LP_connected.nodes[node]["train_neg"] = 0

    # Add attribute: validation_pos
    if node in validation_nodes_phishing:
        G_LP_connected.nodes[node]["validation_pos"] = 1
    else:
        G_LP_connected.nodes[node]["validation_pos"] = 0

    # Add attribute: validation_neg
    if node in validation_nodes_normal:
        G_LP_connected.nodes[node]["validation_neg"] = 1
    else:
        G_LP_connected.nodes[node]["validation_neg"] = 0

    # Add attribute: test_pos
    if node in test_nodes_phishing:
        G_LP_connected.nodes[node]["test_pos"] = 1
    else:
        G_LP_connected.nodes[node]["test_pos"] = 0

    # Add attribute: test_neg
    if node in test_nodes_normal:
        G_LP_connected.nodes[node]["test_neg"] = 1
    else:
        G_LP_connected.nodes[node]["test_neg"] = 0


In [None]:
# Convert NodeView to a list
node_list = list(G_LP_connected.nodes())

# Print attributes of the first 5 nodes
for node in node_list[:5]:
    print(G_LP_connected.nodes[node])


In [None]:
# define a function to get all structural features of nodes
def calculate_average_neighbor_degree(node, graph):
    neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
    total_neighbor_degree = sum([graph.degree(neighbor) for neighbor in neighbors])
    
    if len(neighbors) > 0:
        average_neighbor_degree = total_neighbor_degree / len(neighbors)
    else:
        average_neighbor_degree = 0
    
    return average_neighbor_degree


def node_feature_func(node, G):
    # Get neighbors
    neighbors = list(G.neighbors(node))  # Out-neighbors
    in_neighbors = list(G.predecessors(node))  # In-neighbors
    all_neighbors = list(set(neighbors + in_neighbors))  # Both in-neighbors and out-neighbors

    degree = G.degree(node)
    in_degree = G.in_degree(node)
    out_degree = G.out_degree(node)

    # Check if neighbors exist
    if all_neighbors:
        in_transactions = [G[neighbor][node].get('weight', 1) for neighbor in in_neighbors]
        out_transactions = [G[node][neighbor].get('weight', 1) for neighbor in neighbors]
        all_transactions = in_transactions + out_transactions
        max_transactions = max(all_transactions)
    else:
        max_transactions = 0

    avg_neighbor_degree = calculate_average_neighbor_degree(node, G)

    return [degree, in_degree, out_degree, len(neighbors), len(in_neighbors), len(all_neighbors), max_transactions, avg_neighbor_degree]


In [None]:
from tqdm import tqdm

# apply the function to all nodes in G_LP_connected
for node in tqdm(G_LP_connected.nodes(), total=len(G_LP_connected)):
    G_LP_connected.nodes[node]['features'] = node_feature_func(node, G_LP_connected)


In [None]:
# print some examples of nodes and their features
print(list(G_LP_connected.nodes(data=True))[0:5])

In [None]:
from sklearn.preprocessing import MinMaxScaler

# Convert your list to a NumPy array for Scikit-Learn, if it isn't already
node_features = np.array(node_features)

# Initialize a scaler
scaler = MinMaxScaler()

# Fit the scaler to your data and transform it
normalized_features = scaler.fit_transform(node_features)

# Store the normalized features back into the nodes
for node, features in zip(node_list, normalized_features):
    G_LP_connected.nodes[node]["features"] = features


In [None]:
# for G_train, G_val, G_test, add node labels
# label the nodes of G_LP_connected, if appear in eoa_addr_list.txt, label 1, else label 0
# load eoa_addr_list.txt
eoa_addr_list = []
with open("eoa_addr_list.txt", "r") as f:
    for line in f:
        eoa_addr_list.append(line.strip())
        
# label the nodes of G_train, G_val, G_test, if appear in eoa_addr_list.txt, label 1, else label 0
for node in G_train.nodes():
    if node in eoa_addr_list:
        G_train.nodes[node]['label'] = 1
        print(node)
    else:
        G_train.nodes[node]['label'] = 0
        

for node in G_val.nodes():
    if node in eoa_addr_list:
        G_val.nodes[node]['label'] = 1
        print(node)
    else:
        G_val.nodes[node]['label'] = 0
        

for node in G_test.nodes():
    if node in eoa_addr_list:
        G_test.nodes[node]['label'] = 1
        print(node)
    else:
        G_test.nodes[node]['label'] = 0
        

In [None]:
import networkx as nx

# Get the largest weakly connected component of the graph, with all edges data
Gcc = sorted(nx.weakly_connected_components(G_LP), key=len, reverse=True)

# Get the largest weakly connected component of the graph
G_LP_connected = G_LP.subgraph(Gcc[0])

# Convert edges to list for easier manipulation
edges = list(G_LP_connected.edges(data=True))

# Extract block numbers from edges
block_numbers = [d['block_number'] for _, _, d in edges]
sorted_edge_indexes = np.argsort(block_numbers)  # Sort edges by block numbers

# Compute the sizes of train/validation/test sets
total_edges = len(edges)
train_size = int(total_edges * 0.7)
val_size = int(total_edges * 0.1)
test_size = total_edges - train_size - val_size  # Rest of the edges

# Divide edges into train/validation/test
train_edges = [edges[i] for i in sorted_edge_indexes[:train_size]]
val_edges = [edges[i] for i in sorted_edge_indexes[train_size:train_size+val_size]]
test_edges = [edges[i] for i in sorted_edge_indexes[train_size+val_size:]]

# Create train/validation/test graphs
G_train = nx.DiGraph()
G_val = nx.DiGraph()
G_test = nx.DiGraph()

for e in train_edges:
    G_train.add_edge(e[0], e[1], block_number=e[2]['block_number'], weight=e[2]['weight'])

for e in val_edges:
    G_val.add_edge(e[0], e[1], block_number=e[2]['block_number'], weight=e[2]['weight'])

for e in test_edges:
    G_test.add_edge(e[0], e[1], block_number=e[2]['block_number'], weight=e[2]['weight'])


In [None]:
# for G_train, G_val, G_test, add node labels
# label the nodes of G_LP_connected, if appear in eoa_addr_list.txt, label 1, else label 0
# load eoa_addr_list.txt
eoa_addr_list = []
with open("eoa_addr_list.txt", "r") as f:
    for line in f:
        eoa_addr_list.append(line.strip())
        
# label the nodes of G_train, G_val, G_test, if appear in eoa_addr_list.txt, label 1, else label 0
for node in G_train.nodes():
    if node in eoa_addr_list:
        G_train.nodes[node]['label'] = 1
        print(node)
    else:
        G_train.nodes[node]['label'] = 0
        

for node in G_val.nodes():
    if node in eoa_addr_list:
        G_val.nodes[node]['label'] = 1
        print(node)
    else:
        G_val.nodes[node]['label'] = 0
        

for node in G_test.nodes():
    if node in eoa_addr_list:
        G_test.nodes[node]['label'] = 1
        print(node)
    else:
        G_test.nodes[node]['label'] = 0
        



In [None]:
# print G_train, G_val, G_test information of nodes, edges, and edge data
print(G_train.number_of_nodes())
print(G_train.number_of_edges())
print(list(G_train.edges(data=True))[0:10])

print(G_val.number_of_nodes())
print(G_val.number_of_edges())
print(list(G_val.edges(data=True))[0:10])

print(G_test.number_of_nodes())
print(G_test.number_of_edges())
print(list(G_test.edges(data=True))[0:10])

In [None]:
# get all neighbors of 0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338
neighbors = list(G_train.neighbors('0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338'))

In [None]:
G_train['0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338'][neighbors[0]]['weight']

In [None]:
# check degree of 0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338
G_train.degree('0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338')

In [None]:
# check in-degree of 0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338
G_train.in_degree('0x47b81da0bbe08cb3ae51cd378ab060a0fcd51338')

In [None]:
# dump G_train, G_val, G_test
with open("G_train.gpickle", "wb") as f:
    pickle.dump(G_train, f)
    
with open("G_val.gpickle", "wb") as f:
    pickle.dump(G_val, f)
    
with open("G_test.gpickle", "wb") as f:
    pickle.dump(G_test, f)

In [None]:
def calculate_average_neighbor_degree(node, graph):
    neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
    total_neighbor_degree = sum([graph.degree(neighbor) for neighbor in neighbors])
    
    if len(neighbors) > 0:
        average_neighbor_degree = total_neighbor_degree / len(neighbors)
    else:
        average_neighbor_degree = 0
    
    return average_neighbor_degree

In [None]:
def calculate_average_neighbor_degree(node, graph):
    neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
    total_neighbor_degree = sum([graph.degree(neighbor) for neighbor in neighbors])
    
    if len(neighbors) > 0:
        average_neighbor_degree = total_neighbor_degree / len(neighbors)
    else:
        average_neighbor_degree = 0
    
    return average_neighbor_degree


def node_feature_func(node, G):
    # Get neighbors
    neighbors = list(G.neighbors(node))  # Out-neighbors
    in_neighbors = list(G.predecessors(node))  # In-neighbors
    all_neighbors = list(set(neighbors + in_neighbors))  # Both in-neighbors and out-neighbors

    degree = G.degree(node)
    in_degree = G.in_degree(node)
    out_degree = G.out_degree(node)

    # Check if neighbors exist
    if all_neighbors:
        in_transactions = [G[neighbor][node].get('weight', 1) for neighbor in in_neighbors]
        out_transactions = [G[node][neighbor].get('weight', 1) for neighbor in neighbors]
        all_transactions = in_transactions + out_transactions
        max_transactions = max(all_transactions)
    else:
        max_transactions = 0

    avg_neighbor_degree = calculate_average_neighbor_degree(node, G)

    return [degree, in_degree, out_degree, len(neighbors), len(in_neighbors), len(all_neighbors), max_transactions, avg_neighbor_degree]


In [None]:
for G in [G_train, G_val, G_test]:
    for node in G.nodes:
        G.nodes[node]['features'] = node_feature_func(node, G)

In [None]:
# print first 10 nodes in G_train, G_val, G_test
print(list(G_train.nodes(data=True))[0:10])
print(list(G_val.nodes(data=True))[0:10])
print(list(G_test.nodes(data=True))[0:10])

In [None]:
# for all features, do log, then normalize
from sklearn.preprocessing import StandardScaler

# Get all features
all_features = np.array([G_train.nodes[node]['features'] for node in G_train.nodes()])

# Log transform all features
all_features = np.log(all_features + 1)

# Normalize all features
scaler = StandardScaler()
all_features = scaler.fit_transform(all_features)

# Set node features
for i, node in enumerate(G_train.nodes()):
    G_train.nodes[node]['normalized_log_features'] = all_features[i]
    
for i, node in enumerate(G_val.nodes()):
    G_val.nodes[node]['normalized_log_features'] = all_features[i]
    
for i, node in enumerate(G_test.nodes()):
    G_test.nodes[node]['normalized_log_features'] = all_features[i]

In [None]:
# print first 10 nodes in G_train, G_val, G_test
print(list(G_train.nodes(data=True))[0:10])
print(list(G_val.nodes(data=True))[0:10])
print(list(G_test.nodes(data=True))[0:10])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl


# Convert the graph to a DGL graph, keep node features

G_train_dgl = dgl.DGLGraph(G_train)
G_val_dgl = dgl.DGLGraph(G_val)
G_test_dgl = dgl.DGLGraph(G_test)

In [None]:
# Get the node features from the original graph
node_features_train = np.array([G_train.nodes[node]['features'] for node in G_train.nodes()])
node_features_val = np.array([G_val.nodes[node]['features'] for node in G_val.nodes()])
node_features_test = np.array([G_test.nodes[node]['features'] for node in G_test.nodes()])

# Set the node features in the DGL graph
G_train_dgl.ndata['features'] = torch.tensor(node_features_train)
G_val_dgl.ndata['features'] = torch.tensor(node_features_val)
G_test_dgl.ndata['features'] = torch.tensor(node_features_test)


In [None]:
# print node_features_train first 10 rows
print(node_features_train[0:10])

In [None]:
# Set the node features in the DGL graph
G_train_dgl.ndata['features'] = torch.tensor(node_features_train, dtype=torch.float32)
G_val_dgl.ndata['features'] = torch.tensor(node_features_val, dtype=torch.float32)
G_test_dgl.ndata['features'] = torch.tensor(node_features_test, dtype=torch.float32)

In [None]:
# print some examples in G_train_dgl
print(G_train_dgl.ndata['features'])

In [None]:
# load G_train, G_val, G_test
with open("G_train.gpickle", "rb") as f:
    G_train = pickle.load(f)
    
with open("G_val.gpickle", "rb") as f:
    G_val = pickle.load(f)
    
with open("G_test.gpickle", "rb") as f:
    G_test = pickle.load(f)

In [None]:
# print some nodes in G_train
print(list(G_train.nodes(data=True))[0:10])

In [None]:
# find how many nods both appear in G_train and G_val
G_train_nodes = set(G_train.nodes())
G_val_nodes = set(G_val.nodes())

print(len(G_train_nodes.intersection(G_val_nodes)))

In [None]:
# print one example of node in G_train_nodes.intersection(G_val_nodes), and its features in both G_train and G_val
node = list(G_train_nodes.intersection(G_val_nodes))[0]

# print degree of node in G_train
print(G_train.degree(node))

# print degree of node in G_val
print(G_val.degree(node))

In [None]:
# get those nodes labels, check how many nodes both label 1
G_train_nodes = set(G_train.nodes())

In [None]:
# add labels to G_train_dgl, G_val_dgl, G_test_dgl

# Get the labels from the original graph
node_labels_train = np.array([G_train.nodes[node]['label'] for node in G_train.nodes()])
node_labels_val = np.array([G_val.nodes[node]['label'] for node in G_val.nodes()])
node_labels_test = np.array([G_test.nodes[node]['label'] for node in G_test.nodes()])

# Set the labels in the DGL graph
G_train_dgl.ndata['label'] = torch.tensor(node_labels_train)
G_val_dgl.ndata['label'] = torch.tensor(node_labels_val)
G_test_dgl.ndata['label'] = torch.tensor(node_labels_test)

In [None]:
# print some labels
print(G_train_dgl.ndata['label'][0:10])
print(G_val_dgl.ndata['label'][0:10])
print(G_test_dgl.ndata['label'][0:10])

## GCN training graph largest connected component

In [None]:
import networkx as nx
import numpy as np
import scipy.sparse as sp
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
import dgl


In [None]:
G_train_dgl.ndata['features'].shape[1]

In [None]:
G_train_dgl.ndata['features'].shape[1]

In [None]:
# Add self-loops to the input graphs
G_train_dgl = dgl.add_self_loop(G_train_dgl)
G_val_dgl = dgl.add_self_loop(G_val_dgl)
G_test_dgl = dgl.add_self_loop(G_test_dgl)

In [None]:
# print check how many 1 and how many 0 in G_test_dgl.ndata['label']

print(G_test_dgl.ndata['label'].shape)

In [None]:
# print check how many 1 and how many 0 in G_test_dgl.ndata['label']
print(G_test_dgl.ndata['label'].sum())

In [None]:
# store G_train_dgl, G_val_dgl, G_test_dgl
with open("G_train_dgl.gpickle", "wb") as f:
    pickle.dump(G_train_dgl, f)
    
with open("G_val_dgl.gpickle", "wb") as f:
    pickle.dump(G_val_dgl, f)
    
with open("G_test_dgl.gpickle", "wb") as f:
    pickle.dump(G_test_dgl, f)


In [None]:
# read G_train_dgl, G_val_dgl, G_test_dgl
with open("G_train_dgl.gpickle", "rb") as f:
    G_train_dgl = pickle.load(f)
    
with open("G_val_dgl.gpickle", "rb") as f:
    G_val_dgl = pickle.load(f)
    
with open("G_test_dgl.gpickle", "rb") as f:
    G_test_dgl = pickle.load(f)

In [None]:
# print first 10 nodes in G_train_dgl, G_val_dgl, G_test_dgl
print({i: G_train_dgl.ndata['features'][i] for i in range(10)})
print({i: G_val_dgl.ndata['features'][i] for i in range(10)})
print({i: G_test_dgl.ndata['features'][i] for i in range(10)})


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
import copy
import random
# Set the random seed, a randamly selected number
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats, dropout_rate):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size)
        self.conv2 = GraphConv(hidden_size, hidden_size)  # added layer
        self.conv3 = GraphConv(hidden_size, out_feats)  # final layer
        self.dropout = nn.Dropout(dropout_rate)  # dropout layer
        self.batchnorm1 = nn.BatchNorm1d(hidden_size)  # batchnorm layer

    def forward(self, g, features):
        x = F.relu(self.conv1(g, features))
        x = self.dropout(x)  # apply dropout
        x = self.batchnorm1(x)  # apply batchnorm
        x = F.relu(self.conv2(g, x))
        x = self.dropout(x)  # apply dropout
        # x = self.batchnorm1(x)  # apply batchnorm
        x = self.conv3(g, x)
        return x

# Get the number of input features
in_feats = G_train_dgl.ndata['normalized_features'].shape[1]

# Define the model hyperparameters
hidden_size = 128
out_feats = 2  # Assuming binary classification
dropout_rate = 0.1


In [None]:
for i in range(10):
    # Set the random seed, a randamly selected number
    seed = random.randint(0, 1000)
    print(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Create the GCN model
    model = GCN(in_feats, hidden_size, out_feats, dropout_rate)

    # Define the optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCEWithLogitsLoss()
    best_val_loss = float('inf')
    best_model = None
    num_epochs = 200
    patience = 20

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()

        labels = G_train_dgl.ndata['label'].squeeze()
        features = G_train_dgl.ndata['normalized_features']

        # Select indices of 0 and 1 labels
        zero_indices = torch.where(labels == 0)[0]
        one_indices = torch.where(labels == 1)[0]
        
        # Get the minimum count between 0 and 1 labels
        min_count = min(zero_indices.shape[0], one_indices.shape[0])
        
        # Randomly select 'min_count' indices from zero_indices and one_indices each
        selected_zero_indices = zero_indices[torch.randperm(zero_indices.shape[0])[:min_count]]
        selected_one_indices = one_indices[torch.randperm(one_indices.shape[0])[:min_count]]

        # Combine the selected indices
        selected_indices = torch.cat((selected_zero_indices, selected_one_indices))

        # Shuffle the selected indices
        selected_indices = selected_indices[torch.randperm(selected_indices.shape[0])]

        # Create a subgraph from the selected indices
        subgraph = dgl.node_subgraph(G_train_dgl, selected_indices)

        # Get the selected features and labels
        selected_features = subgraph.ndata['normalized_features']
        selected_labels = subgraph.ndata['label'].squeeze()

        # Forward pass and compute the loss
        logits = model(subgraph, selected_features)
        labels = F.one_hot(selected_labels, num_classes=out_feats).float()
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            # Create balanced validation set
            labels = G_val_dgl.ndata['label'].squeeze()

            # Select indices of 0 and 1 labels
            zero_indices = torch.where(labels == 0)[0]
            one_indices = torch.where(labels == 1)[0]

            # Get the minimum count between 0 and 1 labels
            min_count = min(zero_indices.shape[0], one_indices.shape[0])

            # Randomly select 'min_count' indices from zero_indices and one_indices each
            selected_zero_indices = zero_indices[torch.randperm(zero_indices.shape[0])[:min_count]]
            selected_one_indices = one_indices[torch.randperm(one_indices.shape[0])[:min_count]]

            # Combine the selected indices
            selected_indices = torch.cat((selected_zero_indices, selected_one_indices))

            # Shuffle the selected indices
            selected_indices = selected_indices[torch.randperm(selected_indices.shape[0])]

            # Create a subgraph from the selected indices
            subgraph = dgl.node_subgraph(G_val_dgl, selected_indices)

            # Get the selected features and labels
            selected_features = subgraph.ndata['normalized_features']
            selected_labels = subgraph.ndata['label'].squeeze()

            # Validation
            logits = model(subgraph, selected_features)
            labels = F.one_hot(selected_labels, num_classes=out_feats).float()
            val_loss = criterion(logits, labels)
            
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            torch.save(model.state_dict(), 'best_model.pt')
        print(f"Epoch: {epoch + 1}/{num_epochs}, Loss: {loss:.4f}, Validation Loss: {val_loss:.4f}")

    best_model.eval()
    with torch.no_grad():
        # Create balanced testing set
        labels = G_test_dgl.ndata['label'].squeeze()

        # Select indices of 0 and 1 labels
        zero_indices = torch.where(labels == 0)[0]
        one_indices = torch.where(labels == 1)[0]

        # Get the minimum count between 0 and 1 labels
        min_count = min(zero_indices.shape[0], one_indices.shape[0])

        # Randomly select 'min_count' indices from zero_indices and one_indices each
        selected_zero_indices = zero_indices[torch.randperm(zero_indices.shape[0])[:min_count]]
        selected_one_indices = one_indices[torch.randperm(one_indices.shape[0])[:min_count]]

        # Combine the selected indices
        selected_indices = torch.cat((selected_zero_indices, selected_one_indices))

        # Shuffle the selected indices
        selected_indices = selected_indices[torch.randperm(selected_indices.shape[0])]

        # Create a subgraph from the selected indices
        subgraph = dgl.node_subgraph(G_test_dgl, selected_indices)

        # Get the selected features and labels
        selected_features = subgraph.ndata['normalized_features']
        ground_truth = subgraph.ndata['label'].squeeze()

        # Testing
        logits = best_model(subgraph, selected_features)
        _, predicted_labels = torch.max(logits, 1)

        # Calculate additional evaluation metrics for testing
        predicted_probs = F.softmax(logits, dim=1)[:, 1]
        predicted_labels = (predicted_probs > 0.5).float()
        auc = roc_auc_score(ground_truth.detach().numpy(), predicted_probs.detach().numpy())
        f1 = f1_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy())
        precision = precision_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy())
        recall = recall_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy())
        accuracy = accuracy_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy())
        macro_f1 = f1_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy(), average='macro')
        macro_precision = precision_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy(), average='macro')
        macro_recall = recall_score(ground_truth.detach().numpy(), predicted_labels.detach().numpy(), average='macro')
        # store results in a txt file
        with open("GCN_wo_results.txt", "a") as f:
            # need to write random seed, validation loss, test loss, auc, f1, precision, recall
            f.write(f"Random seed: {seed}, Epoch: {epoch + 1}/{num_epochs}, Loss: {loss:.4f}, Validation Loss: {val_loss:.4f}, AUC: {auc:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, Accuracy: {accuracy:.4f}, Macro-F1: {macro_f1:.4f}, Macro-Precision: {macro_precision:.4f}, Macro-recall: {macro_recall:.4f}\n")
        print(f"AUC: {auc:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, Accuracy: {accuracy:.4f}, Macro-F1: {macro_f1:.4f}, Macro-Precision: {macro_precision:.4f}, Macro-recall: {macro_recall:.4f}\n")
