In [19]:
import jsonlines
from tqdm import tqdm
import os
from torch.optim import AdamW
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import re
import json
import torch.nn as nn
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, DataLoader as GeoDataLoader, Batch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
from nltk.tokenize import word_tokenize
import torch.nn.functional as F

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [21]:
FOLDER = "../data"

def load_jsonl(file_path):
    with jsonlines.open(file_path) as reader:
        data = [obj for obj in reader]
        return data
    
train_data = load_jsonl(f"{FOLDER}/train.jsonl")
val_data = load_jsonl(f"{FOLDER}/validation.jsonl")
test_data = load_jsonl(f"{FOLDER}/test.jsonl")

In [22]:
class DiplomacyVocabulary(Dataset):
    def __init__(self):
        # Initialize the vocabulary with special tokens
        self.word2idx = {"PAD": 0, "UNK": 1}
        self.idx2word = {0: "PAD", 1: "UNK"}
        
    def add_token(self, token):
        # Add a new token to the vocabulary
        token = token.lower()
        if token not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[token] = idx
            self.idx2word[idx] = token
        
    def __len__(self):
        return len(self.word2idx)
    
    def tokenize(self, message):
        message = message.lower()
        tokens = word_tokenize(message)
        return [self.word2idx.get(token, 1) for token in tokens]

