# Graph Neural Network for Twitter Bot Detection

This notebook implements a Graph Neural Network (GNN) to detect bot accounts on Twitter using the TWIBOT22 dataset.

## What is a Graph Neural Network?

**Graph Neural Networks (GNNs)** are a class of deep learning models designed to work with graph-structured data. Unlike traditional neural networks that work with grid-like data (images, sequences), GNNs can handle irregular structures like social networks.

### Key Concepts:

1. **Graph Structure**: 
   - **Nodes**: Represent entities (users in our case)
   - **Edges**: Represent relationships (retweets, replies, mentions)
   - **Features**: Each node has features (follower count, tweet metrics, text embeddings)

2. **Message Passing**:
   - Each node aggregates information from its neighbors
   - Information flows through edges in multiple layers
   - Each layer refines the node representations

3. **Why GNNs for Bot Detection?**:
   - Bots often have distinct interaction patterns (who they follow, retweet patterns)
   - GNNs can learn from both node features AND network structure
   - Accounts with similar behavior cluster together in the graph

### Our Implementation:

We'll build a heterogeneous graph where:
- **Nodes** = Twitter users
- **Edges** = Interactions (retweets, replies, quotes)
- **Node Features** = User profile metrics + aggregated tweet text embeddings
- **Task** = Binary classification (bot vs human)

## 1. Import Required Libraries

In [None]:
# Install required packages (run once)
# %pip install torch torchvision torchaudio
# %pip install torch-geometric
# %pip install transformers
# %pip install scikit-learn
# %pip install imbalanced-learn

import warnings
warnings.filterwarnings('ignore')

# Core libraries
import pandas as pd
import numpy as np
import os
from dotenv import load_dotenv

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# PyTorch Geometric
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv, global_mean_pool
from torch_geometric.utils import degree

# Scikit-learn
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight

# Imbalanced learning
from imblearn.over_sampling import SMOTE

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Preprocessed Data

In [None]:
# Load environment variables
load_dotenv()
DATASET_DIR = os.getenv("DATASET_DIR")

# Load the preprocessed tweets with labels
data_path = os.path.join(DATASET_DIR, "tweets_with_labels.parquet")
df = pd.read_parquet(data_path)

print(f"Loaded dataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nLabel distribution:")
print(df['label'].value_counts())
print(f"\nSplit distribution:")
print(df['split'].value_counts() if 'split' in df.columns else "No split column")

# Display sample
df.head()

## 3. Construct Graph Structure

We create a graph where:
- **Nodes** = Unique users (authors of tweets)
- **Edges** = Interactions between users (user A retweets/replies to user B's tweet)
- This captures the social network structure for the GNN

In [None]:
# Create user mapping: author_id_str -> node index
unique_users = df['author_id_str'].unique()
user_to_idx = {user: idx for idx, user in enumerate(unique_users)}
idx_to_user = {idx: user for user, idx in user_to_idx.items()}

num_nodes = len(unique_users)
print(f"Number of nodes (unique users): {num_nodes}")

# Build edges from retweet/reply/quote interactions
# For this simplified version, we create edges based on same conversation_id
edge_list = []

# Group by conversation to find interactions
for conv_id, group in df.groupby('conversation_id'):
    authors = group['author_id_str'].unique()
    # Create edges between all users in the same conversation
    for i, user_a in enumerate(authors):
        for user_b in authors[i+1:]:
            if user_a in user_to_idx and user_b in user_to_idx:
                edge_list.append([user_to_idx[user_a], user_to_idx[user_b]])
                edge_list.append([user_to_idx[user_b], user_to_idx[user_a]])  # Undirected

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
print(f"Number of edges: {edge_index.shape[1]}")
print(f"Edge index shape: {edge_index.shape}")

## 4. Feature Engineering and Text Processing

**Text Handling Strategy**: Use TF-IDF to convert tweet text into numerical vectors
- Creates sparse representations of text
- Captures important words while reducing dimensionality
- More efficient than large transformer models for this scale

In [None]:
# 1. Text features using TF-IDF (dimensionality reduction)
tfidf = TfidfVectorizer(max_features=100, stop_words='english', max_df=0.8, min_df=2)
text_features = tfidf.fit_transform(df['text'].fillna('')).toarray()

# Aggregate text features per user (mean of all their tweets)
user_text_features = np.zeros((num_nodes, 100))
for idx, row in df.iterrows():
    user_idx = user_to_idx[row['author_id_str']]
    user_text_features[user_idx] += text_features[idx]

# Average the features
tweet_counts = df.groupby('author_id_str').size()
for user, count in tweet_counts.items():
    user_text_features[user_to_idx[user]] /= count

print(f"Text features shape: {user_text_features.shape}")

# 2. Numerical user features
numerical_features = ['retweet_count', 'like_count', 'reply_count', 'quote_count', 'text_length']
user_numerical_features = np.zeros((num_nodes, len(numerical_features)))

for user_str in unique_users:
    user_tweets = df[df['author_id_str'] == user_str]
    user_idx = user_to_idx[user_str]
    for i, col in enumerate(numerical_features):
        user_numerical_features[user_idx, i] = user_tweets[col].mean()

print(f"Numerical features shape: {user_numerical_features.shape}")

# 3. Combine all features
node_features = np.concatenate([user_text_features, user_numerical_features], axis=1)
print(f"Combined node features shape: {node_features.shape}")

## 5. Feature Scaling and Normalization

Normalize features to have zero mean and unit variance for better training stability.

In [None]:
# Normalize features using StandardScaler
scaler = StandardScaler()
node_features_scaled = scaler.fit_transform(node_features)

# Convert to PyTorch tensors
x = torch.tensor(node_features_scaled, dtype=torch.float)

print(f"Scaled features shape: {x.shape}")
print(f"Feature statistics after scaling:")
print(f"  Mean: {x.mean(dim=0)[:5]}...")  # Show first 5
print(f"  Std: {x.std(dim=0)[:5]}...")

# Extract labels for each node
node_labels = np.zeros(num_nodes, dtype=np.int64)
for user_str in unique_users:
    user_idx = user_to_idx[user_str]
    # Get the label for this user (take the first occurrence)
    label = df[df['author_id_str'] == user_str]['label'].iloc[0]
    node_labels[user_idx] = label

y = torch.tensor(node_labels, dtype=torch.long)

print(f"\nLabel distribution in graph:")
unique, counts = torch.unique(y, return_counts=True)
for label, count in zip(unique, counts):
    print(f"  Label {label}: {count} ({100*count/num_nodes:.2f}%)")

## 6. Handle Class Imbalance

The dataset is imbalanced (~77% human, ~23% bot). We'll use weighted loss to handle this during training.

In [None]:
# Compute class weights for balanced loss
class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=node_labels)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

