In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import jsonlines
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score
import copy

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [13]:
DATA_PATH = '../data'

In [14]:
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Saurav\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Saurav\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [15]:
class DiplomacyVocabulary(Dataset):
    def __init__(self):
        self.word2idx = {"PAD": 0, "UNK": 1}
        self.idx2word = {0: "PAD", 1: "UNK"}
        
    def add_token(self, token):
        if token not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[token] = idx
            self.idx2word[idx] = token
        
    def __len__(self):
        return len(self.word2idx)
    
    def tokenize(self, message):
        message = message.lower()
        tokens = word_tokenize(message)
        return [self.word2idx.get(token, 1) for token in tokens]

In [16]:
class DiplomacyDataset(Dataset):
    def __init__(self, file_path, vocab=None, construct=False):
        self.data = []
        self.vocab = vocab if vocab else DiplomacyVocabulary()
        
        with jsonlines.open(file_path, 'r') as f:
            for line in f:
                for i, message in enumerate(line['messages']):
                    if line['sender_labels'][i] == 'NOANNOTATION':
                        continue
                    self.data.append({
                        'message': message,
                        'label': 1 if line['sender_labels'][i] else 0
                    })
                    
                    if construct:
                        tokens = word_tokenize(message)
                        for token in tokens:
                            self.vocab.add_token(token)
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = self.vocab.tokenize(self.data[idx]['message'])
        
        return {
            'tokens': torch.tensor(tokens, dtype=torch.long),
            'label': torch.tensor(self.data[idx]['label'], dtype=torch.long),
        }

In [17]:
def collate_fn(batch):
    batch = sorted(batch, key=lambda x: len(x['tokens']), reverse=True)
    
    tokens = pad_sequence([item['tokens'] for item in batch], batch_first=True, padding_value=0)
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'tokens': tokens,
        'labels': labels,
    }

In [18]:
train_dataset = DiplomacyDataset(f'{DATA_PATH}/train.jsonl', construct=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
vocab = train_dataset.vocab

test_dataset = DiplomacyDataset(f'{DATA_PATH}/test.jsonl', vocab=vocab)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [19]:
class LSTM_Model(nn.Module):
    def __init__(self, vocab_size, pretrained_embeddings=None, embedding_dim=200, hidden_size=100):
        super(LSTM_Model, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(hidden_size * 2, 1)
        
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        
        lstm_out, _ = self.lstm(embeddings)
        
        pooled_out, _ = torch.max(lstm_out, dim=1)
        
        pooled_out = self.dropout(pooled_out)
        
        logits = self.fc(pooled_out)
        return logits.squeeze(1)

In [20]:
model = LSTM_Model(vocab_size=len(vocab)).to(device)
model.load_state_dict(torch.load('model.pth'))

model.eval()
with torch.no_grad():
    test_preds = []
    test_labels = []
    for batch in test_loader:
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].float().to(device)
        
        logits = model(tokens)
        preds = (logits > 0.5).float()
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
        
    f1 = f1_score(test_labels, test_preds, average='macro')
    weighted_f1 = f1_score(test_labels, test_preds, average='weighted')
    accuracy = accuracy_score(test_labels, test_preds)
    
    
print(f'Test Macro F1: {f1}')
print(f'Test Weighted F1: {weighted_f1}')
print(f'Test Accuracy: {accuracy}')

Test Macro F1: 0.5216035825665345
Test Weighted F1: 0.830274129151589
Test Accuracy: 0.8143013498723094
