In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import jsonlines
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score
import copy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
DATA_PATH = '../data'

In [4]:
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [5]:
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Saurav\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Saurav\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [6]:
class DiplomacyVocabulary(Dataset):
    def __init__(self):
        # Initialize the vocabulary with special tokens
        self.word2idx = {"PAD": 0, "UNK": 1}
        self.idx2word = {0: "PAD", 1: "UNK"}
        
    def add_token(self, token):
        # Add a new token to the vocabulary
        if token not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[token] = idx
            self.idx2word[idx] = token
        
    def __len__(self):
        return len(self.word2idx)
    
    def tokenize(self, message):
        message = message.lower()
        tokens = word_tokenize(message)
        return [self.word2idx.get(token, 1) for token in tokens]

In [7]:
class DiplomacyDataset(Dataset):
    def __init__(self, file_path, vocab=None, construct=False):
        self.data = []
        self.vocab = vocab if vocab else DiplomacyVocabulary()
        
        with jsonlines.open(file_path, 'r') as f:
            for line in f:
                for i, message in enumerate(line['messages']):
                    # Ignore messages with no sender labels
                    if line['sender_labels'][i] == 'NOANNOTATION':
                        continue
                    self.data.append({
                        'message': message,
                        'label': 1 if line['sender_labels'][i] else 0 # 1 for True, 0 for False message
                    })
                    
                    if construct:
                        tokens = word_tokenize(message)
                        for token in tokens:
                            self.vocab.add_token(token)
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = self.vocab.tokenize(self.data[idx]['message'])
        
        return {
            'tokens': torch.tensor(tokens, dtype=torch.long),
            'label': torch.tensor(self.data[idx]['label'], dtype=torch.long),
        }

In [8]:
def collate_fn(batch):
    batch = sorted(batch, key=lambda x: len(x['tokens']), reverse=True)
    
    tokens = pad_sequence([item['tokens'] for item in batch], batch_first=True, padding_value=0)
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'tokens': tokens,
        'labels': labels,
    }

In [9]:
train_dataset = DiplomacyDataset(f'{DATA_PATH}/train.jsonl', construct=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
vocab = train_dataset.vocab

val_dataset = DiplomacyDataset(f'{DATA_PATH}/validation.jsonl', vocab=vocab)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

test_dataset = DiplomacyDataset(f'{DATA_PATH}/test.jsonl', vocab=vocab)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [10]:
class LSTM_Model(nn.Module):
    def __init__(self, vocab_size, pretrained_embeddings=None, embedding_dim=200, hidden_size=100):
        super(LSTM_Model, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        if pretrained_embeddings is not None:
            # Load the pretrained embeddings
            self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            self.embedding.weight.requires_grad = False
        # Initialize the BiLSTM layer
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(hidden_size * 2, 1)
        
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        
        lstm_out, _ = self.lstm(embeddings)
        
        # Use max pooling over the LSTM outputs
        pooled_out, _ = torch.max(lstm_out, dim=1)
        
        pooled_out = self.dropout(pooled_out)
        
        logits = self.fc(pooled_out)
        # Return the logits for binary classification
        return logits.squeeze(1)

In [11]:
embedding_path = '../embeddings/glove.twitter.27B.200d.txt'

pretrained_embeddings = np.zeros((len(vocab), 200))

with open(embedding_path, 'r', encoding='utf-8') as f:
    for line in tqdm(f):
        values = line.split()
        word = values[0]
        if word in vocab.word2idx:
            idx = vocab.word2idx[word]
            pretrained_embeddings[idx] = np.array(values[1:], dtype='float32')

1193514it [00:17, 68701.75it/s]


In [12]:
model = LSTM_Model(vocab_size=len(vocab), pretrained_embeddings=pretrained_embeddings).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.003)

# Adjusting the loss function to account for class imbalance
criteria = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0/30.0]).to(device))

EPOCHS = 15
PATIENCE = 5

patience_counter = 0
best_f1 = 0.0
best_model = None

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader):
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].float().to(device)
        
        optimizer.zero_grad()
        
        logits = model(tokens)
        
        loss = criteria(logits, labels)
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
        
    print(f'Epoch: {epoch+1}/{EPOCHS}')
    print(f'Train Loss: {total_loss/len(train_loader)}')
        
    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in val_loader:
            tokens = batch['tokens'].to(device)
            labels = batch['labels'].float().to(device)
            
            logits = model(tokens)
            preds = (logits > 0.5).float()
            
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
            
    f1 = f1_score(val_labels, val_preds, average='macro')
    weighted_f1 = f1_score(val_labels, val_preds, average='weighted')
    accuracy = accuracy_score(val_labels, val_preds)
    
    if f1 > best_f1:
        best_f1 = f1
        best_model = copy.deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1
        
    print(f'Validation Macro F1: {f1}')
    print(f'Validation Weighted F1: {weighted_f1}')
    print(f'Validation Accuracy: {accuracy}')
        
    if patience_counter >= PATIENCE:
        print(f'Early stopping at epoch {epoch+1}')
        break
    
torch.save(best_model, 'model.pth')