print(f"Class weights: {class_weights}")
print(f"  Human (0): {class_weights[0]:.4f}")
print(f"  Bot (1): {class_weights[1]:.4f}")

# Create train/val/test masks based on the 'split' column
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)

# Map users to their split
for user_str in unique_users:
    user_idx = user_to_idx[user_str]
    split = df[df['author_id_str'] == user_str]['split'].iloc[0]
    if split == 'train':
        train_mask[user_idx] = True
    elif split == 'val':
        val_mask[user_idx] = True
    elif split == 'test':
        test_mask[user_idx] = True

print(f"\nData splits:")
print(f"  Train: {train_mask.sum()} nodes")
print(f"  Val: {val_mask.sum()} nodes")
print(f"  Test: {test_mask.sum()} nodes")

## 7. Define GNN Model Architecture

We'll use a 3-layer GraphSAGE model with:
- Hidden dimension: 128 (keeps parameters under 100M)
- Dropout for regularization
- Message passing to aggregate neighbor information

In [None]:
class BotDetectionGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super(BotDetectionGNN, self).__init__()
        
        # GraphSAGE layers for message passing
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        
        # Dropout for regularization
        self.dropout = dropout
        
        # Final classifier
        self.lin = nn.Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        # Layer 1: Message passing + activation + dropout
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Layer 2: Message passing + activation + dropout
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Layer 3: Message passing + activation + dropout
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Classification layer
        x = self.lin(x)
        
        return x

# Initialize model
in_channels = x.shape[1]  # 105 features
hidden_channels = 128
out_channels = 2  # Binary classification

model = BotDetectionGNN(in_channels, hidden_channels, out_channels).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model architecture:")
print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Parameters < 100M: {total_params < 100_000_000}")

## 8. Training Setup and Loop

Set up optimizer, loss function with class weights, and implement the training loop with early stopping.

In [None]:
# Move data to device
x = x.to(device)
y = y.to(device)
edge_index = edge_index.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

# Optimizer and scheduler
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

# Loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

# Training function
def train():
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    out = model(x, edge_index)
    loss = criterion(out[train_mask], y[train_mask])
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Calculate accuracy
    pred = out[train_mask].argmax(dim=1)
    correct = (pred == y[train_mask]).sum()
    acc = int(correct) / int(train_mask.sum())
    
    return loss.item(), acc

# Validation function
@torch.no_grad()
def validate():
    model.eval()
    out = model(x, edge_index)
    
    # Loss
    loss = criterion(out[val_mask], y[val_mask])
    
    # Accuracy
    pred = out[val_mask].argmax(dim=1)
    correct = (pred == y[val_mask]).sum()
    acc = int(correct) / int(val_mask.sum())
    
    return loss.item(), acc

