In [7]:
import json
import re
from nltk.tokenize import word_tokenize
from transformers import BertTokenizer
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import f1_score
import nltk
nltk.download('punkt_tab')

from tqdm import tqdm


[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [8]:
!pip install torch_geometric



In [9]:
# from google.colab import drive
# drive.mount('/content/drive')


In [10]:
import random

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [11]:
with open("/kaggle/input/dataset/2020_acl_diplomacy-master/data/test.jsonl", "r", encoding="utf-8") as file:
    test_data = [json.loads(line) for line in file]

with open("/kaggle/input/dataset/2020_acl_diplomacy-master/data/train.jsonl", "r", encoding="utf-8") as file:
    train_data = [json.loads(line) for line in file]

with open("/kaggle/input/dataset/2020_acl_diplomacy-master/data/validation.jsonl", "r", encoding="utf-8") as file:
    val_data = [json.loads(line) for line in file]



In [12]:
def preprocess(sentence ):
    sentence=sentence.lower()

    sentence = re.sub(r"[^a-zA-Z0-9 ]", "", sentence)  # can use punctations with bert , not with glove
    sentence = re.sub(r"\s+", " ", sentence).strip()

    return sentence

def prep_data_context(data ,  is_sender ):
    final_data=[]
    for data_points in data:
        sub = []
        for i, message in enumerate(data_points["messages"]):

            msg=preprocess(message )
            msg=word_tokenize(msg)
            if(len(msg)==0): continue

            if(is_sender):
              if(data_points['sender_labels'][i]=='NOANNOTATION'):
                continue
            else:
              if(data_points['receiver_labels'][i]=='NOANNOTATION'):
                continue

            sub.append({"message":msg ,
                        "label":(data_points["receiver_labels"][i],data_points["sender_labels"][i] )[is_sender] ,
                        "game_score_delta": int(data_points["game_score_delta"][i])})
        if(len(sub)==0): continue
        final_data.append(sub)
    return final_data

In [13]:
val=prep_data_context(val_data, 0)
train=prep_data_context(train_data,0)
test=prep_data_context(test_data ,0)

In [14]:
tokens = []
for sub in train:
  for data_p in sub:
    for word in data_p["message"]:
        tokens.append(word)
tokens=list(set(tokens))

vocab = {token:idx+2  for idx , token in enumerate(tokens)}
vocab["<PAD>"]=0
vocab["<UNK>"]=1


In [15]:

class Deception_dataset_context(Dataset):
    def __init__(self, data , vocab):
        self.data = data
        self.vocab= vocab
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
      data_p = self.data[idx]
      msg_ids = []
      for sub in data_p:
          sub_msg_id = []
          for token in sub['message']:
              if token in self.vocab:
                  sub_msg_id.append(self.vocab[token])
              else:
                  sub_msg_id.append(self.vocab["<UNK>"])
          msg_ids.append(torch.tensor(sub_msg_id, dtype=torch.long))
          # pritn(msg_ids[-1].shape)
      try:
          return {
              "messages": msg_ids,
              "labels": torch.tensor([i['label'] for i in data_p], dtype=torch.long),
              "game_score_delta": torch.tensor([i["game_score_delta"] for i in data_p], dtype=torch.float)
          }
      except Exception as e:
          print("issue: ", e)
          return



In [16]:
def collate_fn_context(batch):
    messages = []
    labels = []
    lengths = []
    game_score_deltas = []
    num_messages = []
    for i in batch:
      # print(len(i['messages']))
      # print((i['messages'][0]).shape)
      messages.extend(i['messages'])
      lengths.extend([len(j) for j in i['messages']])
      labels.extend(i['labels'])
      num_messages.append(len(i['messages']))
      game_score_deltas.extend(i['game_score_delta'])
    # print(len(messages))
    padded_messages = pad_sequence(messages,batch_first=True,padding_value=0)


    return {
        "messages": padded_messages,
        "lengths": lengths,
        "labels": torch.tensor(labels),
        "num_messages":num_messages,
        "deltas":torch.tensor(game_score_deltas)
    }


In [17]:
def load_glove(file):
    embeddings = {}
    with open(file, 'r', encoding='utf-8') as f:
        for line in f:
            embed = line.split()
            word = embed[0]
            embedding = torch.tensor([float(i) for i in embed[1:]], dtype=torch.float)
            embeddings[word] = embedding
    return embeddings

In [50]:
import random
import torch_geometric
from torch_geometric.data import Batch, Data

class ContextLSTM(nn.Module):
  def __init__(self,vocab, glove_file, embedding_dim, hidden_size_message, gat_dim, num_classes):
    super(ContextLSTM,self).__init__()
    seed = 100
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    glove_embeddings = load_glove(glove_file)

    vocab_size = len(vocab)
    self.embedding_matrix = torch.zeros(vocab_size, embedding_dim)
    for token, idx in vocab.items():
        if token in glove_embeddings:
            self.embedding_matrix[idx] = glove_embeddings[token]
        else:
            self.embedding_matrix[idx] = torch.randn(embedding_dim) * 0.6

    self.embedding = nn.Embedding.from_pretrained(self.embedding_matrix, freeze=True, padding_idx=0)

    self.lstm_message = nn.LSTM(embedding_dim, hidden_size_message, batch_first=True, bidirectional=True)
    # self.gat1 = torch_geometric.nn.conv.GATConv(in_channels=hidden_size_message*2,out_channels=gat_dim//2,heads=4)
    self.gat1 = torch_geometric.nn.conv.GATConv(in_channels=hidden_size_message*2,out_channels=gat_dim*2,heads=4,  concat=False)
      
    self.gat2 = torch_geometric.nn.conv.GATConv(in_channels=gat_dim*2,out_channels=gat_dim,heads=2 , concat=False)
    self.fc = nn.Linear(hidden_size_message * 2 + gat_dim, num_classes)
    self.relu = torch.nn.ReLU()
    self.tanh = torch.nn.Tanh()
    self.sigmoid = torch.nn.Sigmoid()  
    self.gate1 = nn.Linear(hidden_size_message * 2 + gat_dim,hidden_size_message * 2)
    self.gate2 = nn.Linear(hidden_size_message * 2 + gat_dim,gat_dim)
    self.hidden_size_message = hidden_size_message
    self.gat_dim = gat_dim

  def fusion(self,x):
      lstm_emb = x[:,:2*self.hidden_size_message]
      gat_emb = x[:,2*self.hidden_size_message:]
      
      lstm_emb_tanh = self.tanh(lstm_emb) 
      gat_emb_tanh = self.tanh(gat_emb)
      # print('lstm emb: ',lstm_emb.shape)
      # print("gat emb shape: ",gat_emb.shape)
      mixed = torch.cat((lstm_emb,gat_emb),dim=1)
      mixed_sigmoid_left = self.sigmoid(self.gate1(mixed))
      mixed_sigmoid_right = self.sigmoid(self.gate2(mixed))

      final_emb = torch.cat((mixed_sigmoid_left*lstm_emb_tanh,mixed_sigmoid_right*gat_emb_tanh),dim=1)
      return final_emb

  def get_message_emb(self,input_ids,lengths):
    embedded = self.embedding(input_ids)
    # print(embedded.shape)
    packed = pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False)

    _, (hn, _1) = self.lstm_message(packed)
    last_hidden = torch.cat((hn[0], hn[1]), dim=1)
    return last_hidden

  

  def forward(self, input_ids, num_messages, lengths , scores):
    total_messages = sum(num_messages)
    inputs_to_msg_encoder = input_ids
    encoded_messages = self.get_message_emb(inputs_to_msg_encoder,lengths)
    input_to_convo_encoder = torch.split(encoded_messages,num_messages)
    # print(input_to_convo_encoder[0].shape)

    tot_data = []

    for i in range(len(input_to_convo_encoder)):
        msg_embs = input_to_convo_encoder[i] # num_msgs x emb_dim
        n_nodes = msg_embs.shape[0]
        edge_ind = torch.combinations(torch.arange(n_nodes), r=2).T
        edge_ind = torch.cat([edge_ind, edge_ind.flip(0)], dim=1).to(input_ids.device)
    
        tot_data.append(Data(x=msg_embs,edge_index=edge_ind))

    tot_data = Batch.from_data_list(tot_data)
    from_gat1 = self.relu(self.gat1(tot_data.x,tot_data.edge_index))
    
    # print("from gat 1 shape: ",from_gat1.shape)

    from_gat2 = self.gat2(from_gat1,tot_data.edge_index)
    # print("enc msg shape: ",encoded_messages.shape)
    # print("from_gat2 shap: ",from_gat2.shape)
      
    input_for_fusion = [] 
    prev = 0
    # print("enc msg shape: ",encoded_messages.shape)
    for i, msg_emb in enumerate(encoded_messages):
        # print("i: ",i)
        inp = torch.cat((msg_emb,from_gat2[i]),dim=0)
        input_for_fusion.append(inp)
    # print("shape before fusion: ",torch.stack(input_for_fusion).shape)
    final_combined = self.fusion(torch.stack(input_for_fusion))
    # print("final combined shape: ",final_combined.shape)
    final_input = torch.cat((final_combined,  scores.unsqueeze(1)), dim=1)
    # print(final_input.shape)
    logits = self.fc(final_combined)
    # print("logits shape: ",logits.shape)
    return logits



