In [1]:
import torch
from torch import nn

from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import roc_curve,auc

from sklearn.model_selection import train_test_split

from gsage-model import BotGraphSAGE
from utils import accuracy, init_weights

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cpu'
embedding_size = 32
dropout = 0.1
lr = 1e-2
weight_decay = 5e-2

In [3]:
#Re calculate Des Tensor
path = "/Dataset/"
des_tensor = torch.load(path + "filtered_des_tensor.pt").t().to(device)

num_prop = torch.load(path + "filtered_num_properties_tensor.pt").t().to(device)
category_prop = torch.load(path + "filtered_cat_properties_tensor.pt").t().to(device)
labels = torch.load(path + "filtered_label.pt").t().to(device)

tweets_tensor = torch.load("/Users/ketanjadhav/Documents/BotRGCN/processed_data/tweets_tensor.pt").t().to(device)

train_idx = torch.load(path + "filtered_train_idx.pt").to(device)
val_idx = torch.load(path + "filtered_val_idx.pt").to(device)
test_idx = torch.load(path + "filtered_test_idx.pt").to(device)

edge_index = torch.load(path + "filtered_edge_index.pt").to(device)
edge_type = torch.load(path + "filtered_edge_type.pt").to(device)

In [4]:
# followers_edge_index = torch.load("./tensors/followers_edge_index.pt").to(device)
# following_edge_index = torch.load("./tensors/following_edge_index.pt").to(device)
# combined_foll_edge_index = torch.load("./tensors/foll_combined_edge_index.pt").to(device)
# combined_foll_edge_type = torch.load("./tensors/foll_combined_edge_type.pt").to(device)
# interactions_edge_index = torch.load("./tensors/interactions_edge_index.pt").to(device)

In [5]:
sage_model=BotGraphSAGE(cat_prop_size=3,embedding_dimension=32).to(device)
loss=nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(sage_model.parameters(),
                    lr=lr,weight_decay=weight_decay)

sage_model.apply(init_weights)