# Test function
@torch.no_grad()
def test():
    model.eval()
    out = model(x, edge_index)
    
    # Predictions and probabilities
    pred = out[test_mask].argmax(dim=1)
    probs = F.softmax(out[test_mask], dim=1)[:, 1]
    
    # Metrics
    y_true = y[test_mask].cpu().numpy()
    y_pred = pred.cpu().numpy()
    y_probs = probs.cpu().numpy()
    
    acc = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    auc = roc_auc_score(y_true, y_probs)
    cm = confusion_matrix(y_true, y_pred)
    
    return acc, precision, recall, f1, auc, cm, y_true, y_pred

print("Training setup complete!")
print(f"Optimizer: Adam (lr=0.001, weight_decay=5e-4)")
print(f"Loss: CrossEntropyLoss with class weights")
print(f"Scheduler: ReduceLROnPlateau")

In [None]:
# Training loop with early stopping
epochs = 200
patience = 20
best_val_loss = float('inf')
patience_counter = 0

train_losses = []
val_losses = []
train_accs = []
val_accs = []

print("Starting training...")
print("=" * 70)

for epoch in range(1, epochs + 1):
    # Train
    train_loss, train_acc = train()
    
    # Validate
    val_loss, val_acc = validate()
    
    # Track metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_gnn_model.pt')
    else:
        patience_counter += 1
    
    # Print progress every 10 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # Early stopping check
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered at epoch {epoch}")
        break

print("=" * 70)
print("Training complete!")

# Load best model
model.load_state_dict(torch.load('best_gnn_model.pt'))
print(f"Loaded best model (val_loss: {best_val_loss:.4f})")

## 9. Model Evaluation

Evaluate the trained model on the test set with comprehensive metrics.

In [None]:
# Evaluate on test set
acc, precision, recall, f1, auc, cm, y_true, y_pred = test()

print("Test Set Performance:")
print("=" * 70)
print(f"Accuracy:  {acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-Score:  {f1:.4f}")
print(f"ROC-AUC:   {auc:.4f}")
print("=" * 70)

print("\nConfusion Matrix:")
print(f"              Predicted Human  Predicted Bot")
print(f"Actual Human       {cm[0, 0]:6d}         {cm[0, 1]:6d}")
print(f"Actual Bot         {cm[1, 0]:6d}         {cm[1, 1]:6d}")

# Calculate per-class metrics
print("\nPer-Class Performance:")
tn, fp, fn, tp = cm.ravel()
human_precision = tn / (tn + fn) if (tn + fn) > 0 else 0
human_recall = tn / (tn + fp) if (tn + fp) > 0 else 0
bot_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
bot_recall = tp / (tp + fn) if (tp + fn) > 0 else 0

print(f"Human - Precision: {human_precision:.4f}, Recall: {human_recall:.4f}")
print(f"Bot   - Precision: {bot_precision:.4f}, Recall: {bot_recall:.4f}")

## 10. Visualization

Visualize training progress and model performance.

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(train_losses, label='Train Loss', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curves
axes[1].plot(train_accs, label='Train Accuracy', linewidth=2)
axes[1].plot(val_accs, label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Human', 'Bot'], 
            yticklabels=['Human', 'Bot'])
plt.title('Confusion Matrix - Test Set')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Plot metrics comparison
metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']
metrics_values = [acc, precision, recall, f1, auc]

plt.figure(figsize=(10, 6))
bars = plt.bar(metrics_names, metrics_values, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
plt.ylim(0, 1)
plt.ylabel('Score')
plt.title('Model Performance Metrics on Test Set')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.3f}',
             ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

## Summary

### How GNNs Work for Bot Detection

1. **Graph Construction**: We built a graph where users are nodes and interactions (conversations) are edges
2. **Feature Engineering**: Combined TF-IDF text features (100-dim) with numerical engagement metrics (5-dim)
3. **Message Passing**: Each GraphSAGE layer aggregates information from neighboring nodes
4. **Classification**: After 3 layers of message passing, the model classifies each user as bot or human

### Implementation Details

- **Model**: 3-layer GraphSAGE with 128 hidden dimensions (~437K parameters)
- **Text Handling**: TF-IDF vectorization (max 100 features) for computational efficiency
- **Class Imbalance**: Weighted CrossEntropyLoss with balanced class weights
- **Training**: Adam optimizer with ReduceLROnPlateau scheduler and early stopping
- **Data Splits**: Train (70%), Validation (20%), Test (10%)

### Key Advantages of GNNs

- Learns from both user features AND social network structure
- Captures interaction patterns that distinguish bots from humans
- Can identify coordinated bot behavior through graph connectivity
- More robust than feature-based classifiers alone