Train Context LSTM:

In [53]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

glove_file = "/kaggle/input/gloveembed/glove.6B.100d.txt"
embedding_dim = 100
hidden_size = 100
num_classes = 2
gat_dim = 50

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


model = ContextLSTM(vocab, glove_file, embedding_dim, hidden_size,gat_dim, num_classes)
model = model.to(device)

class_weights = torch.tensor([1.0 / 0.10, 1.0 / 0.90], dtype=torch.float)
class_weights = class_weights.to(device)


# model = ContextLSTM(vocab, glove_file, embedding_dim, hidden_size//2,hidden_size, num_classes)
# model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
loss = nn.CrossEntropyLoss(weight=class_weights)

train_dataset = Deception_dataset_context(train, vocab)
val_dataset = Deception_dataset_context(val, vocab)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_context)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn_context)


test_dataset = Deception_dataset_context(test, vocab)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_context)

print("hey")
for epoch in range(16):
    
    model.train()
    train_loss = 0
    train_preds, train_labels = [], []
    for batch in tqdm(train_dataloader):
        messages = batch["messages"].to(device)
        lengths = batch["lengths"]
        labels = batch["labels"].to(device)
        num_messages = batch['num_messages']
        scores = batch['deltas'].to(device)

        optimizer.zero_grad()
        logits = model(messages, num_messages,lengths,scores)
        loss_ = loss(logits, labels)
        loss_.backward()
        optimizer.step()

        train_loss += loss_.item()
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        train_preds.extend(preds)
        train_labels.extend(labels.cpu().numpy())


    train_loss /= len(train_dataloader)
    train_f1 = f1_score(train_labels, train_preds, average='macro')


    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_dataloader:
            messages = batch["messages"].to(device)
            lengths = batch["lengths"]
            labels = batch["labels"].to(device)
            num_messages = batch['num_messages']
            scores = batch['deltas'].to(device)
            logits = model(messages, num_messages,lengths,scores)
            loss_ = loss(logits, labels)
            val_loss += loss_.item()
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(preds)
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_dataloader)
    val_f1 = f1_score(val_labels, val_preds, average='macro')

    model.eval()
    test_preds, test_labels = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            messages = batch["messages"].to(device)
            lengths = batch["lengths"]
            labels = batch["labels"].to(device)
            num_messages = batch['num_messages']
            scores = batch['deltas'].to(device)
            logits = model(messages, num_messages,lengths,scores)
    
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            test_preds.extend(preds)
            test_labels.extend(labels.cpu().numpy())
    
    
    test_f1 = f1_score(test_labels, test_preds, average='macro')
    
    print(test_f1)


    print(f"Epoch {epoch+1}:  Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f} Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}")