BotGraphSAGE(
  (linear_relu_num_prop): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (linear_relu_cat_prop): Sequential(
    (0): Linear(in_features=3, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (linear_relu_input): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (sage1): GraphSAGE(32, 32, num_layers=2)
  (sage2): GraphSAGE(32, 32, num_layers=2)
  (linear_relu_output1): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (linear_output2): Linear(in_features=32, out_features=2, bias=True)
)

In [6]:
def train_sage(epoch, optimizer, train_edge_index):
    sage_model.train()
    output = sage_model(des_tensor, tweets_tensor, num_prop, category_prop, train_edge_index)
    loss_train = loss(output[train_idx], labels[train_idx])
    acc_train = accuracy(output[train_idx], labels[train_idx])
    acc_val = accuracy(output[val_idx], labels[val_idx])
    
    optimizer.zero_grad()
    loss_train.backward()
    optimizer.step()
    
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'acc_val: {:.4f}'.format(acc_val.item()))
    
    return acc_train, loss_train

In [7]:
def test_sage(test_edge_index):
    sage_model.eval()
    output = sage_model(des_tensor, tweets_tensor, num_prop, category_prop, test_edge_index)
    loss_test = loss(output[test_idx], labels[test_idx])
    acc_test = accuracy(output[test_idx], labels[test_idx])
    
    output = output.max(1)[1].to('cpu').detach().numpy()
    label = labels.to('cpu').detach().numpy()
    
    f1 = f1_score(label[test_idx], output[test_idx])
    precision = precision_score(label[test_idx], output[test_idx])
    recall = recall_score(label[test_idx], output[test_idx])
    
    fpr, tpr, thresholds = roc_curve(label[test_idx], output[test_idx], pos_label=1)
    auc_val = auc(fpr, tpr)
    
    print("Test set results:",
          "test_loss= {:.4f}".format(loss_test.item()),
          "test_accuracy= {:.4f}".format(acc_test.item()),
          "precision= {:.4f}".format(precision.item()),
          "recall= {:.4f}".format(recall.item()),
          "f1_score= {:.4f}".format(f1.item()),
          "auc= {:.4f}".format(auc_val.item()))

In [8]:
#In case we want to split edges

# num_edges = followers_edge_index.size(1)
# indices = torch.arange(num_edges)

# train_indices, test_indices = train_test_split(indices.numpy(), test_size=0.2, random_state=42)

# train_edge_index = followers_edge_index[:, train_indices]
# test_edge_index = followers_edge_index[:, test_indices]

COMBINED

In [10]:
epochs = 50

for epoch in range(epochs):
    train_sage(epoch, optimizer, edge_index)
    
test_sage(edge_index)

Epoch: 0001 loss_train: 0.4720 acc_train: 0.7793 acc_val: 0.5041
Epoch: 0002 loss_train: 0.4556 acc_train: 0.7799 acc_val: 0.5057
Epoch: 0003 loss_train: 0.4444 acc_train: 0.7748 acc_val: 0.5953
Epoch: 0004 loss_train: 0.4360 acc_train: 0.7824 acc_val: 0.5550
Epoch: 0005 loss_train: 0.4346 acc_train: 0.7811 acc_val: 0.3351
Epoch: 0006 loss_train: 0.4262 acc_train: 0.7887 acc_val: 0.4035
Epoch: 0007 loss_train: 0.4292 acc_train: 0.7805 acc_val: 0.5807
Epoch: 0008 loss_train: 0.4257 acc_train: 0.7911 acc_val: 0.4430
Epoch: 0009 loss_train: 0.4195 acc_train: 0.7895 acc_val: 0.4854
Epoch: 0010 loss_train: 0.4193 acc_train: 0.7856 acc_val: 0.5636
Epoch: 0011 loss_train: 0.4196 acc_train: 0.7897 acc_val: 0.3534
Epoch: 0012 loss_train: 0.4111 acc_train: 0.7945 acc_val: 0.4925
Epoch: 0013 loss_train: 0.4111 acc_train: 0.7938 acc_val: 0.6158
Epoch: 0014 loss_train: 0.4074 acc_train: 0.7919 acc_val: 0.4737
Epoch: 0015 loss_train: 0.4077 acc_train: 0.7919 acc_val: 0.4705
Epoch: 0016 loss_train: 0

FOLLOWING

In [11]:
epochs = 50

following_edge_index = torch.load(path + "filtered_following_edge_index.pt").to(device)

for epoch in range(epochs):
    train_sage(epoch, optimizer, following_edge_index)
    
test_sage(following_edge_index)

Epoch: 0001 loss_train: 0.3771 acc_train: 0.8256 acc_val: 0.7452
Epoch: 0002 loss_train: 0.3818 acc_train: 0.8175 acc_val: 0.5380
Epoch: 0003 loss_train: 0.3784 acc_train: 0.8231 acc_val: 0.6144
Epoch: 0004 loss_train: 0.3797 acc_train: 0.8216 acc_val: 0.7786
Epoch: 0005 loss_train: 0.3773 acc_train: 0.8240 acc_val: 0.7562
Epoch: 0006 loss_train: 0.3780 acc_train: 0.8235 acc_val: 0.6152
Epoch: 0007 loss_train: 0.3751 acc_train: 0.8279 acc_val: 0.6558
Epoch: 0008 loss_train: 0.3780 acc_train: 0.8213 acc_val: 0.7829
Epoch: 0009 loss_train: 0.3733 acc_train: 0.8284 acc_val: 0.7258
Epoch: 0010 loss_train: 0.3763 acc_train: 0.8267 acc_val: 0.6349
Epoch: 0011 loss_train: 0.3734 acc_train: 0.8259 acc_val: 0.7446
Epoch: 0012 loss_train: 0.3733 acc_train: 0.8250 acc_val: 0.7685
Epoch: 0013 loss_train: 0.3735 acc_train: 0.8268 acc_val: 0.6245
Epoch: 0014 loss_train: 0.3706 acc_train: 0.8294 acc_val: 0.7233
Epoch: 0015 loss_train: 0.3715 acc_train: 0.8277 acc_val: 0.7480
Epoch: 0016 loss_train: 0

FOLLOWER

In [12]:
epochs = 50

followers_edge_index = torch.load(path + "filtered_followers_edge_index.pt").to(device)

for epoch in range(epochs):
    train_sage(epoch, optimizer, followers_edge_index)
    
test_sage(followers_edge_index)

Epoch: 0001 loss_train: 0.3651 acc_train: 0.8321 acc_val: 0.6820
Epoch: 0002 loss_train: 0.3662 acc_train: 0.8286 acc_val: 0.7763
Epoch: 0003 loss_train: 0.3643 acc_train: 0.8338 acc_val: 0.6799
Epoch: 0004 loss_train: 0.3625 acc_train: 0.8348 acc_val: 0.6915
Epoch: 0005 loss_train: 0.3635 acc_train: 0.8310 acc_val: 0.7740
Epoch: 0006 loss_train: 0.3609 acc_train: 0.8344 acc_val: 0.6831
Epoch: 0007 loss_train: 0.3605 acc_train: 0.8338 acc_val: 0.6946
Epoch: 0008 loss_train: 0.3614 acc_train: 0.8331 acc_val: 0.7620
Epoch: 0009 loss_train: 0.3586 acc_train: 0.8356 acc_val: 0.7005
Epoch: 0010 loss_train: 0.3579 acc_train: 0.8370 acc_val: 0.7073
Epoch: 0011 loss_train: 0.3599 acc_train: 0.8333 acc_val: 0.7801
Epoch: 0012 loss_train: 0.3639 acc_train: 0.8300 acc_val: 0.5985
Epoch: 0013 loss_train: 0.3670 acc_train: 0.8267 acc_val: 0.8231
Epoch: 0014 loss_train: 0.3628 acc_train: 0.8325 acc_val: 0.6272
Epoch: 0015 loss_train: 0.3571 acc_train: 0.8355 acc_val: 0.7571
Epoch: 0016 loss_train: 0

INTERACTIONS

In [15]:
epochs = 50

interactions_edge_index = torch.load(path + "filtered_interaction_edge_index.pt").to(device)

for epoch in range(epochs):
    train_sage(epoch, optimizer, interactions_edge_index)
    
test_sage(interactions_edge_index)

Epoch: 0001 loss_train: 0.3492 acc_train: 0.8401 acc_val: 0.6757
Epoch: 0002 loss_train: 0.3498 acc_train: 0.8394 acc_val: 0.7778
Epoch: 0003 loss_train: 0.3553 acc_train: 0.8345 acc_val: 0.6103
Epoch: 0004 loss_train: 0.3610 acc_train: 0.8336 acc_val: 0.8202
Epoch: 0005 loss_train: 0.3703 acc_train: 0.8259 acc_val: 0.5496
Epoch: 0006 loss_train: 0.3626 acc_train: 0.8315 acc_val: 0.8093
Epoch: 0007 loss_train: 0.3481 acc_train: 0.8402 acc_val: 0.7203
Epoch: 0008 loss_train: 0.3573 acc_train: 0.8339 acc_val: 0.6289
Epoch: 0009 loss_train: 0.3546 acc_train: 0.8349 acc_val: 0.8039
Epoch: 0010 loss_train: 0.3484 acc_train: 0.8412 acc_val: 0.7180
Epoch: 0011 loss_train: 0.3499 acc_train: 0.8399 acc_val: 0.6603
Epoch: 0012 loss_train: 0.3573 acc_train: 0.8322 acc_val: 0.8017
Epoch: 0013 loss_train: 0.3513 acc_train: 0.8386 acc_val: 0.6522
Epoch: 0014 loss_train: 0.3476 acc_train: 0.8401 acc_val: 0.7428
Epoch: 0015 loss_train: 0.3487 acc_train: 0.8393 acc_val: 0.7741
Epoch: 0016 loss_train: 0

ALL

In [16]:
epochs = 50

all_edge_index = torch.load(path + "all_combined_edge_index.pt").to(device)

for epoch in range(epochs):
    train_sage(epoch, optimizer, all_edge_index)
    
test_sage(all_edge_index)

Epoch: 0001 loss_train: 0.3485 acc_train: 0.8409 acc_val: 0.6864
Epoch: 0002 loss_train: 0.3479 acc_train: 0.8398 acc_val: 0.7637
Epoch: 0003 loss_train: 0.3453 acc_train: 0.8421 acc_val: 0.7318
Epoch: 0004 loss_train: 0.3459 acc_train: 0.8409 acc_val: 0.7348
Epoch: 0005 loss_train: 0.3448 acc_train: 0.8416 acc_val: 0.7585
Epoch: 0006 loss_train: 0.3443 acc_train: 0.8417 acc_val: 0.7144
Epoch: 0007 loss_train: 0.3434 acc_train: 0.8427 acc_val: 0.7566
Epoch: 0008 loss_train: 0.3431 acc_train: 0.8444 acc_val: 0.6997
Epoch: 0009 loss_train: 0.3433 acc_train: 0.8439 acc_val: 0.7770
Epoch: 0010 loss_train: 0.3483 acc_train: 0.8392 acc_val: 0.6440
Epoch: 0011 loss_train: 0.3794 acc_train: 0.8190 acc_val: 0.8641
Epoch: 0012 loss_train: 0.4409 acc_train: 0.7809 acc_val: 0.2837
Epoch: 0013 loss_train: 0.3589 acc_train: 0.8344 acc_val: 0.7883
Epoch: 0014 loss_train: 0.3980 acc_train: 0.8119 acc_val: 0.8721
Epoch: 0015 loss_train: 0.3671 acc_train: 0.8308 acc_val: 0.7367
Epoch: 0016 loss_train: 0