In [1]:
!pip install transformers
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv
from torch_geometric.data import Data
from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, RobertaTokenizer, RobertaModel
from sklearn.metrics import accuracy_score, f1_score, classification_report, precision_recall_curve, roc_auc_score
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.ensemble import VotingClassifier, RandomForestClassifier, GradientBoostingClassifier
from imblearn.over_sampling import SMOTE
from tqdm import tqdm
import random
import os
import re
import nltk
from nltk.corpus import stopwords
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns

# Download NLTK resources
nltk.download('stopwords', quiet=True)
nltk.download('punkt', quiet=True)

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Enhanced constants
MAX_LEN = 160  # Increased for more context
BATCH_SIZE = 8  # Reduced for better gradient updates
EPOCHS = 10  # Increased training time
LEARNING_RATE = 2e-5  # Fine-tuned learning rate
NUM_PREVIOUS_MESSAGES = 3  # Increased context window
PLAYER_EMBEDDING_DIM = 64  # Increased complexity
METADATA_DIM = 32  # Increased metadata representation
DROPOUT_RATE = 0.4  # Tuned dropout

# Improved text preprocessing
def preprocess_text(text):
    """Enhanced text preprocessing with better feature extraction"""
    if not text:
        return ""

    # Convert to lowercase
    text = text.lower()

    # Replace multiple spaces with single space
    text = re.sub(r'\s+', ' ', text)

    # Standardize punctuation spacing
    text = re.sub(r'([.,!?;:])', r' \1 ', text)

    # Standardize quotes
    text = re.sub(r'["""]', '"', text)
    text = re.sub(r"[‘’´`']", "'", text)


    # Replace repeated characters (like "sooooo" -> "soo")
    text = re.sub(r'(.)\1{2,}', r'\1\1', text)

    # Final cleaning
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Enhanced linguistic features extraction
def extract_linguistic_features(text):
    """Extract advanced linguistic features that may indicate deception"""
    if not text:
        return np.zeros(10)

    # Tokenize text
    tokens = text.lower().split()

    # Get stopwords
    stop_words = set(stopwords.words('english'))

    features = [
        # Message length (normalized)
        min(1.0, len(tokens) / 100.0),

        # Question marks (may indicate seeking information or creating doubt)
        min(1.0, text.count('?') / 5.0),

        # Exclamation marks (may indicate enthusiasm or overcompensation)
        min(1.0, text.count('!') / 5.0),

        # First-person pronouns (deceptive messages often use fewer)
        len(re.findall(r'\b(i|me|my|mine|myself|we|us|our|ours|ourselves)\b', text)) / max(1, len(tokens)),

        # Third-person pronouns (shifting focus away from self)
        len(re.findall(r'\b(he|she|they|them|their|his|her|theirs|himself|herself|themselves)\b', text)) / max(1, len(tokens)),

        # Tentative language (maybe, perhaps, guess, think)
        len(re.findall(r'\b(maybe|perhaps|possibly|guess|think|probably|about|approximately|around|might|could|may)\b', text)) / max(1, len(tokens)),

        # Certainty language (definitely, certainly, always, never)
        len(re.findall(r'\b(definitely|certainly|always|never|absolutely|completely|totally|surely|undoubtedly)\b', text)) / max(1, len(tokens)),

        # Non-specific language (things, stuff, something)
        len(re.findall(r'\b(thing|things|stuff|something|anything|everything|nothing|somewhere|anywhere|everywhere|nowhere)\b', text)) / max(1, len(tokens)),

        # Ratio of non-stopwords to total words (information density)
        sum(1 for word in tokens if word not in stop_words) / max(1, len(tokens)),

        # Average word length (can indicate complexity/sophistication)
        sum(len(word) for word in tokens) / max(1, len(tokens))
    ]

    return np.array(features)

# Extract message sentiment features
def extract_sentiment_features(text):
    """Extract simple sentiment features"""
    if not text:
        return np.zeros(2)

    positive_words = set([
        'good', 'great', 'excellent', 'amazing', 'wonderful', 'fantastic', 'terrific',
        'best', 'better', 'nice', 'happy', 'glad', 'pleased', 'delighted', 'love',
        'like', 'enjoy', 'trust', 'honest', 'true', 'truth', 'agree', 'yes', 'perfect'
    ])

    negative_words = set([
        'bad', 'terrible', 'awful', 'horrible', 'poor', 'wrong', 'worst',
        'worse', 'sad', 'unhappy', 'angry', 'upset', 'hate', 'dislike',
        'disagree', 'no', 'never', 'not', 'lie', 'lying', 'lied', 'fake', 'false'
    ])

    tokens = text.lower().split()
    total_words = max(1, len(tokens))

    positive_ratio = sum(1 for word in tokens if word in positive_words) / total_words
    negative_ratio = sum(1 for word in tokens if word in negative_words) / total_words

    return np.array([positive_ratio, negative_ratio])

