In [1]:
import json
import re
from nltk.tokenize import word_tokenize
from transformers import BertTokenizer
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import f1_score
import nltk
nltk.download('punkt_tab')

from tqdm import tqdm


[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
import random

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
with open("/kaggle/input/diplomacy-dataset/test.jsonl", "r", encoding="utf-8") as file:
    test_data = [json.loads(line) for line in file]

with open("/kaggle/input/diplomacy-dataset/train.jsonl", "r", encoding="utf-8") as file:
    train_data = [json.loads(line) for line in file]

with open("/kaggle/input/diplomacy-dataset/validation.jsonl", "r", encoding="utf-8") as file:
    val_data = [json.loads(line) for line in file]



In [5]:
def preprocess(sentence ):
    sentence=sentence.lower()

    sentence = re.sub(r"[^a-zA-Z0-9 ]", "", sentence)  # can use punctations with bert , not with glove
    sentence = re.sub(r"\s+", " ", sentence).strip()

    return sentence

def prep_data_context(data ,  is_sender ):
    final_data=[]
    for data_points in data:
        sub = []
        for i, message in enumerate(data_points["messages"]):

            msg=preprocess(message )
            msg=word_tokenize(msg)
            if(len(msg)==0): continue

            if(is_sender):
              if(data_points['sender_labels'][i]=='NOANNOTATION'):
                continue
            else:
              if(data_points['receiver_labels'][i]=='NOANNOTATION'):
                continue

            sub.append({"message":msg ,
                        "label":(data_points["receiver_labels"][i],data_points["sender_labels"][i] )[is_sender] ,
                        "game_score_delta": int(data_points["game_score_delta"][i])})
        if(len(sub)==0): continue
        final_data.append(sub)
    return final_data

In [6]:
val=prep_data_context(val_data, 1)
train=prep_data_context(train_data,1)
test=prep_data_context(test_data ,1)

In [7]:
# tokens = []
# for sub in train:
#   for data_p in sub:
#     for word in data_p["message"]:
#         tokens.append(word)
# tokens=sorted(set(tokens))

# vocab = {token:idx+2  for idx , token in enumerate(tokens)}
# vocab["<PAD>"]=0
# vocab["<UNK>"]=1
from transformers import BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

2025-04-14 16:34:52.695445: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744648492.923600      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744648492.992937      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [8]:
class Deception_dataset_context(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        data_p = self.data[idx]
        msg_ids = []
        for sub in data_p:
            # Join the pre-tokenized words into a single string.
            # The BertTokenizer will perform its own tokenization.
            text = " ".join(sub['message'])
            # Get token IDs from BERT tokenizer.
            token_ids = self.tokenizer.encode(text, add_special_tokens=False)
            msg_ids.append(torch.tensor(token_ids, dtype=torch.long))
        try:
            return {
                "messages": msg_ids,
                "labels": torch.tensor([i['label'] for i in data_p], dtype=torch.long),
                "game_score_delta": torch.tensor([i["game_score_delta"] for i in data_p], dtype=torch.float)
            }
        except Exception as e:
            print("issue: ", e)
            return


In [9]:
def collate_fn_context(batch):
    messages = []
    labels = []
    lengths = []
    num_messages = []
    game_score_deltas = []
    for i in batch:
        messages.extend(i['messages'])
        lengths.extend([j.shape[0] for j in i['messages']])
        labels.extend(i['labels'])
        num_messages.append(len(i['messages']))
        game_score_deltas.extend(i['game_score_delta'])
    padded_messages = pad_sequence(messages, batch_first=True, padding_value=0)
    return {
        "messages": padded_messages,
        "lengths": lengths,
        "labels": torch.tensor(labels),
        "num_messages": num_messages,
        "deltas":torch.tensor(game_score_deltas)
    }


In [None]:
# def load_glove(file):
#     embeddings = {}
#     with open(file, 'r', encoding='utf-8') as f:
#         for line in f:
#             embed = line.split()
#             word = embed[0]
#             embedding = torch.tensor([float(i) for i in embed[1:]], dtype=torch.float)
#             embeddings[word] = embedding
#     return embeddings

In [10]:
import random
import torch_geometric
from torch_geometric.data import Batch, Data

class ContextLSTM(nn.Module):
  def __init__(self, embedding_dim, hidden_size_message, gat_dim, num_classes):
    super(ContextLSTM,self).__init__()
    # glove_embeddings = load_glove(glove_file)

    # vocab_size = len(vocab)
    # self.embedding_matrix = torch.zeros(vocab_size, embedding_dim)
    # for token, idx in vocab.items():
    #     if token in glove_embeddings:
    #         self.embedding_matrix[idx] = glove_embeddings[token]
    #     else:
    #         self.embedding_matrix[idx] = torch.randn(embedding_dim) * 0.6

    # self.embedding = nn.Embedding.from_pretrained(self.embedding_matrix, freeze=True, padding_idx=0)
    self.bert_model = BertModel.from_pretrained('bert-base-uncased')
    self.embedding = self.bert_model.embeddings.word_embeddings
    # Unfreeze BERT embeddings so they can be fine-tuned
    for param in self.bert_model.embeddings.parameters():
        param.requires_grad = False

    self.lstm_message = nn.LSTM(embedding_dim, hidden_size_message, batch_first=True, bidirectional=True)
    # self.gat1 = torch_geometric.nn.conv.GATConv(in_channels=hidden_size_message*2,out_channels=gat_dim//2,heads=4)
    self.gat1 = torch_geometric.nn.conv.GATConv(in_channels=hidden_size_message*2,out_channels=gat_dim*2,heads=4,  concat=False)
      
    self.gat2 = torch_geometric.nn.conv.GATConv(in_channels=gat_dim*2,out_channels=gat_dim,heads=2 , concat=False)
    self.fc = nn.Linear(hidden_size_message * 2 + gat_dim, num_classes)
    self.relu = torch.nn.ReLU()
    self.tanh = torch.nn.Tanh()
    self.sigmoid = torch.nn.Sigmoid()  
    self.gate1 = nn.Linear(hidden_size_message * 2 + gat_dim,hidden_size_message * 2)
    self.gate2 = nn.Linear(hidden_size_message * 2 + gat_dim,gat_dim)
    self.hidden_size_message = hidden_size_message
    self.gat_dim = gat_dim

  def fusion(self,x):
      lstm_emb = x[:,:2*self.hidden_size_message]
      gat_emb = x[:,2*self.hidden_size_message:]
      
      lstm_emb_tanh = self.tanh(lstm_emb) 
      gat_emb_tanh = self.tanh(gat_emb)
      # print('lstm emb: ',lstm_emb.shape)
      # print("gat emb shape: ",gat_emb.shape)
      mixed = torch.cat((lstm_emb,gat_emb),dim=1)
      mixed_sigmoid_left = self.sigmoid(self.gate1(mixed))
      mixed_sigmoid_right = self.sigmoid(self.gate2(mixed))

      final_emb = torch.cat((mixed_sigmoid_left*lstm_emb_tanh,mixed_sigmoid_right*gat_emb_tanh),dim=1)
      return final_emb

  def get_message_emb(self,input_ids,lengths):
    embedded = self.embedding(input_ids)
    # print(embedded.shape)
    packed = pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False)

    _, (hn, _1) = self.lstm_message(packed)
    last_hidden = torch.cat((hn[0], hn[1]), dim=1)
    return last_hidden

  

  def forward(self, input_ids, num_messages, lengths,scores):
    total_messages = sum(num_messages)
    inputs_to_msg_encoder = input_ids
    encoded_messages = self.get_message_emb(inputs_to_msg_encoder,lengths)
    input_to_convo_encoder = torch.split(encoded_messages,num_messages)
    # print(input_to_convo_encoder[0].shape)

    tot_data = []

    for i in range(len(input_to_convo_encoder)):
        msg_embs = input_to_convo_encoder[i] # num_msgs x emb_dim
        n_nodes = msg_embs.shape[0]
        edge_ind = torch.combinations(torch.arange(n_nodes), r=2).T
        edge_ind = torch.cat([edge_ind, edge_ind.flip(0)], dim=1).to(input_ids.device)
    
        tot_data.append(Data(x=msg_embs,edge_index=edge_ind))

    tot_data = Batch.from_data_list(tot_data)
    from_gat1 = self.relu(self.gat1(tot_data.x,tot_data.edge_index))
    
    # print("from gat 1 shape: ",from_gat1.shape)

    from_gat2 = self.gat2(from_gat1,tot_data.edge_index)
    # print("enc msg shape: ",encoded_messages.shape)
    # print("from_gat2 shap: ",from_gat2.shape)
      
    input_for_fusion = [] 
    prev = 0
    # print("enc msg shape: ",encoded_messages.shape)
    for i, msg_emb in enumerate(encoded_messages):
        # print("i: ",i)
        inp = torch.cat((msg_emb,from_gat2[i]),dim=0)
        input_for_fusion.append(inp)
    # print("shape before fusion: ",torch.stack(input_for_fusion).shape)
    final_combined = self.fusion(torch.stack(input_for_fusion))
    # print("final combined shape: ",final_combined.shape)
    final_input = torch.cat((final_combined,  scores.unsqueeze(1)), dim=1)
    # print(final_input.shape)
    logits = self.fc(final_combined)
    # print("logits shape: ",logits.shape)
    return logits



Train Context LSTM:

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from sklearn.metrics import accuracy_score

# Set embedding_dim to 768 for BERT embeddings.
embedding_dim = 768
hidden_size = 128
num_classes = 2
gat_dim=50

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = ContextLSTM(embedding_dim, hidden_size, gat_dim, num_classes)
model = model.to(device)

class_weights = torch.tensor([1.0 / 0.05, 1.0 / 0.95], dtype=torch.float).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
loss = nn.CrossEntropyLoss(weight=class_weights)

train_dataset = Deception_dataset_context(train, tokenizer)
val_dataset = Deception_dataset_context(val, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn_context)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn_context)
test_dataset = Deception_dataset_context(test, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn_context)