hey


100%|██████████| 12/12 [00:01<00:00,  6.45it/s]


0.4824690321226118
Epoch 1:  Train Loss: 0.6508, Train F1: 0.4873 Val Loss: 0.6437, Val F1: 0.4906


100%|██████████| 12/12 [00:01<00:00,  6.86it/s]


0.4826862539349423
Epoch 2:  Train Loss: 0.6250, Train F1: 0.4897 Val Loss: 0.6309, Val F1: 0.4906


100%|██████████| 12/12 [00:01<00:00,  6.84it/s]


0.4824690321226118
Epoch 3:  Train Loss: 0.6261, Train F1: 0.4879 Val Loss: 0.6310, Val F1: 0.4906


100%|██████████| 12/12 [00:01<00:00,  6.92it/s]


0.4824690321226118
Epoch 4:  Train Loss: 0.6335, Train F1: 0.4894 Val Loss: 0.6339, Val F1: 0.4906


100%|██████████| 12/12 [00:01<00:00,  6.97it/s]


0.4824690321226118
Epoch 5:  Train Loss: 0.6066, Train F1: 0.4877 Val Loss: 0.6401, Val F1: 0.4906


100%|██████████| 12/12 [00:01<00:00,  6.97it/s]


0.4824690321226118
Epoch 6:  Train Loss: 0.6167, Train F1: 0.4877 Val Loss: 0.6351, Val F1: 0.4906