# Extract game state features
def extract_game_state_features(item):
    """Extract features related to the game state"""
    year = int(item['year'])
    season = item['season']

    # Game phase
    early_game = 1.0 if year < 1905 else 0.0
    mid_game = 1.0 if 1905 <= year < 1910 else 0.0
    late_game = 1.0 if year >= 1910 else 0.0

    # Season binary
    is_spring = 1.0 if season == 'Spring' else 0.0
    is_fall = 1.0 if season == 'Fall' else 0.0

    # Communication intensity
    rel_msg_idx = item['relative_message_index'] / 20.0  # Normalize
    abs_msg_idx = item['absolute_message_index'] / 1000.0  # Normalize

    # Score deltas (normalized)
    score_delta = min(1.0, max(-1.0, int(item['score_delta']) / 5.0))
    game_score = min(1.0, max(-1.0, int(item['game_score']) / 10.0))

    return np.array([
        early_game, mid_game, late_game,
        is_spring, is_fall,
        rel_msg_idx, abs_msg_idx,
        score_delta, game_score
    ])

class ImprovedDiplomacyDataset(Dataset):
    def __init__(self, data, tokenizer, max_len, player_to_idx, game_graphs=None,
                 tfidf_vectorizer=None, bow_vectorizer=None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.player_to_idx = player_to_idx
        self.game_graphs = game_graphs
        self.tfidf_vectorizer = tfidf_vectorizer
        self.bow_vectorizer = bow_vectorizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Get current message and context
        current_message = item['current_message']
        context_messages = item['context_messages']

        # Preprocess text
        current_message = preprocess_text(current_message)
        context_messages = [preprocess_text(msg) for msg in context_messages]

        # Combine context and current message with special tokens
        full_text = "[SEP]".join(context_messages + [current_message])

        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            full_text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        # Get sender and receiver indices
        sender_idx = self.player_to_idx.get(item['sender'], 0)
        receiver_idx = self.player_to_idx.get(item['receiver'], 0)

        # Get game ID
        game_id = item['game_id']

        # Extract all features
        ling_features = extract_linguistic_features(current_message)
        sent_features = extract_sentiment_features(current_message)
        game_features = extract_game_state_features(item)

        # Prepare TF-IDF and BOW features if vectorizers are provided
        tfidf_features = None
        if self.tfidf_vectorizer:
            tfidf_sparse = self.tfidf_vectorizer.transform([current_message])
            tfidf_features = torch.tensor(tfidf_sparse.toarray()[0], dtype=torch.float32)

        bow_features = None
        if self.bow_vectorizer:
            bow_sparse = self.bow_vectorizer.transform([current_message])
            bow_features = torch.tensor(bow_sparse.toarray()[0], dtype=torch.float32)

        # Comprehensive metadata features
        metadata = torch.tensor(np.concatenate([
            ling_features,
            sent_features,
            game_features
        ]), dtype=torch.float32)

        # Label
        label = 1 if item['sender_label'] else 0

        result = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'sender_idx': torch.tensor(sender_idx, dtype=torch.long),
            'receiver_idx': torch.tensor(receiver_idx, dtype=torch.long),
            'metadata': metadata,
            'game_id': game_id,
            'label': torch.tensor(label, dtype=torch.long),
            'text': current_message  # Store raw text for traditional models
        }

        # Add TF-IDF and BOW features if available
        if tfidf_features is not None:
            result['tfidf_features'] = tfidf_features

        if bow_features is not None:
            result['bow_features'] = bow_features

        return result

def build_enhanced_player_graphs(data_list):
    """Build enhanced player interaction graphs with more features"""
    game_graphs = {}

    # Group by game_id
    game_data = defaultdict(list)
    for item in data_list:
        game_id = item['game_id']
        game_data[game_id].append(item)

    for game_id, game_items in game_data.items():
        # Get all players in this game
        players = list(set([item['sender'] for item in game_items] +
                          [item['receiver'] for item in game_items]))
        player_to_idx = {player: idx for idx, player in enumerate(players)}

        # Initialize player statistics
        player_stats = {player: {
            'message_count': 0,
            'deceptive_count': 0,
            'countries_contacted': set(),
            'avg_message_len': [],
        } for player in players}

        # Calculate player statistics
        for item in game_items:
            sender = item['sender']
            receiver = item['receiver']
            is_deceptive = item['sender_label']

            player_stats[sender]['message_count'] += 1
            if is_deceptive:
                player_stats[sender]['deceptive_count'] += 1
            player_stats[sender]['countries_contacted'].add(receiver)
            player_stats[sender]['avg_message_len'].append(len(item['current_message'].split()))

        # Create node features (richer player representations)
        num_players = len(players)
        node_features = []

        for player in players:
            stats = player_stats[player]
            message_count = stats['message_count']

            # Compute statistics
            deception_ratio = stats['deceptive_count'] / max(1, message_count)
            contact_diversity = len(stats['countries_contacted']) / max(1, len(players) - 1)
            avg_msg_len = sum(stats['avg_message_len']) / max(1, len(stats['avg_message_len']))

            # Normalize message length
            norm_msg_len = min(1.0, avg_msg_len / 100.0)

            # Create feature vector for this player
            player_vector = [
                deception_ratio,
                contact_diversity,
                norm_msg_len,
                message_count / 100.0  # Normalize message count
            ]

            node_features.append(player_vector)

        # Convert to tensor
        x = torch.tensor(node_features, dtype=torch.float)

        # Create edges (sender -> receiver) with enhanced edge features
        edge_index = []
        edge_attr = []  # Enhanced edge features

        for item in game_items:
            sender_idx = player_to_idx[item['sender']]
            receiver_idx = player_to_idx[item['receiver']]

            # Edge features
            is_deceptive = 1.0 if item['sender_label'] else 0.0
            message_length = min(1.0, len(item['current_message'].split()) / 100.0)
            year_normalized = (int(item['year']) - 1900) / 30.0

            # Game state features
            is_spring = 1.0 if item['season'] == 'Spring' else 0.0
            score_delta = min(1.0, max(-1.0, int(item['score_delta']) / 5.0))

            edge_features = [
                is_deceptive,
                message_length,
                year_normalized,
                is_spring,
                score_delta
            ]

            edge_index.append([sender_idx, receiver_idx])
            edge_attr.append(edge_features)

        if edge_index:  # Check if there are any edges
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_attr, dtype=torch.float)

            # Create PyTorch Geometric Data object
            graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
            game_graphs[game_id] = graph_data

    return game_graphs