In [23]:
class DiplomacyDataset(Dataset):
    def __init__(self, data, map_folder, move_folder, vocab=None, construct=False):
        self.data = data
        self.map_folder = map_folder
        self.move_folder = move_folder
        self.countries = ["Austria", "England", "France", "Germany", "Italy", "Russia", "Turkey", "None"]
        self.province_types = ["Inland", "Coastal", "Water"]
        self.unit_types = ["A", "F", "None"]
        self.provinces = None
        self.vocab = vocab if vocab else DiplomacyVocabulary()
        
        # Remove empty conversations
        to_rem = []
        for i in range(0, len(data)):
            flag = False
            for j in range(0, len(data[i]['messages'])):
                message = self.preprocess_text(data[i]['messages'][j])
                if message and data[i]['sender_labels'][j] != "NOANNOTATION":
                    flag = True
                    break
            if not flag:
                to_rem.append(i)
        for i in to_rem[::-1]:
            del data[i]
        self.data = data
        
        # Construct vocabulary if needed
        if construct:
            for item in data:
                for i, message in enumerate(item['messages']):
                    message = self.preprocess_text(message)
                    speaker = item['speakers'][i]
                    receiver = item['receivers'][i]
                    game_id = item['game_id']
                    year = item['years'][i]
                    season = item['seasons'][i].lower()
                    sender_moves, receiver_moves = self.load_moves(f"DiplomacyGame{game_id}_{year}_{season}", speaker, receiver)
                    sender = item['speakers'][i]
                    receiver = item['receivers'][i]
                    message = f"{sender} to {receiver}: {message}"
                    for token in word_tokenize(message):
                        self.vocab.add_token(token)
                    for token in word_tokenize(sender_moves):
                        self.vocab.add_token(token)
                    for token in word_tokenize(receiver_moves):
                        self.vocab.add_token(token)
        
    def __len__(self):
        return len(self.data)
    
    def preprocess_text(self, sentence):
        # Removing links
        sentence = re.sub(r'http\S+|www\S+|https\S+', '', sentence, flags=re.MULTILINE)
        # https://stackoverflow.com/questions/33404752/removing-emojis-from-a-string-in-python
        emoji_pattern = re.compile(
                    pattern = u"[\U0001F600-\U0001F64F"     # emoticons
                                "\U0001F300-\U0001F5FF"     # symbols & pictographs
                                "\U0001F680-\U0001F6FF"     # transport & map symbols
                                "\U0001F1E0-\U0001FAD6]+",  # flags (iOS)
                    flags = re.UNICODE)
        sentence = emoji_pattern.sub(r'', sentence)
        sentence = re.sub(r'\s+', ' ', sentence).strip()
        return sentence.lower()
    
    
    def one_hot_encode(self, value, categories):
        one_hot = np.zeros(len(categories))
        one_hot[categories.index(value)] = 1
        return one_hot
    
    def load_map(self, map_file):
        with open(f"{os.path.join(self.map_folder, map_file)}.json", 'r') as f:
            map_data = json.load(f)
        if self.provinces is None:
            self.provinces = list(map_data.keys())
        features = [None] * len(self.provinces)
        edges = []
        for province, info in map_data.items():
            country = self.one_hot_encode(info["controlledBy"], self.countries)
            province_type = self.one_hot_encode(info["provinceType"], self.province_types)
            unit_type = self.one_hot_encode(info["unitType"], self.unit_types)
            supply_centre = [0, 1] if info["supplyCentre"] else [1, 0]
            controlledBy = self.one_hot_encode(info["controlledBy"], self.countries)
            currentControl = self.one_hot_encode(info["currentControl"], self.countries)
            feature_vector = np.concatenate([country, province_type, unit_type, supply_centre, controlledBy, currentControl])
            features[self.provinces.index(province)] = feature_vector
            for adj in info["adjacency"]:
                edges.append((self.provinces.index(province), self.provinces.index(adj)))
        features = torch.tensor(features, dtype=torch.float)
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        return features, edge_index
    
    def load_moves(self, move_file, sender, receiver):
        with open(f"{os.path.join(self.move_folder, move_file)}.json", 'r') as f:
            move_data = json.load(f)
        sender_moves = " <sep> ".join(move_data[sender.capitalize()])
        receiver_moves = " <sep> ".join(move_data[receiver.capitalize()])
        return sender_moves.lower(), receiver_moves.lower()
    
    def __getitem__(self, idx):
        item = self.data[idx]
        data = {
            "messages": [],
            "map_features": [],
            "map_edges": [],
            "sender_moves": [],
            "receiver_moves": [],
            "labels": []
        }
        for i, message in enumerate(item['messages']):
            message = self.preprocess_text(message)
            if not message or item['sender_labels'][i] == "NOANNOTATION":
                continue
            sender_label = 1 if item['sender_labels'][i] else 0
            speaker = item['speakers'][i]
            receiver = item['receivers'][i]
            game_id = item['game_id']
            year = item['years'][i]
            season = item['seasons'][i].lower()
            features, edge_index = self.load_map(f"DiplomacyGame{game_id}_{year}_{season}")
            sender_moves, receiver_moves = self.load_moves(f"DiplomacyGame{game_id}_{year}_{season}", speaker, receiver)
            sender = item['speakers'][i]
            receiver = item['receivers'][i]
            message = f"{sender} to {receiver}: {message}"
            data["messages"].append(self.vocab.tokenize(message))
            data["map_features"].append(features)
            data["map_edges"].append(edge_index)
            data["sender_moves"].append(self.vocab.tokenize(sender_moves))
            data["receiver_moves"].append(self.vocab.tokenize(receiver_moves))
            data["labels"].append(sender_label)
        data["lengths"] = torch.tensor(len(data["messages"]), dtype=torch.long)
        return {
            "messages": data["messages"],
            "map_features": data["map_features"],
            "map_edges": data["map_edges"],
            "sender_moves": data["sender_moves"],
            "receiver_moves": data["receiver_moves"],
            "labels": data["labels"],
            "lengths": data["lengths"],
        }
            

In [24]:
MOVE_FOLDER = "../moves_sentences"
MAP_FOLDER = "../moves_map"

train_dataset = DiplomacyDataset(train_data, MAP_FOLDER, MOVE_FOLDER, construct=True)
vocab = train_dataset.vocab
val_dataset = DiplomacyDataset(val_data, MAP_FOLDER, MOVE_FOLDER, vocab=vocab)
test_dataset = DiplomacyDataset(test_data, MAP_FOLDER, MOVE_FOLDER, vocab=vocab)