100%|██████████| 12/12 [00:01<00:00,  6.80it/s]


0.4824690321226118
Epoch 7:  Train Loss: 0.6188, Train F1: 0.4912 Val Loss: 0.6390, Val F1: 0.4904


100%|██████████| 12/12 [00:01<00:00,  6.56it/s]


0.4818162707588816
Epoch 8:  Train Loss: 0.6043, Train F1: 0.5031 Val Loss: 0.6564, Val F1: 0.4890


100%|██████████| 12/12 [00:01<00:00,  6.48it/s]


0.489213691026827
Epoch 9:  Train Loss: 0.5940, Train F1: 0.5148 Val Loss: 0.6427, Val F1: 0.4964


100%|██████████| 12/12 [00:01<00:00,  6.39it/s]


0.49724859430363044
Epoch 10:  Train Loss: 0.5949, Train F1: 0.5316 Val Loss: 0.6394, Val F1: 0.4975


100%|██████████| 12/12 [00:01<00:00,  6.39it/s]


0.502075221332762
Epoch 11:  Train Loss: 0.5783, Train F1: 0.5492 Val Loss: 0.6688, Val F1: 0.4888


100%|██████████| 12/12 [00:01<00:00,  6.55it/s]


0.5120532980322534
Epoch 12:  Train Loss: 0.5623, Train F1: 0.5637 Val Loss: 0.6750, Val F1: 0.4939


100%|██████████| 12/12 [00:01<00:00,  6.43it/s]


0.5093535616065762
Epoch 13:  Train Loss: 0.5171, Train F1: 0.5779 Val Loss: 0.7949, Val F1: 0.5032


100%|██████████| 12/12 [00:01<00:00,  6.61it/s]


0.5408900603168308
Epoch 14:  Train Loss: 0.5187, Train F1: 0.5909 Val Loss: 0.6650, Val F1: 0.4860


100%|██████████| 12/12 [00:01<00:00,  6.57it/s]


0.5241594448719896
Epoch 15:  Train Loss: 0.4682, Train F1: 0.6151 Val Loss: 0.8264, Val F1: 0.5229


100%|██████████| 12/12 [00:01<00:00,  6.36it/s]


0.5456896551724137
Epoch 16:  Train Loss: 0.4312, Train F1: 0.6416 Val Loss: 0.8424, Val F1: 0.5037


In [55]:
 torch.save(model.state_dict(),f"gat+lstm+glove_best.pth")

In [54]:
test_dataset = Deception_dataset_context(test, vocab)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_context)


model.eval()
test_preds, test_labels = [], []
with torch.no_grad():
    for batch in test_dataloader:
        messages = batch["messages"].to(device)
        lengths = batch["lengths"]
        labels = batch["labels"].to(device)
        num_messages = batch['num_messages']
        scores = batch['deltas'].to(device)
        logits = model(messages, num_messages,lengths,scores)

        preds = torch.argmax(logits, dim=1).cpu().numpy()
        test_preds.extend(preds)
        test_labels.extend(labels.cpu().numpy())


test_f1 = f1_score(test_labels, test_preds, average='macro')

print(test_f1)

0.5456896551724137