def prepare_enhanced_data(jsonl_file):
    """Prepare enhanced data with more context and features"""
    processed_data = []

    with open(jsonl_file, 'r') as f:
        for line in f:
            game_data = json.loads(line)
            messages = game_data['messages']
            sender_labels = game_data['sender_labels']
            speakers = game_data['speakers']
            receivers = game_data['receivers']
            rel_indices = game_data['relative_message_index']
            abs_indices = game_data['absolute_message_index']
            seasons = game_data['seasons']
            years = game_data['years']
            game_scores = game_data['game_score']
            score_deltas = game_data['game_score_delta']
            game_id = game_data['game_id']

            # Parse messages by player for better context
            player_messages = defaultdict(list)
            for i in range(len(messages)):
                if i < len(speakers) and i < len(messages):
                    player = speakers[i]
                    player_messages[player].append(i)

            for i in range(len(messages)):
                # Skip messages without labels
                if i >= len(sender_labels) or sender_labels[i] == "NOANNOTATION":
                    continue

                # Get context (previous N messages overall)
                context_start = max(0, i - NUM_PREVIOUS_MESSAGES)
                general_context = messages[context_start:i]

                # Also get the last message from this player if it exists
                sender = speakers[i]
                player_msg_indices = [idx for idx in player_messages[sender] if idx < i]

                player_context = []
                if player_msg_indices:
                    # Get the most recent message from this player
                    last_player_msg_idx = max(player_msg_indices)
                    player_context = [messages[last_player_msg_idx]]

                # Combine contexts with priority to player's own previous messages
                context_messages = player_context + general_context
                context_messages = context_messages[-NUM_PREVIOUS_MESSAGES:] if context_messages else []

                # If there are fewer than N previous messages, pad with empty strings
                while len(context_messages) < NUM_PREVIOUS_MESSAGES:
                    context_messages.insert(0, "")

                item = {
                    'current_message': messages[i],
                    'context_messages': context_messages,
                    'sender': speakers[i],
                    'receiver': receivers[i],
                    'sender_label': sender_labels[i],
                    'relative_message_index': rel_indices[i],
                    'absolute_message_index': abs_indices[i],
                    'season': seasons[i],
                    'year': years[i],
                    'game_score': game_scores[i],
                    'score_delta': score_deltas[i],
                    'game_id': game_id
                }

                processed_data.append(item)

    return processed_data