for epoch in range(15):
    model.train()
    train_loss = 0
    train_preds, train_labels = [], []
    for batch in train_dataloader:
        messages = batch["messages"].to(device)
        lengths = batch["lengths"]
        labels = batch["labels"].to(device)
        num_messages = batch['num_messages']
        scores = batch['deltas'].to(device)
        optimizer.zero_grad()
        logits = model(messages, num_messages, lengths,scores)
        loss_ = loss(logits, labels)
        loss_.backward()
        optimizer.step()

        train_loss += loss_.item()
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        train_preds.extend(preds)
        train_labels.extend(labels.cpu().numpy())

    train_loss /= len(train_dataloader)
    train_f1 = f1_score(train_labels, train_preds, average='macro')
    train_acc = accuracy_score(train_labels,train_preds)

    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_dataloader:
            messages = batch["messages"].to(device)
            lengths = batch["lengths"]
            labels = batch["labels"].to(device)
            num_messages = batch['num_messages']
            scores = batch['deltas'].to(device)
            logits = model(messages, num_messages, lengths,scores)
            loss_ = loss(logits, labels)
            val_loss += loss_.item()
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(preds)
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_dataloader)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_acc = accuracy_score(val_labels,val_preds)

    print(f"Epoch {epoch+1}:  Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Train acc: {train_acc:.4f}  Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}, Val acc: {val_acc:.4f}")
    torch.save(model.state_dict(),f"gat_contextlstm_glove_{epoch+1}.pth")


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Epoch 1:  Train Loss: 0.6930, Train F1: 0.4885, Train acc: 0.9550  Val Loss: 0.6879, Val F1: 0.4899, Val acc: 0.9603
Epoch 2:  Train Loss: 0.6890, Train F1: 0.4885, Train acc: 0.9550  Val Loss: 0.6869, Val F1: 0.4899, Val acc: 0.9603
Epoch 3:  Train Loss: 0.6857, Train F1: 0.4885, Train acc: 0.9550  Val Loss: 0.6853, Val F1: 0.4899, Val acc: 0.9603
Epoch 4:  Train Loss: 0.6823, Train F1: 0.4885, Train acc: 0.9550  Val Loss: 0.6834, Val F1: 0.4899, Val acc: 0.9603
Epoch 5:  Train Loss: 0.6841, Train F1: 0.4885, Train acc: 0.9550  Val Loss: 0.6819, Val F1: 0.4899, Val acc: 0.9603
Epoch 6:  Train Loss: 0.6805, Train F1: 0.5102, Train acc: 0.9029  Val Loss: 0.6788, Val F1: 0.5168, Val acc: 0.8673
Epoch 7:  Train Loss: 0.6736, Train F1: 0.5292, Train acc: 0.8760  Val Loss: 0.6742, Val F1: 0.5504, Val acc: 0.9163
Epoch 8:  Train Loss: 0.6548, Train F1: 0.5167, Train acc: 0.8126  Val Loss: 0.6691, Val F1: 0.4878, Val acc: 0.7977
Epoch 9:  Train Loss: 0.6519, Train F1: 0.4831, Train acc: 0.710

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from sklearn.metrics import accuracy_score
embedding_dim = 768
hidden_size = 128
num_classes = 2
gat_dim=50

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = ContextLSTM(embedding_dim, hidden_size, gat_dim, num_classes)
model.load_state_dict(torch.load("/kaggle/input/diplomacy-dataset/gat_contextlstm_glove_11.pth"))
model = model.to(device)