In [25]:
def collate_fn(batch):
    batch_messages = []
    batch_map_features = []
    batch_map_edges = []
    batch_sender_moves = []
    batch_receiver_moves = []
    batch_labels = []
    batch_length = []
    
    for item in batch:
        batch_messages += item["messages"]
        batch_map_features += item["map_features"]
        batch_map_edges += item["map_edges"]
        batch_sender_moves += item["sender_moves"]
        batch_receiver_moves += item["receiver_moves"]
        batch_labels += item["labels"]
        batch_length += [item["lengths"]]
        
    batch_messages = pad_sequence([torch.tensor(m) for m in batch_messages], batch_first=True, padding_value=0)
    batch_sender_moves = pad_sequence([torch.tensor(m) for m in batch_sender_moves], batch_first=True, padding_value=0)
    batch_receiver_moves = pad_sequence([torch.tensor(m) for m in batch_receiver_moves], batch_first=True, padding_value=0)
    batch_map_features = torch.stack(batch_map_features)
    batch_map_edges = torch.stack(batch_map_edges)
    
    return {
        "messages": batch_messages,
        "map_features": batch_map_features,
        "map_edges": batch_map_edges,
        "sender_moves": batch_sender_moves,
        "receiver_moves": batch_receiver_moves,
        "labels": torch.tensor(batch_labels, dtype=torch.long),
        "length": torch.tensor(batch_length, dtype=torch.long)
    }

In [26]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [27]:
# Encodes the concatenated one-hot vectors of the map
class FeatureVectorEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FeatureVectorEncoder, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.fc(x)

In [28]:
# Encodes the player moves
class MoveEncoder(nn.Module):
    def __init__(self, vocab_size, output_dim, pretrained_embeddings=None):
        super(MoveEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, 200)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(200, output_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(output_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = x.long()
        outputs = []
        for i in range(x.size(0)):
            emb = self.embedding(x[i].unsqueeze(0))
            _, (hidden, _) = self.lstm(emb)
            hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
            out = self.fc(hidden)
            outputs.append(out.squeeze(0))
        return torch.stack(outputs, dim=0)

In [29]:
class DetectionModel(nn.Module):
    def __init__(self, vocab_size, pretrained_embeddings=None):
        super(DetectionModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, 200)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            self.embedding.weight.requires_grad = False
        self.feature_vector_encoder = FeatureVectorEncoder(32, 32).to(device)
        self.move_encoder = MoveEncoder(vocab_size, 32, pretrained_embeddings=pretrained_embeddings)
        self.gan1 = GATConv(32, 16, heads=4, dropout=0.4, concat=True)
        self.gan2 = GATConv(16*4, 64, heads=2, dropout = 0.1, concat = False)

        self.norm1 = nn.LayerNorm(64)
        self.norm2 = nn.LayerNorm(64)
        
        self.message_lstm = nn.LSTM(200, 128, batch_first=True, bidirectional=True)
        self.context_lstm = nn.LSTM(256, 128, batch_first=True, bidirectional=False)
        
        self.fusion = nn.Linear(448, 128)
        self.classifier = nn.Linear(128, 2)
        self.dropout = nn.Dropout(0.1)
        self.norm3 = nn.LayerNorm(128)
        
    
    def get_graph_embeddings(self, x, edge_index):
        graph_embeddings = []
        for i in range(0, x.size(0), 32):
            curr = x[i:i + 32].to(device)
            edge = edge_index[i:i + 32]
            curr = self.feature_vector_encoder(curr) # Getting feature vectors
            curr = [Data(x=curr[j], edge_index=edge[j]) for j in range(len(curr))]
            curr = Batch.from_data_list(curr).to(device)
            # Passing graph through GATConv layers
            y = self.gan1(curr.x, curr.edge_index)
            y = self.norm1(y)
            y = self.gan2(y, curr.edge_index).to(device)
            y = self.norm2(y)
            graph_embeddings.extend(y)
        graph_embeddings = torch.stack(graph_embeddings, dim=0)
        return graph_embeddings
    
    def get_message_embeddings(self, x):
        conversation_len = x.size(0)
        message_embeddings = []
        for i in range(0, conversation_len):
            message = x[i].unsqueeze(0)
            message = message[message != 0]
            message = message.unsqueeze(0)
            message = self.embedding(message)
            _, (hidden, _) = self.message_lstm(message)
            # Getting the last hidden state of the LSTM for each message
            hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
            message_embeddings.append(hidden)
        message_embeddings = torch.stack(message_embeddings, dim=0)
        return message_embeddings
    
    def forward(self, messages, map_features, map_edges, sender_moves, receiver_moves, lengths):
        lengths = lengths.cpu()
        map_embeddings = self.get_graph_embeddings(map_features, map_edges) 
        
        sender_moves = self.move_encoder(sender_moves)
        receiver_moves = self.move_encoder(receiver_moves)
        
        moves = torch.cat((sender_moves, receiver_moves), dim=-1) # Concatenating sender and receiver moves (query)
        
        batch_size = messages.size(0)
        map_embeddings = map_embeddings.view(batch_size, -1, 64) # Reshaping to (batch_size, num_provinces, 64) (key, value)
        
        query = moves.unsqueeze(1)
        attn_scores = torch.bmm(query, map_embeddings.transpose(1, 2)).squeeze(1)
        
        attn_scores = attn_scores.unsqueeze(1)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.bmm(attn_weights, map_embeddings).squeeze(1)
        
        message_embeddings = self.get_message_embeddings(messages)
        conversation_messages = torch.split(message_embeddings, lengths.tolist())        
        conversation_messages = pad_sequence(conversation_messages, batch_first=True, padding_value=0)
        conversation_messages = conversation_messages.squeeze(2)
        conversation_messages = pack_padded_sequence(conversation_messages, lengths.cpu(), batch_first=True, enforce_sorted=False)
        conversation_messages, _ = self.context_lstm(conversation_messages)
        conversation_embeddings, _ = pad_packed_sequence(conversation_messages, batch_first=True) 
        
        causal_embeddings = []
        for i, length in enumerate(lengths):
            for j in range(length.item()):
                # Only taking the representation of the messages in the conversation so far
                causal_embeddings.append(conversation_embeddings[i][j])
        causal_embeddings = torch.stack(causal_embeddings, dim=0)
        
        message_embeddings = message_embeddings.squeeze(1)
        
        # Fusing the message embeddings, causal embeddings, and attention output
        fused = torch.cat((message_embeddings, causal_embeddings, attn_output), dim=-1)
        fused = self.fusion(fused)
        x = self.norm3(fused)
        x = F.relu(x)
        x = self.dropout(x)
        logits = self.classifier(x)
        return logits.squeeze(-1)

In [30]:
model = DetectionModel(len(vocab)).to(device)
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))