def augment_enhanced_data(data_list):
    """Enhanced data augmentation with more sophisticated techniques"""
    augmented_data = []

    # Focus on augmenting the minority class (usually truthful messages)
    truthful_data = [item for item in data_list if not item['sender_label']]
    deceptive_data = [item for item in data_list if item['sender_label']]

    truthful_count = len(truthful_data)
    deceptive_count = len(deceptive_data)
    print(f"Original data - Truthful: {truthful_count}, Deceptive: {deceptive_count}")

    # Determine which class is minority
    minority_class = truthful_data if truthful_count < deceptive_count else deceptive_data
    majority_class = deceptive_data if truthful_count < deceptive_count else truthful_data

    minority_label = False if truthful_count < deceptive_count else True

    print(f"Minority class: {'Truthful' if truthful_count < deceptive_count else 'Deceptive'}")

    # Calculate augmentation factor to balance classes
    target_count = len(majority_class)
    needed_augmentations = max(0, target_count - len(minority_class))
    augmentation_factor = max(1, needed_augmentations // len(minority_class) + 1)

    print(f"Creating {augmentation_factor} augmentations per minority class message")

    # Advanced augmentation techniques
    def swap_words(text, swap_ratio=0.15):
        words = text.split()
        if len(words) <= 3:
            return text
        num_swaps = max(1, int(swap_ratio * len(words)))
        for _ in range(num_swaps):
            if len(words) >= 2:
                idx1, idx2 = random.sample(range(len(words)), 2)
                words[idx1], words[idx2] = words[idx2], words[idx1]
        return ' '.join(words)

    def synonym_replacement(text):
        # Expanded synonym dictionary
        synonyms = {
            'good': ['great', 'excellent', 'fine', 'positive', 'beneficial'],
            'bad': ['poor', 'terrible', 'negative', 'awful', 'undesirable'],
            'attack': ['assault', 'strike', 'hit', 'offensive', 'charge'],
            'defend': ['protect', 'guard', 'shield', 'cover', 'safeguard'],
            'ally': ['partner', 'friend', 'supporter', 'confederate', 'associate'],
            'enemy': ['opponent', 'adversary', 'foe', 'rival', 'antagonist'],
            'think': ['believe', 'consider', 'feel', 'reckon', 'suppose'],
            'move': ['advance', 'proceed', 'shift', 'progress', 'transfer'],
            'help': ['assist', 'aid', 'support', 'back', 'facilitate'],
            'plan': ['strategy', 'scheme', 'idea', 'approach', 'method'],
            'want': ['need', 'desire', 'wish', 'hope', 'prefer'],
            'take': ['grab', 'seize', 'capture', 'acquire', 'secure'],
            'give': ['provide', 'offer', 'supply', 'deliver', 'furnish'],
            'make': ['create', 'produce', 'form', 'construct', 'develop'],
            'see': ['observe', 'notice', 'spot', 'witness', 'perceive'],
            'say': ['tell', 'mention', 'state', 'declare', 'express'],
            'know': ['understand', 'recognize', 'realize', 'comprehend', 'grasp']
        }

        words = text.split()
        for i, word in enumerate(words):
            word_lower = word.lower()
            if word_lower in synonyms and random.random() < 0.3:
                words[i] = random.choice(synonyms[word_lower])

        return ' '.join(words)

    def random_insertion(text):
        """Insert random common words in the text"""
        common_inserts = [
            'I think', 'perhaps', 'maybe', 'actually', 'honestly',
            'of course', 'clearly', 'anyway', 'definitely', 'absolutely',
            'in fact', 'obviously', 'frankly', 'truthfully'
        ]

        words = text.split()
        if len(words) <= 2:
            return text

        insert_pos = random.randint(0, len(words) - 1)
        insert_word = random.choice(common_inserts)
        words.insert(insert_pos, insert_word)

        return ' '.join(words)

    def random_deletion(text, p=0.1):
        """Randomly delete words from the text"""
        words = text.split()
        if len(words) <= 3:
            return text

        kept_words = []
        for word in words:
            if random.random() > p:
                kept_words.append(word)

        if not kept_words:  # Make sure we keep at least one word
            kept_words = [random.choice(words)]

        return ' '.join(kept_words)

    def add_noise(text, noise_level=0.1):
        words = text.split()
        if len(words) <= 3:
            return text

        # Randomly skip some words (mimics typos/omissions)
        if random.random() < noise_level:
            skip_idx = random.randint(0, len(words)-1)
            words.pop(skip_idx)

        # Randomly add fillers
        fillers = ['um', 'well', 'like', 'you know', 'I mean', 'so', 'anyway', 'basically']
        if random.random() < noise_level and len(words) > 3:
            insert_idx = random.randint(0, len(words)-1)
            words.insert(insert_idx, random.choice(fillers))

        return ' '.join(words)

    # Apply augmentation to minority class messages
    for item in minority_class:
        current_message = item['current_message']

        # Skip very short messages
        if not current_message or len(current_message.split()) < 3:
            continue

        for i in range(augmentation_factor):
            augmented_item = item.copy()

            # Apply different augmentation strategies
            if i % 5 == 0:
                augmented_text = swap_words(current_message)
            elif i % 5 == 1:
                augmented_text = synonym_replacement(current_message)
            elif i % 5 == 2:
                augmented_text = random_insertion(current_message)
            elif i % 5 == 3:
                augmented_text = random_deletion(current_message)
            else:
                augmented_text = add_noise(current_message)

            augmented_item['current_message'] = augmented_text
            augmented_data.append(augmented_item)

    # Combine original and augmented data
    combined_data = data_list + augmented_data

    # Count classes after augmentation
    truthful_count_after = sum(1 for item in combined_data if not item['sender_label'])
    deceptive_count_after = sum(1 for item in combined_data if item['sender_label'])
    print(f"After augmentation - Truthful: {truthful_count_after}, Deceptive: {deceptive_count_after}")

    return combined_data

class EnhancedGNNModule(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(EnhancedGNNModule, self).__init__()
        # GCN layer for initial feature processing
        self.conv1 = GCNConv(input_dim, hidden_dim)

        # GAT layers for attention-based message passing
        self.gat1 = GATConv(hidden_dim, hidden_dim // 2, heads=2, dropout=DROPOUT_RATE)
        self.gat2 = GATv2Conv(hidden_dim, output_dim, heads=1, dropout=DROPOUT_RATE)

        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x, edge_index, edge_attr=None):
        # Initial convolution
        x = F.relu(self.conv1(x, edge_index))
        x = self.layer_norm1(x)
        x = self.dropout(x)

        # First GAT layer with multi-head attention
        x = F.elu(self.gat1(x, edge_index))
        x = self.layer_norm2(x)
        x = self.dropout(x)

        # Second GAT layer
        x = self.gat2(x, edge_index)

        return x

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

class EnhancedDeceptionModel(nn.Module):
    def __init__(self, num_players, player_embedding_dim, metadata_dim, tfidf_dim=None, bow_dim=None):
        super(EnhancedDeceptionModel, self).__init__()

        # Use a stronger transformer model
        self.bert = RobertaModel.from_pretrained("roberta-base")
        self.bert_dim = self.bert.config.hidden_size  # 768 for roberta-base

        # Freeze part of the transformer to prevent overfitting
        # Only fine-tune the top 4 layers
        for param in list(self.bert.parameters())[:-4 * len(list(self.bert.encoder.layer[-1].parameters()))]:
            param.requires_grad = False

        # Enhanced GNN for player relationship modeling
        self.gnn = EnhancedGNNModule(num_players + 4, 128, player_embedding_dim)  # +4 for additional node features

        # Player embeddings as backup when graph unavailable
        self.player_embeddings = nn.Embedding(num_players, player_embedding_dim)

        # Metadata projection with layer normalization
        self.metadata_projection = nn.Sequential(
            nn.Linear(21, metadata_dim),  # Updated for additional features
            nn.LayerNorm(metadata_dim),
            nn.Dropout(DROPOUT_RATE),
            nn.ReLU()
        )

        # TF-IDF and BOW projections with regularization
        self.tfidf_dim = tfidf_dim
        if tfidf_dim is not None:
            self.tfidf_projection = nn.Sequential(
                nn.Linear(tfidf_dim, 64),
                nn.LayerNorm(64),
                nn.Dropout(DROPOUT_RATE),
                nn.ReLU()
            )

        self.bow_dim = bow_dim
        if bow_dim is not None:
            self.bow_projection = nn.Sequential(
                nn.Linear(bow_dim, 64),
                nn.LayerNorm(64),
                nn.Dropout(DROPOUT_RATE),
                nn.ReLU()
            )

        # Calculate fusion input dimension
        fusion_input_dim = self.bert_dim + player_embedding_dim * 2 + metadata_dim
        if tfidf_dim is not None:
            fusion_input_dim += 64
        if bow_dim is not None:
            fusion_input_dim += 64

        # Enhanced fusion layer with residual connections
        self.fusion_layer1 = nn.Linear(fusion_input_dim, 256)
        self.layer_norm1 = nn.LayerNorm(256)
        self.fusion_layer2 = nn.Linear(256, 128)
        self.layer_norm2 = nn.LayerNorm(128)

        # Output layer
        self.output_layer = nn.Linear(128, 2)

        # Additional regularization
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, input_ids, attention_mask, sender_idx, receiver_idx, metadata, game_id=None,
                tfidf_features=None, bow_features=None, game_graphs=None):
        # Process text with BERT
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.pooler_output

        # Process metadata
        metadata_features = self.metadata_projection(metadata)

        # Process sender and receiver embeddings
        # If game graph is available, use GNN-processed embeddings
        sender_embed = None
        receiver_embed = None

        if game_graphs is not None and game_id in game_graphs:
            graph = game_graphs[game_id]
            node_features = graph.x
            edge_index = graph.edge_index
            edge_attr = graph.edge_attr if hasattr(graph, 'edge_attr') else None

            # Get node embeddings from GNN
            node_embeddings = self.gnn(node_features, edge_index, edge_attr)

            # Extract embeddings for sender and receiver
            sender_embed = node_embeddings[sender_idx]
            receiver_embed = node_embeddings[receiver_idx]
        else:
            # Fallback to regular embeddings if graph not available
            sender_embed = self.player_embeddings(sender_idx)
            receiver_embed = self.player_embeddings(receiver_idx)

        # Process TF-IDF and BOW features if available
        tfidf_projected = None
        if self.tfidf_dim is not None and tfidf_features is not None:
            tfidf_projected = self.tfidf_projection(tfidf_features)

        bow_projected = None
        if self.bow_dim is not None and bow_features is not None:
            bow_projected = self.bow_projection(bow_features)

        # Concatenate all features
        combined = [pooled_output, sender_embed, receiver_embed, metadata_features]
        if tfidf_projected is not None:
            combined.append(tfidf_projected)
        if bow_projected is not None:
            combined.append(bow_projected)

        fused = torch.cat(combined, dim=1)

        # Pass through fusion layers with residual connections
        hidden1 = self.fusion_layer1(fused)
        hidden1 = self.layer_norm1(hidden1)
        hidden1 = F.relu(hidden1)
        hidden1 = self.dropout(hidden1)

        hidden2 = self.fusion_layer2(hidden1)
        hidden2 = self.layer_norm2(hidden2)
        hidden2 = F.relu(hidden2)
        hidden2 = self.dropout(hidden2)

        # Output
        logits = self.output_layer(hidden2)

        return logits


def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                num_epochs, device, game_graphs=None, early_stopping_patience=3):
    """Train the model with extensive logging and early stopping"""
    best_val_f1 = 0.0
    no_improvement_count = 0
    training_history = {'train_loss': [], 'train_acc': [], 'train_f1': [],
                       'val_loss': [], 'val_acc': [], 'val_f1': []}

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Training phase
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels = []

        for batch in tqdm(train_loader, desc="Training"):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            sender_idx = batch['sender_idx'].to(device)
            receiver_idx = batch['receiver_idx'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['label'].to(device)
            game_ids = batch['game_id']

            # Get optional features if available
            tfidf_features = batch.get('tfidf_features')
            if tfidf_features is not None:
                tfidf_features = tfidf_features.to(device)

            bow_features = batch.get('bow_features')
            if bow_features is not None:
                bow_features = bow_features.to(device)

            # Clear gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                sender_idx=sender_idx,
                receiver_idx=receiver_idx,
                metadata=metadata,
                game_id=game_ids,
                tfidf_features=tfidf_features,
                bow_features=bow_features,
                game_graphs=game_graphs
            )

            # Calculate loss
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # Update statistics
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())

        # Calculate training metrics
        train_loss = train_loss / len(train_loader)
        train_acc = accuracy_score(train_labels, train_preds)
        train_f1 = f1_score(train_labels, train_preds, average='macro')

        # Evaluation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []
        val_probs = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                # Move batch to device
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                sender_idx = batch['sender_idx'].to(device)
                receiver_idx = batch['receiver_idx'].to(device)
                metadata = batch['metadata'].to(device)
                labels = batch['label'].to(device)
                game_ids = batch['game_id']

                # Get optional features if available
                tfidf_features = batch.get('tfidf_features')
                if tfidf_features is not None:
                    tfidf_features = tfidf_features.to(device)

                bow_features = batch.get('bow_features')
                if bow_features is not None:
                    bow_features = bow_features.to(device)

                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    sender_idx=sender_idx,
                    receiver_idx=receiver_idx,
                    metadata=metadata,
                    game_id=game_ids,
                    tfidf_features=tfidf_features,
                    bow_features=bow_features,
                    game_graphs=game_graphs
                )

                # Calculate loss
                loss = criterion(outputs, labels)

                # Update statistics
                val_loss += loss.item()
                probs = F.softmax(outputs, dim=1)
                val_probs.extend(probs[:, 1].cpu().numpy())
                _, preds = torch.max(outputs, 1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        # Calculate validation metrics
        val_loss = val_loss / len(val_loader)
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average='macro')

        # Update learning rate
        scheduler.step(val_loss)

        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")

        # Update history
        training_history['train_loss'].append(train_loss)
        training_history['train_acc'].append(train_acc)
        training_history['train_f1'].append(train_f1)
        training_history['val_loss'].append(val_loss)
        training_history['val_acc'].append(val_acc)
        training_history['val_f1'].append(val_f1)

        # Early stopping check
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            no_improvement_count = 0
            # Save the best model
            print("Saving new best model...")
            torch.save(model.state_dict(), 'best_model.pt')
        else:
            no_improvement_count += 1
            if no_improvement_count >= early_stopping_patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break

        print("-" * 50)

    # Plot training history
    plot_training_history(training_history)

    return training_history