100%|██████████| 411/411 [00:22<00:00, 17.96it/s]


Epoch: 1/15
Train Loss: 0.05138159521522313
Validation Macro F1: 0.06983685364811085
Validation Weighted F1: 0.061972671472173
Validation Accuracy: 0.06991525423728813


100%|██████████| 411/411 [00:22<00:00, 18.57it/s]


Epoch: 2/15
Train Loss: 0.0494637707261926
Validation Macro F1: 0.17325739649897282
Validation Weighted F1: 0.25552196682004963
Validation Accuracy: 0.182909604519774


100%|██████████| 411/411 [00:22<00:00, 18.30it/s]


Epoch: 3/15
Train Loss: 0.047815250060600376
Validation Macro F1: 0.20125930888704707
Validation Weighted F1: 0.3061913845820305
Validation Accuracy: 0.2175141242937853


100%|██████████| 411/411 [00:24<00:00, 17.08it/s]


Epoch: 4/15
Train Loss: 0.0442759493057709
Validation Macro F1: 0.26118296276170605
Validation Weighted F1: 0.42022648318255285
Validation Accuracy: 0.3015536723163842


100%|██████████| 411/411 [00:24<00:00, 16.98it/s]


Epoch: 5/15
Train Loss: 0.041205868222399494
Validation Macro F1: 0.4312021014433883
Validation Weighted F1: 0.735277220486472
Validation Accuracy: 0.6228813559322034


100%|██████████| 411/411 [00:23<00:00, 17.75it/s]


Epoch: 6/15
Train Loss: 0.03769032843410969
Validation Macro F1: 0.39120478274446624
Validation Weighted F1: 0.6592518039523143
Validation Accuracy: 0.530367231638418


100%|██████████| 411/411 [00:23<00:00, 17.43it/s]


Epoch: 7/15
Train Loss: 0.03499473093435567
Validation Macro F1: 0.3691295386381273
Validation Weighted F1: 0.6386902833862214
Validation Accuracy: 0.5049435028248588


100%|██████████| 411/411 [00:22<00:00, 18.19it/s]


Epoch: 8/15
Train Loss: 0.032257872019552256
Validation Macro F1: 0.343328436545944
Validation Weighted F1: 0.5803524118877647
Validation Accuracy: 0.4442090395480226


100%|██████████| 411/411 [00:22<00:00, 18.18it/s]


Epoch: 9/15
Train Loss: 0.02718315339237089
Validation Macro F1: 0.44807001411837466
Validation Weighted F1: 0.7643057771596424
Validation Accuracy: 0.661723163841808


100%|██████████| 411/411 [00:22<00:00, 18.46it/s]


Epoch: 10/15
Train Loss: 0.02406762770791776
Validation Macro F1: 0.4643421052631579
Validation Weighted F1: 0.8182842699970265
Validation Accuracy: 0.7401129943502824


100%|██████████| 411/411 [00:22<00:00, 18.05it/s]


Epoch: 11/15
Train Loss: 0.022138018966618465
Validation Macro F1: 0.47217320722755507
Validation Weighted F1: 0.831925689147502
Validation Accuracy: 0.7612994350282486


100%|██████████| 411/411 [00:22<00:00, 18.04it/s]


Epoch: 12/15
Train Loss: 0.01918560004432815
Validation Macro F1: 0.47905818428797825
Validation Weighted F1: 0.8544459068647202
Validation Accuracy: 0.7980225988700564


100%|██████████| 411/411 [00:24<00:00, 16.67it/s]


Epoch: 13/15
Train Loss: 0.01661848627459122
Validation Macro F1: 0.49203422346459724
Validation Weighted F1: 0.886812678659658
Validation Accuracy: 0.8538135593220338


100%|██████████| 411/411 [00:26<00:00, 15.38it/s]


Epoch: 14/15
Train Loss: 0.01586018660085371
Validation Macro F1: 0.4618880461355231
Validation Weighted F1: 0.8090526160710081
Validation Accuracy: 0.7259887005649718


100%|██████████| 411/411 [00:21<00:00, 18.86it/s]


Epoch: 15/15
Train Loss: 0.01583990312822689
Validation Macro F1: 0.48811972231386397
Validation Weighted F1: 0.8869391441948788
Validation Accuracy: 0.8545197740112994


In [14]:
model = LSTM_Model(vocab_size=len(vocab)).to(device)
model.load_state_dict(torch.load('model.pth'))

model.eval()
with torch.no_grad():
    test_preds = []
    test_labels = []
    for batch in test_loader:
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].float().to(device)
        
        logits = model(tokens)
        preds = (logits > 0.5).float()
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
        
    f1 = f1_score(test_labels, test_preds, average='macro')
    weighted_f1 = f1_score(test_labels, test_preds, average='weighted')
    accuracy = accuracy_score(test_labels, test_preds)
    
    
print(f'Test Macro F1: {f1}')
print(f'Test Weighted F1: {weighted_f1}')
print(f'Test Accuracy: {accuracy}')

Test Macro F1: 0.5216035825665345
Test Weighted F1: 0.830274129151589
Test Accuracy: 0.8143013498723094
