In [None]:
import json
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from datetime import datetime
import wandb
import pandas as pd

In [None]:
import wandb
wandb.login(key='YOUR_KEY_HERE')

# Preprocess data

In [None]:
with open('../data/ordered_fake_propagation_paths.jsonl', 'r') as f:
    data_fake = [json.loads(line) for line in f]

with open('../data/ordered_real_propagation_paths.jsonl', 'r') as f:
    data_real = [json.loads(line) for line in f]

In [None]:
emb_fake = torch.load('../data/fake_news_net/ordered_fake_propagation_paths_emb.pt', weights_only=False)
emb_real = torch.load('../data/fake_news_net/ordered_real_propagation_paths_emb.pt', weights_only=False)

In [None]:
data = data_fake + data_real
embs = emb_fake + emb_real

In [None]:
def get_targets(data):
  """Extract total number of likes from tweet sequences"""
  targets = []
  for sequence in data:
    c = 0
    for tweet in sequence:
      c += tweet['favorite_count']
    targets.append(c)
  return targets

targets_likes = torch.tensor(get_targets(data)) 
categorical_targets = (targets_likes > targets_likes.median()).int() #categorical median-based target

In [None]:
# variables renaming for compatibility with next functions
X_full = data
embs_full = embs

In [None]:
def parse_twitter_date(date_str):
    """Parse Twitter date format: 'Thu Nov 17 21:19:15 +0000 2016'"""
    return datetime.strptime(date_str, '%a %b %d %H:%M:%S %z %Y')

def load_tweet_scalar_values(tweet):
    """Extract scalar features from tweet"""
    return [tweet[key] for key in ['verified', 'followers_count', 'following_count', 'favorite_count']]

In [None]:
def get_tensors(data, embs):
    """Process sequences to combine embeddings with scalar features and time differences"""
    processed_data = []

    for seq_idx, sequence in enumerate(data):
        processed_sequence = []
        base_time = parse_twitter_date(sequence[0]['created_at'])

        for tweet_idx, tweet in enumerate(sequence):
            tweet_embedding = embs[seq_idx][tweet_idx]
            scalar_values = load_tweet_scalar_values(tweet)
            
            # Calculate time difference from first tweet
            tweet_time = parse_twitter_date(tweet['created_at'])
            time_diff_seconds = (tweet_time - base_time).total_seconds()
            
            # Combine all features
            processed_tweet = torch.cat([
                tweet_embedding,
                torch.tensor(scalar_values, dtype=torch.float32),
                torch.tensor([time_diff_seconds], dtype=torch.float32)
            ])
            processed_sequence.append(processed_tweet)

        processed_data.append(processed_sequence)
    return processed_data

In [None]:
# Process the dataset
processed_data = get_tensors(X_full, embs_full)
all_data = [torch.stack(x) for x in processed_data]

In [None]:
print(f"Data shape before truncation (first tweet sequence): {all_data[0].shape}")
print(f"Features: 768 BERT embeddings + 5 scalar values (verified, followers, following, favorites, time_diff)")

In [None]:
def truncate_sequences(data, max_length=5):
  truncated_data = []
  for sequence in data:
    truncated_sequence = sequence[:max_length]
    if len(truncated_sequence) < max_length:
      truncated_sequence = torch.cat([truncated_sequence, torch.zeros(max_length - len(truncated_sequence), truncated_sequence.shape[1])])
    truncated_data.append(truncated_sequence)
  return truncated_data

In [None]:
# Prepare final dataset with truncated sequences
X_full_processed = torch.stack(truncate_sequences(all_data))
y_targets_full = categorical_targets #renaming for compatibility

print(f"Final dataset shape: {X_full_processed.shape}")
print(f"Targets shape: {y_targets_full.shape}")
print(f"Target distribution (fake/real): {torch.bincount(y_targets_full)}")

In [89]:
class TweetSequenceDataset(Dataset):
    def __init__(self, X, y):
        self.X = X  # (N, 5, 772)
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
def plot_confusion_matrix_image(y_true, y_pred, labels):
    cm = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap="BuGn", xticklabels=labels, yticklabels=labels, cbar=True)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual")
    ax.set_title("Confusion Matrix")
    return fig