def evaluate_model(model, test_loader, criterion, device, game_graphs=None):
    """Evaluate the model on the test set with detailed metrics"""
    model.eval()
    test_loss = 0.0
    test_preds = []
    test_labels = []
    test_probs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            sender_idx = batch['sender_idx'].to(device)
            receiver_idx = batch['receiver_idx'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['label'].to(device)
            game_ids = batch['game_id']

            # Get optional features if available
            tfidf_features = batch.get('tfidf_features')
            if tfidf_features is not None:
                tfidf_features = tfidf_features.to(device)

            bow_features = batch.get('bow_features')
            if bow_features is not None:
                bow_features = bow_features.to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                sender_idx=sender_idx,
                receiver_idx=receiver_idx,
                metadata=metadata,
                game_id=game_ids,
                tfidf_features=tfidf_features,
                bow_features=bow_features,
                game_graphs=game_graphs
            )

            # Calculate loss
            loss = criterion(outputs, labels)

            # Update statistics
            test_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            test_probs.extend(probs[:, 1].cpu().numpy())
            _, preds = torch.max(outputs, 1)
            test_preds.extend(preds.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())

    # Calculate average loss
    test_loss = test_loss / len(test_loader)

    # Calculate metrics
    accuracy = accuracy_score(test_labels, test_preds)
    f1 = f1_score(test_labels, test_preds, average='macro')
    roc_auc = roc_auc_score(test_labels, test_probs)

    # Print detailed classification report
    print("\nClassification Report:")
    print(classification_report(test_labels, test_preds, target_names=['Truthful', 'Deceptive']))

    # Plot PR and ROC curves
    plot_evaluation_curves(test_labels, test_probs)

    # Print summary metrics
    print(f"\nTest Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Macro F1: {f1:.4f}")
    print(f"Test ROC AUC: {roc_auc:.4f}")

    return {
        'loss': test_loss,
        'accuracy': accuracy,
        'f1_score': f1,
        'roc_auc': roc_auc,
        'predictions': test_preds,
        'probabilities': test_probs,
        'true_labels': test_labels
    }