class_weights = torch.tensor([1.0 / 0.05, 1.0 / 0.95], dtype=torch.float).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
loss = nn.CrossEntropyLoss(weight=class_weights)

train_dataset = Deception_dataset_context(train, tokenizer)
val_dataset = Deception_dataset_context(val, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn_context)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn_context)
test_dataset = Deception_dataset_context(test, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn_context)

model.eval()
test_preds, test_labels = [], []
with torch.no_grad():
    for batch in test_dataloader:
        messages = batch["messages"].to(device)
        lengths = batch["lengths"]
        labels = batch["labels"].to(device)
        num_messages = batch['num_messages']
        scores = batch['deltas'].to(device)
        logits = model(messages, num_messages,lengths,scores)

        preds = torch.argmax(logits, dim=1).cpu().numpy()
        test_preds.extend(preds)
        test_labels.extend(labels.cpu().numpy())


test_f1 = f1_score(test_labels, test_preds, average='macro')
test_acc = accuracy_score(test_labels,test_preds)

print("test_f1: ",test_f1)
print("test accuracy: ",test_acc)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  model.load_state_dict(torch.load("/kaggle/input/diplomacy-dataset/gat_contextlstm_glove_11.pth"))


test_f1:  0.5350073916701014
test accuracy:  0.8080586080586081