model.eval()
with torch.no_grad():
    test_preds = []
    test_labels = []
    for batch in tqdm(test_dataloader, desc="Test", unit="batch"):
        messages = batch["messages"].to(device)
        map_features = batch["map_features"].to(device)
        map_edges = batch["map_edges"].to(device)
        sender_moves = batch["sender_moves"].to(device)
        receiver_moves = batch["receiver_moves"].to(device)
        labels = batch["labels"].to(device)
        lengths = batch["length"].to(device)
    
        logits = model(messages, map_features, map_edges, sender_moves, receiver_moves, lengths)
        preds = torch.argmax(logits, dim=1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())
        
test_f1 = f1_score(test_labels, test_preds, average='macro')
test_weighted_f1 = f1_score(test_labels, test_preds, average='weighted')
test_accuracy = accuracy_score(test_labels, test_preds)
cm = confusion_matrix(test_labels, test_preds)

print(f'Test Macro F1: {test_f1}')
print(f'Test Weighted F1: {test_weighted_f1}')
print(f'Test Accuracy: {test_accuracy}')
print('Confusion Matrix:')
print(cm)

Test: 100%|██████████| 11/11 [00:25<00:00,  2.34s/batch]

Test Macro F1: 0.5463822432375554
Test Weighted F1: 0.8531173194606964
Test Accuracy: 0.8513513513513513
Confusion Matrix:
[[  43  197]
 [ 210 2288]]