def plot_training_history(history):
    """Plot training and validation metrics over epochs"""
    plt.figure(figsize=(15, 10))

    # Plot loss
    plt.subplot(3, 1, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot accuracy
    plt.subplot(3, 1, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    # Plot F1 score
    plt.subplot(3, 1, 3)
    plt.plot(history['train_f1'], label='Train F1')
    plt.plot(history['val_f1'], label='Validation F1')
    plt.title('F1 Score over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()


def plot_evaluation_curves(true_labels, probabilities):
    """Plot Precision-Recall and ROC curves"""
    plt.figure(figsize=(15, 6))

    # Precision-Recall curve
    plt.subplot(1, 2, 1)
    precision, recall, _ = precision_recall_curve(true_labels, probabilities)
    plt.plot(recall, precision, marker='.')
    plt.title('Precision-Recall Curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.grid(True)

    # ROC curve
    plt.subplot(1, 2, 2)
    fpr, tpr, _ = precision_recall_curve(true_labels, probabilities)
    plt.plot(fpr, tpr, marker='.')
    plt.title('ROC Curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('evaluation_curves.png')
    plt.close()


def analyze_errors(true_labels, predictions, probabilities, texts, senders, receivers, game_ids):
    """Analyze the model's errors to gain insights"""
    # Find the indices of false positives and false negatives
    fp_indices = np.where((predictions == 1) & (np.array(true_labels) == 0))[0]
    fn_indices = np.where((predictions == 0) & (np.array(true_labels) == 1))[0]

    # Print some examples of false positives
    print("\nFalse Positive Examples (Truthful messages classified as Deceptive):")
    for i in range(min(5, len(fp_indices))):
        idx = fp_indices[i]
        print(f"Game ID: {game_ids[idx]}")
        print(f"Sender: {senders[idx]}, Receiver: {receivers[idx]}")
        print(f"Text: {texts[idx]}")
        print(f"Confidence: {probabilities[idx]:.4f}")
        print("-" * 80)

    # Print some examples of false negatives
    print("\nFalse Negative Examples (Deceptive messages classified as Truthful):")
    for i in range(min(5, len(fn_indices))):
        idx = fn_indices[i]
        print(f"Game ID: {game_ids[idx]}")
        print(f"Sender: {senders[idx]}, Receiver: {receivers[idx]}")
        print(f"Text: {texts[idx]}")
        print(f"Confidence: {1 - probabilities[idx]:.4f}")
        print("-" * 80)

    # Analyze errors by game id, sender, and receiver
    error_by_game = {}
    error_by_sender = {}
    error_by_receiver = {}

    for idx in range(len(true_labels)):
        game_id = game_ids[idx]
        sender = senders[idx]
        receiver = receivers[idx]
        is_error = predictions[idx] != true_labels[idx]

        # Update error counts
        if game_id not in error_by_game:
            error_by_game[game_id] = {'total': 0, 'errors': 0}
        error_by_game[game_id]['total'] += 1
        if is_error:
            error_by_game[game_id]['errors'] += 1

        if sender not in error_by_sender:
            error_by_sender[sender] = {'total': 0, 'errors': 0}
        error_by_sender[sender]['total'] += 1
        if is_error:
            error_by_sender[sender]['errors'] += 1

        if receiver not in error_by_receiver:
            error_by_receiver[receiver] = {'total': 0, 'errors': 0}
        error_by_receiver[receiver]['total'] += 1
        if is_error:
            error_by_receiver[receiver]['errors'] += 1

    # Calculate error rates
    for game_id, counts in error_by_game.items():
        counts['error_rate'] = counts['errors'] / counts['total']

    for sender, counts in error_by_sender.items():
        counts['error_rate'] = counts['errors'] / counts['total']

    for receiver, counts in error_by_receiver.items():
        counts['error_rate'] = counts['errors'] / counts['total']

    # Print error rates
    print("\nError rates by game:")
    for game_id, counts in sorted(error_by_game.items(), key=lambda x: x[1]['error_rate'], reverse=True):
        print(f"Game {game_id}: {counts['error_rate']:.4f} ({counts['errors']}/{counts['total']})")

    print("\nError rates by sender:")
    for sender, counts in sorted(error_by_sender.items(), key=lambda x: x[1]['error_rate'], reverse=True):
        print(f"Sender {sender}: {counts['error_rate']:.4f} ({counts['errors']}/{counts['total']})")

    print("\nError rates by receiver:")
    for receiver, counts in sorted(error_by_receiver.items(), key=lambda x: x[1]['error_rate'], reverse=True):
        print(f"Receiver {receiver}: {counts['error_rate']:.4f} ({counts['errors']}/{counts['total']})")


def main():
    # Load data
    print("Loading data...")
    train_data = prepare_enhanced_data('train.jsonl')
    val_data = prepare_enhanced_data('validation.jsonl')
    test_data = prepare_enhanced_data('test.jsonl')

    # Augment the training data for better balance
    print("Augmenting training data...")
    augmented_train_data = augment_enhanced_data(train_data)

    # Build player embeddings
    print("Building player mappings...")
    all_players = set()
    for item in train_data + val_data + test_data:
        all_players.add(item['sender'])
        all_players.add(item['receiver'])

    player_to_idx = {player: idx + 1 for idx, player in enumerate(sorted(all_players))}
    player_to_idx['UNK'] = 0  # Add unknown player token

    # Build TF-IDF and BOW vectorizers
    print("Building TF-IDF and BOW features...")
    train_texts = [item['current_message'] for item in augmented_train_data]

    tfidf_vectorizer = TfidfVectorizer(max_features=1000, min_df=5, max_df=0.7)
    tfidf_vectorizer.fit(train_texts)

    bow_vectorizer = CountVectorizer(max_features=1000, min_df=5, max_df=0.7)
    bow_vectorizer.fit(train_texts)

    # Build player graphs
    print("Building player interaction graphs...")
    game_graphs = build_enhanced_player_graphs(augmented_train_data + val_data + test_data)

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    # Create datasets
    print("Creating datasets...")
    train_dataset = ImprovedDiplomacyDataset(
        augmented_train_data, tokenizer, MAX_LEN, player_to_idx, game_graphs,
        tfidf_vectorizer, bow_vectorizer
    )

    val_dataset = ImprovedDiplomacyDataset(
        val_data, tokenizer, MAX_LEN, player_to_idx, game_graphs,
        tfidf_vectorizer, bow_vectorizer
    )

    test_dataset = ImprovedDiplomacyDataset(
        test_data, tokenizer, MAX_LEN, player_to_idx, game_graphs,
        tfidf_vectorizer, bow_vectorizer
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    # Create model
    print("Initializing model...")
    model = EnhancedDeceptionModel(
        num_players=len(player_to_idx),
        player_embedding_dim=PLAYER_EMBEDDING_DIM,
        metadata_dim=METADATA_DIM,
        tfidf_dim=len(tfidf_vectorizer.vocabulary_),
        bow_dim=len(bow_vectorizer.vocabulary_)
    )

    model = model.to(device)

    # Define loss function and optimizer
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

    # Train model
    print("Training model...")
    history = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler,
        num_epochs=EPOCHS, device=device, game_graphs=game_graphs
    )

    # Load best model
    print("Loading best model...")
    model.load_state_dict(torch.load('best_model.pt'))

    # Evaluate on test set
    print("Evaluating on test set...")
    test_results = evaluate_model(model, test_loader, criterion, device, game_graphs)

    # Analyze errors
    print("Analyzing errors...")
    test_texts = [item['current_message'] for item in test_data]
    test_senders = [item['sender'] for item in test_data]
    test_receivers = [item['receiver'] for item in test_data]
    test_game_ids = [item['game_id'] for item in test_data]

    analyze_errors(
        test_results['true_labels'],
        test_results['predictions'],
        test_results['probabilities'],
        test_texts,
        test_senders,
        test_receivers,
        test_game_ids
    )

    # Save the final model
    print("Saving model...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'player_to_idx': player_to_idx,
        'tfidf_vectorizer': tfidf_vectorizer,
        'bow_vectorizer': bow_vectorizer
    }, 'diplomacy_deception_model_complete.pt')

    print("Done!")


if __name__ == "__main__":
    main()

Using device: cuda
Loading data...
Augmenting training data...
Original data - Truthful: 591, Deceptive: 12541
Minority class: Truthful
Creating 21 augmentations per minority class message
After augmentation - Truthful: 12456, Deceptive: 12541
Building player mappings...
Building TF-IDF and BOW features...
Building player interaction graphs...
Loading tokenizer...
Creating datasets...
Initializing model...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training model...
Epoch 1/10


Training: 100%|██████████| 3125/3125 [07:07<00:00,  7.30it/s]
Validation: 100%|██████████| 177/177 [00:13<00:00, 13.25it/s]


Train Loss: 0.0373, Train Acc: 0.6914, Train F1: 0.6914
Val Loss: 0.0298, Val Acc: 0.8319, Val F1: 0.5223
Saving new best model...
--------------------------------------------------
Epoch 2/10


Training: 100%|██████████| 3125/3125 [07:06<00:00,  7.32it/s]
Validation: 100%|██████████| 177/177 [00:13<00:00, 13.29it/s]


Train Loss: 0.0135, Train Acc: 0.9356, Train F1: 0.9356
Val Loss: 0.0224, Val Acc: 0.9435, Val F1: 0.5093
--------------------------------------------------
Epoch 3/10


Training:  42%|████▏     | 1309/3125 [02:57<04:00,  7.55it/s]