In [91]:
def preprocess_features_for_fold(X_full, train_indices, val_indices):
    """
    Preprocess features for a specific fold, fitting scalers on training data only.
    This prevents data leakage during cross-validation.
    """
    from sklearn.preprocessing import StandardScaler

    # Get train and validation data for this fold
    X_train_fold = X_full[train_indices].clone()
    X_val_fold = X_full[val_indices].clone()

    # Separate different types of features for proper preprocessing
    # Features: [verified, followers_count, following_count, favorite_count, time_diff_seconds]

    # Extract features separately for training data
    verified_train = X_train_fold[:, :, -5]  # Binary feature - keep as is
    count_features_train = X_train_fold[:, :, -4:-1]  # followers_count, following_count, favorite_count
    time_diff_train = X_train_fold[:, :, -1]  # time difference in seconds

    # Extract features separately for validation data
    verified_val = X_val_fold[:, :, -5]
    count_features_val = X_val_fold[:, :, -4:-1]
    time_diff_val = X_val_fold[:, :, -1]

    # 1. Handle count features (apply log1p + standard scaling)
    count_features_train_log = torch.log1p(count_features_train)
    count_features_val_log = torch.log1p(count_features_val)

    # Reshape for scaling
    count_train_flat = count_features_train_log.reshape(-1, count_features_train_log.shape[-1])
    count_val_flat = count_features_val_log.reshape(-1, count_features_val_log.shape[-1])

    # Scale count features (fit on train, transform both)
    count_scaler = StandardScaler()
    count_train_scaled_flat = count_scaler.fit_transform(count_train_flat)
    count_val_scaled_flat = count_scaler.transform(count_val_flat)

    # Reshape back
    count_train_scaled = count_train_scaled_flat.reshape(count_features_train_log.shape)
    count_val_scaled = count_val_scaled_flat.reshape(count_features_val_log.shape)

    # 2. Handle time difference (apply log1p + separate scaling)
    time_diff_train_log = torch.log1p(torch.abs(time_diff_train))
    time_diff_val_log = torch.log1p(torch.abs(time_diff_val))

    # Reshape and scale time features
    time_train_flat = time_diff_train_log.reshape(-1, 1)
    time_val_flat = time_diff_val_log.reshape(-1, 1)

    time_scaler = StandardScaler()
    time_train_scaled_flat = time_scaler.fit_transform(time_train_flat)
    time_val_scaled_flat = time_scaler.transform(time_val_flat)

    # Reshape back
    time_train_scaled = time_train_scaled_flat.reshape(time_diff_train_log.shape)
    time_val_scaled = time_val_scaled_flat.reshape(time_diff_val_log.shape)

    # 3. Combine all features back (verified stays unchanged as it's binary)
    X_train_fold[:, :, -5] = verified_train  # Binary feature
    X_train_fold[:, :, -4:-1] = torch.tensor(count_train_scaled, dtype=torch.float32)  # Count features
    X_train_fold[:, :, -1] = torch.tensor(time_train_scaled.squeeze(), dtype=torch.float32)  # Time feature

    X_val_fold[:, :, -5] = verified_val
    X_val_fold[:, :, -4:-1] = torch.tensor(count_val_scaled, dtype=torch.float32)
    X_val_fold[:, :, -1] = torch.tensor(time_val_scaled.squeeze(), dtype=torch.float32)

    return X_train_fold, X_val_fold

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [94]:
# Cross-validation setup - preprocessing will be done within each fold to prevent data leakage
print("Setting up k-fold cross-validation...")
print(f"Dataset shape: {X_full_processed.shape}")
print(f"Target distribution: {torch.bincount(y_targets_full)}")
print("Preprocessing will be performed within each fold to prevent data leakage.")

Setting up k-fold cross-validation...
Dataset shape: torch.Size([578, 40, 773])
Target distribution: tensor([404, 174])
Preprocessing will be performed within each fold to prevent data leakage.


