In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import jsonlines
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score

In [139]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [140]:
DATA_PATH = '../data'

In [141]:
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [142]:
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Saurav\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Saurav\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [None]:
class DiplomacyVocabulary(Dataset):
    def __init__(self):
        self.word2idx = {"PAD": 0, "UNK": 1}
        self.idx2word = {0: "PAD", 1: "UNK"}
        
    def add_token(self, token):
        if token not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[token] = idx
            self.idx2word[idx] = token
        
    def __len__(self):
        return len(self.word2idx)
    
    def tokenize(self, message):
        message = message.lower()
        tokens = word_tokenize(message)
        return [self.word2idx.get(token, 1) for token in tokens]

In [144]:
class DiplomacyDataset(Dataset):
    def __init__(self, file_path, vocab=None, construct=False):
        self.data = []
        self.vocab = vocab if vocab else DiplomacyVocabulary()
        
        with jsonlines.open(file_path, 'r') as f:
            for line in f:
                for i, message in enumerate(line['messages']):
                    if line['sender_labels'][i] == 'NOANNOTATION':
                        continue
                    self.data.append({
                        'message': message,
                        'label': 1 if line['sender_labels'][i] else 0
                    })
                    
                    if construct:
                        tokens = word_tokenize(message)
                        for token in tokens:
                            self.vocab.add_token(token)
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = self.vocab.tokenize(self.data[idx]['message'])
        
        return {
            'tokens': torch.tensor(tokens, dtype=torch.long),
            'label': torch.tensor(self.data[idx]['label'], dtype=torch.long),
        }

In [145]:
def collate_fn(batch):
    batch = sorted(batch, key=lambda x: len(x['tokens']), reverse=True)
    
    tokens = pad_sequence([item['tokens'] for item in batch], batch_first=True, padding_value=0)
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'tokens': tokens,
        'labels': labels,
    }

In [146]:
train_dataset = DiplomacyDataset(f'{DATA_PATH}/train.jsonl', construct=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
vocab = train_dataset.vocab

val_dataset = DiplomacyDataset(f'{DATA_PATH}/validation.jsonl', vocab=vocab)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

test_dataset = DiplomacyDataset(f'{DATA_PATH}/test.jsonl', vocab=vocab)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
class LSTM_Model(nn.Module):
    def __init__(self, vocab_size, pretrained_embeddings, embedding_dim=200, hidden_size=100):
        super(LSTM_Model, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(hidden_size * 2, 1)
        
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        
        lstm_out, _ = self.lstm(embeddings)
        
        pooled_out, _ = torch.max(lstm_out, dim=1)
        
        pooled_out = self.dropout(pooled_out)
        
        # apply sigmoid activation
        logits = torch.sigmoid(pooled_out)

In [150]:
embedding_path = '../embeddings/glove.twitter.27B.200d.txt'

pretrained_embeddings = np.zeros((len(vocab), 200))

with open(embedding_path, 'r', encoding='utf-8') as f:
    for line in tqdm(f):
        values = line.split()
        word = values[0]
        if word in vocab.word2idx:
            idx = vocab.word2idx[word]
            pretrained_embeddings[idx] = np.array(values[1:], dtype='float32')

1193514it [00:20, 58560.68it/s]


In [None]:
model = LSTM_Model(vocab_size=len(vocab), pretrained_embeddings=pretrained_embeddings).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.003)

criteria = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 30.0]).to(device))

EPOCHS = 15
PATIENCE = 5

patience_counter = 0
best_f1 = 0.0
best_model = None

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader):
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(tokens)
        
        loss = criteria(logits.view(-1, 2), labels.view(-1))
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
        
    print(f'Epoch {epoch+1}/{EPOCHS}')
    print(f'Train Loss: {total_loss/len(train_loader)}')
        
    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in val_loader:
            tokens = batch['tokens'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(tokens)
            preds = torch.argmax(logits, dim=1)
            
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
            
    f1 = f1_score(val_labels, val_preds)
    accuracy = accuracy_score(val_labels, val_preds)
    
    if f1 > best_f1:
        best_f1 = f1
        best_model = model.state_dict().copy()
        patience_counter = 0
    else:
        patience_counter += 1
        
    print(f'Validation F1: {f1}')
    print(f'Validation Accuracy: {accuracy}')
        
    if patience_counter >= PATIENCE:
        print(f'Early stopping at epoch {epoch+1}')
        break
    
torch.save(best_model, 'model.pth')

100%|██████████| 411/411 [00:22<00:00, 17.88it/s]


Epoch 1/15
Train Loss: 0.01731064073860613


  7%|▋         | 1/15 [00:23<05:35, 23.93s/it]

Validation F1: 0.9798270893371758
Validation Accuracy: 0.96045197740113


100%|██████████| 411/411 [00:22<00:00, 18.51it/s]


Epoch 2/15
Train Loss: 0.012377512930448274


 13%|█▎        | 2/15 [00:46<05:04, 23.40s/it]

Validation F1: 0.9798270893371758
Validation Accuracy: 0.96045197740113


100%|██████████| 411/411 [00:22<00:00, 18.43it/s]


Epoch 3/15
Train Loss: 0.012270296426714282


 20%|██        | 3/15 [01:10<04:39, 23.26s/it]

Validation F1: 0.9798270893371758
Validation Accuracy: 0.96045197740113


100%|██████████| 411/411 [00:22<00:00, 18.45it/s]


Epoch 4/15
Train Loss: 0.012218155585305307


 27%|██▋       | 4/15 [01:33<04:15, 23.21s/it]

Validation F1: 0.9798270893371758
Validation Accuracy: 0.96045197740113


100%|██████████| 411/411 [00:24<00:00, 16.52it/s]


Epoch 5/15
Train Loss: 0.012146962626152865


 33%|███▎      | 5/15 [01:59<04:02, 24.22s/it]

Validation F1: 0.9798270893371758
Validation Accuracy: 0.96045197740113


100%|██████████| 411/411 [00:23<00:00, 17.21it/s]


Epoch 6/15
Train Loss: 0.012162367674072302


 33%|███▎      | 5/15 [02:23<04:47, 28.78s/it]

Validation F1: 0.9798270893371758
Validation Accuracy: 0.96045197740113
Early stopping at epoch 6





In [155]:
model = LSTM_Model(vocab_size=len(vocab), pretrained_embeddings=pretrained_embeddings).to(device)
model.load_state_dict(torch.load('model.pth'))

model.eval()
with torch.no_grad():
    test_preds = []
    test_labels = []
    for batch in test_loader:
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].to(device)
        
        logits = model(tokens)
        preds = torch.argmax(logits, dim=1)
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
        
    f1 = f1_score(test_labels, test_preds)
    accuracy = accuracy_score(test_labels, test_preds)
    
print(f'Test F1: {f1}')
print(f'Test Accuracy: {accuracy}')

Test F1: 0.9542159481114079
Test Accuracy: 0.9124407150674936