In [None]:
class LSTMClassifier(nn.Module):
    def __init__(self, bert_dim=768, numerical_features=5, numerical_embedding_dim=32,
                 hidden_dim=128, num_layers=1, dropout_p=0.1):
        super().__init__()

        self.numerical_features = numerical_features

        # Embedding layer for numerical features
        self.numerical_embedding = nn.Sequential(
            nn.Linear(numerical_features, numerical_embedding_dim),
            nn.GELU(),
            nn.Dropout(0.1)
        )

        # Total input dimension after concatenating BERT embeddings with numerical embeddings
        total_input_dim = bert_dim + numerical_embedding_dim

        self.lstm = nn.LSTM(total_input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout_p if num_layers > 1 else 0)

        self.head = nn.Sequential(
            nn.Linear(hidden_dim * 2, 32),
            nn.GELU(),
            nn.Dropout(dropout_p),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: (B, 5, 772) - last numerical_features dimensions are numerical features
        batch_size, seq_len, _ = x.shape

        # Split BERT embeddings and numerical features
        bert_embeddings = x[:, :, :-self.numerical_features]  # (B, 5, 768)
        numerical_features = x[:, :, -self.numerical_features:]  # (B, 5, numerical_features)

        # Transform numerical features to embeddings
        numerical_emb = self.numerical_embedding(numerical_features)  # (B, 5, numerical_embedding_dim)

        # Concatenate BERT embeddings with numerical embeddings
        combined_input = torch.cat([bert_embeddings, numerical_emb], dim=-1)  # (B, 5, 768 + numerical_embedding_dim)

        # Pass through LSTM
        output, _ = self.lstm(combined_input)  # (B, 5, 2*hidden_dim)

        last_hidden = output[:, -1, :]  # (B, 2*hidden_dim)

        return self.head(last_hidden).squeeze(-1)  # (B,)

In [96]:
def test_model_wandb(model, test_loader, device):
    """Enhanced test function with comprehensive metrics and wandb logging"""
    model.eval()
    criterion = nn.BCELoss()
    total_loss = 0
    all_probs, all_targets = [], []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            total_loss += loss.item()

            all_probs.append(y_pred.detach().cpu())
            all_targets.append(y_batch.detach().cpu())

    all_probs = torch.cat(all_probs)
    all_targets = torch.cat(all_targets)
    all_preds = (all_probs > 0.5).long()

    # Create confusion matrix
    fig = plot_confusion_matrix_image(all_targets.numpy(), all_preds.numpy(), [0, 1])

    metrics = {
        "test_loss": total_loss / len(test_loader),
        "test_accuracy": accuracy_score(all_targets, all_preds),
        "test_balanced_accuracy": balanced_accuracy_score(all_targets, all_preds),
        "test_f1": f1_score(all_targets, all_preds),
        "test_precision": precision_score(all_targets, all_preds),
        "test_recall": recall_score(all_targets, all_preds),
        "test_roc_auc": roc_auc_score(all_targets, all_probs),
        "test_confusion_matrix": wandb.Image(fig),
    }

    wandb.log(metrics)
    plt.close(fig)
    return metrics

In [None]:
def train_single_fold(config, fold, train_indices, val_indices, X_full, y_full, project_name):
    """Train a single fold with proper preprocessing and wandb logging"""
    from transformers import get_linear_schedule_with_warmup
    import time

    # Initialize wandb for this fold
    run_name = f"fold_{fold}"
    with wandb.init(config=config, project=project_name, name=run_name, save_code=True):

        # Preprocess features for this fold (prevents data leakage)
        X_train_fold, X_val_fold = preprocess_features_for_fold(X_full, train_indices, val_indices)
        y_train_fold = y_full[train_indices]
        y_val_fold = y_full[val_indices]

        # Create datasets for this fold
        train_dataset = TweetSequenceDataset(X_train_fold.float(), y_train_fold.float())
        val_dataset = TweetSequenceDataset(X_val_fold.float(), y_val_fold.float())

        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

        # Log fold information
        wandb.log({"fold": fold})

        # Initialize model and optimizer
        numerical_features = X_full.shape[-1] - 768
        model = LSTMClassifier(
            bert_dim=768,
            numerical_features=numerical_features,
            numerical_embedding_dim=config['numerical_embedding_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers']
        ).to(device)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        criterion = nn.BCELoss()

        # Scheduler
        total_steps = len(train_loader) * config['num_epochs']
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

        print(f"Training fold {fold}...")

        # Initialize best metrics tracking
        best_metrics = None
        best_f1 = 0.0
        best_epoch = 0

        for epoch in range(config['num_epochs']):
            model.train()
            total_loss = 0
            all_probs, all_targets = [], []

            for X_batch, y_batch in tqdm(train_loader, desc=f"Fold {fold}, Epoch {epoch+1}"):
                optimizer.zero_grad()
                X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
                y_pred = model(X_batch)
                loss = criterion(y_pred, y_batch)
                total_loss += loss.item()
                loss.backward()
                optimizer.step()
                scheduler.step()

                all_probs.append(y_pred.detach().cpu())
                all_targets.append(y_batch.detach().cpu())

            all_probs = torch.cat(all_probs)
            all_targets = torch.cat(all_targets)
            all_preds = (all_probs > 0.5).long()

            # Create training confusion matrix
            fig = plot_confusion_matrix_image(all_targets.numpy(), all_preds.numpy(), [0, 1])

            train_metrics = {
                "epoch": epoch,
                "train_loss": total_loss / len(train_loader),
                "train_accuracy": accuracy_score(all_targets, all_preds),
                "train_balanced_accuracy": balanced_accuracy_score(all_targets, all_preds),
                "train_f1": f1_score(all_targets, all_preds),
                "train_precision": precision_score(all_targets, all_preds),
                "train_recall": recall_score(all_targets, all_preds),
                "train_roc_auc": roc_auc_score(all_targets, all_probs),
                "train_confusion_matrix": wandb.Image(fig),
                "learning_rate": optimizer.param_groups[0]['lr']
            }

            wandb.log(train_metrics)
            plt.close(fig)

            # Validation
            val_metrics = test_model_wandb(model, val_loader, device)

            # Check if this is the best epoch based on validation F1 score
            current_f1 = val_metrics["test_f1"]
            if current_f1 > best_f1:
                best_f1 = current_f1
                best_epoch = epoch
                best_metrics = val_metrics.copy()
                best_metrics["best_epoch"] = best_epoch

                # Log best metrics with a special prefix
                best_log_metrics = {f"best_{key}": value for key, value in best_metrics.items()
                                   if key != "test_confusion_matrix"}  # Exclude confusion matrix from best logging
                wandb.log(best_log_metrics)

                print(f"New best validation f1: {best_f1:.4f} at epoch {epoch + 1}")

        # Log final best metrics summary
        best_metrics["fold"] = fold
        print(f"Fold {fold} completed. Best f1: {best_f1:.4f} at epoch {best_epoch + 1}")

        return best_metrics

In [None]:
def run_cross_validation(config=None):
    """Run 10-fold cross-validation"""

    # Configuration for the experiment
    if config is None:
        config = {
            'learning_rate': 0.0001,
            'weight_decay': 0.01,
            'batch_size': 16,
            'numerical_embedding_dim': 32,
            'hidden_dim': 128,
            'num_layers': 1,
            'num_epochs': 50
        }

    project_name = 'PROJECT NAME HERE'

    # Initialize 10-fold cross-validation
    skf = StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=42)

    all_best_metrics = []

    for fold, (train_indices, val_indices) in enumerate(skf.split(X_full_processed, y_targets_full)):
        print(f"\n=== FOLD {fold + 1}/10 ===")

        best_fold_metrics = train_single_fold(
            config, fold + 1, train_indices, val_indices,
            X_full_processed, y_targets_full, project_name
        )

        all_best_metrics.append(best_fold_metrics)


    # Calculate and print summary statistics using BEST metrics from each fold
    print(f"\n=== CROSS-VALIDATION SUMMARY (BEST METRICS) ===")
    metric_names = ['test_accuracy', 'test_balanced_accuracy', 'test_f1', 'test_precision', 'test_recall', 'test_roc_auc']

    for metric_name in metric_names:
        values = [fold_metrics[metric_name] for fold_metrics in all_best_metrics]
        mean_val = np.mean(values)
        std_val = np.std(values)
        print(f"Best {metric_name}: {mean_val:.4f} ± {std_val:.4f}")

    # Also show which epoch was best for each fold
    print(f"\nBest epochs for each fold:")
    for i, fold_metrics in enumerate(all_best_metrics):
        print(f"Fold {i+1}: Epoch {fold_metrics['best_epoch'] + 1} (F1: {fold_metrics['test_f1']:.4f}, Balanced Accuracy: {fold_metrics['test_balanced_accuracy']:.4f})")

    return all_best_metrics

In [None]:
# Configuration for the k-fold cross-validation experiment
config = {
    'learning_rate': 0.00008,
    'weight_decay': 0.01,
    'batch_size': 16,
    'numerical_embedding_dim': 32,
    'hidden_dim': 128,
    'num_layers': 1,
    'num_epochs': 100
}

print("Starting 10 shuffle cross-validation for LSTM model...")
print(f"Configuration: {config}")

# Run cross-validation with separate wandb runs for each fold
all_fold_results = run_cross_validation(config)