In [None]:
import functools
import sys

from datasets import Dataset, DatasetDict
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import tqdm
from nltk import word_tokenize

In [None]:
seed = 0

torch.manual_seed(seed)

In [None]:
def read_dataset(filepath, lower=True):
    tokens, labels = [], []
    with open(filepath, encoding='utf-8') as f:
        for line in f:
            text, label = line.strip().split('\t')
            cur_tokens = word_tokenize(text, language='russian')
            if lower:
                cur_tokens = [token.lower() for token in cur_tokens]
            labels.append(label)
            tokens.append(cur_tokens)
    
    return Dataset.from_dict({'tokens': tokens, 'label': labels})

In [None]:
TEXT_LOWER = True

In [None]:
data = DatasetDict()

for split_name in ['train', 'validation', 'test']:
    data[split_name] = read_dataset(f'data/sensitive_topics/{split_name}.tsv', lower=TEXT_LOWER)

In [None]:
min_freq = 2
special_tokens = ['<unk>', '<pad>']

tokens_vocab = torchtext.vocab.build_vocab_from_iterator(data['train']['tokens'],
                                                  min_freq=min_freq,
                                                  specials=special_tokens)

idx_to_label = list(set(data['train']['label']))
label_to_idx = {label: idx for idx, label in enumerate(idx_to_label)}

In [None]:
unk_index = tokens_vocab['<unk>']
pad_index = tokens_vocab['<pad>']

In [None]:
tokens_vocab.set_default_index(unk_index)

In [None]:
def numericalize_data(example, tokens_vocab, label_to_idx):
    token_idxs = tokens_vocab.forward(example['tokens'])
    label_idx = label_to_idx[example['label']]
    return {'tokens': token_idxs, 'label': label_idx}

In [None]:
transformed_data = data.map(numericalize_data, fn_kwargs={'tokens_vocab': tokens_vocab,
                                                          'label_to_idx': label_to_idx})

In [None]:
transformed_data = transformed_data.with_format(type='torch')

In [None]:
transformed_data['train'][0]

In [None]:
def collate_batch(batch):
    batch_tokens = [example['tokens'] for example in batch]
    batch_labels = torch.stack([example['label'] for example in batch])
    batch_tokens = nn.utils.rnn.pad_sequence(batch_tokens, padding_value=tokens_vocab['<pad>'], batch_first=True)
    batch = {'tokens': batch_tokens,
             'label': batch_labels}
    return batch

In [None]:
BATCH_SIZE = 8

train_dataloader = torch.utils.data.DataLoader(transformed_data['train'], 
                                               batch_size=BATCH_SIZE, 
                                               collate_fn=collate_batch, 
                                               shuffle=True)

validation_dataloader = torch.utils.data.DataLoader(transformed_data['validation'],
                                                    batch_size=BATCH_SIZE,
                                                    collate_fn=collate_batch)

test_dataloader = torch.utils.data.DataLoader(transformed_data['test'],
                                              batch_size=BATCH_SIZE,
                                              collate_fn=collate_batch)

In [None]:
for batch in train_dataloader:
    break

In [None]:
batch

In [None]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional,
                 dropout_rate, pad_index):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, bidirectional=bidirectional,
                            dropout=dropout_rate, batch_first=True)
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, ids):
        # ids = [batch size, seq len]
        # length = [batch size]
        embedded = self.dropout(self.embedding(ids))
        # embedded = [batch size, seq len, embedding dim]
        output, (hidden, cell) = self.lstm(embedded)
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # output = [batch size, seq len, hidden dim * n directions]
        if self.lstm.bidirectional:
            hidden = self.dropout(torch.cat([hidden[-1], hidden[-2]], dim=-1))
            # hidden = [batch size, hidden dim * 2]
        else:
            hidden = self.dropout(hidden[-1])
            # hidden = [batch size, hidden dim]
        prediction = self.fc(hidden)
        # prediction = [batch size, output dim]
        return prediction

In [None]:
vocab_size = len(tokens_vocab)
embedding_dim = 300
hidden_dim = 300
output_dim = len(idx_to_label)
n_layers = 2
bidirectional = True
dropout_rate = 0.5

model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout_rate, 
             pad_index)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
lr = 5e-4

optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model = model.to(device)
criterion = criterion.to(device)

In [None]:
def train(dataloader, model, criterion, optimizer, device):

    model.train()
    epoch_losses = []
    epoch_accs = []

    for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):
        ids = batch['tokens'].to(device)
        label = batch['label'].to(device)
        prediction = model(ids)
        loss = criterion(prediction, label)
        accuracy = get_accuracy(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        epoch_accs.append(accuracy.item())

    return epoch_losses, epoch_accs

In [None]:
def evaluate(dataloader, model, criterion, device):
    
    model.eval()
    epoch_losses = []
    epoch_accs = []

    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):
            ids = batch['tokens'].to(device)
            label = batch['label'].to(device)
            prediction = model(ids)
            loss = criterion(prediction, label)
            accuracy = get_accuracy(prediction, label)
            epoch_losses.append(loss.item())
            epoch_accs.append(accuracy.item())

    return epoch_losses, epoch_accs

In [None]:
def get_accuracy(prediction, label):
    batch_size, _ = prediction.shape
    predicted_classes = prediction.argmax(dim=-1)
    correct_predictions = predicted_classes.eq(label).sum()
    accuracy = correct_predictions / batch_size
    return accuracy

In [None]:
n_epochs = 3
best_valid_loss = float('inf')

train_losses = []
train_accs = []
valid_losses = []
valid_accs = []

for epoch in range(n_epochs):

    train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)
    valid_loss, valid_acc = evaluate(validation_dataloader, model, criterion, device)

    train_losses.extend(train_loss)
    train_accs.extend(train_acc)
    valid_losses.extend(valid_loss)
    valid_accs.extend(valid_acc)
    
    epoch_train_loss = np.mean(train_loss)
    epoch_train_acc = np.mean(train_acc)
    epoch_valid_loss = np.mean(valid_loss)
    epoch_valid_acc = np.mean(valid_acc)
    
    if epoch_valid_loss < best_valid_loss:
        best_valid_loss = epoch_valid_loss
        torch.save(model.state_dict(), 'lstm.pt')
    
    print(f'epoch: {epoch+1}')
    print(f'train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}')
    print(f'valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}')

In [None]:
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(1,1,1)
ax.plot(train_losses, label='train loss')
ax.plot(valid_losses, label='valid loss')
plt.legend()
ax.set_xlabel('updates')
ax.set_ylabel('loss');

In [None]:
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(1,1,1)
ax.plot(train_accs, label='train accuracy')
ax.plot(valid_accs, label='valid accuracy')
plt.legend()
ax.set_xlabel('updates')
ax.set_ylabel('accuracy');

In [None]:
model.load_state_dict(torch.load('lstm.pt'))

test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)

epoch_test_loss = np.mean(test_loss)
epoch_test_acc = np.mean(test_acc)

print(f'test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}')

In [None]:
def process_line(text, model, tokens_vocab, idx_to_label, device, lower=True):
    tokens = word_tokenize(text, language='russian')
    ids = tokens_vocab.forward(tokens)
    tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)
    prediction = model(tensor).squeeze(dim=0)
    probability = torch.softmax(prediction, dim=-1)
    predicted_idx = prediction.argmax(dim=-1).item()
    predicted_class = idx_to_label[predicted_idx]
    predicted_probability = probability[predicted_idx].item()
    return predicted_class, predicted_probability

In [None]:
text = 'Все ложь, макаронного монстра не существует, пастафарианство было ошибкой!'

process_line(text, model, tokens_vocab, idx_to_label, device, TEXT_LOWER)

In [None]:
text = 'Я куплю арбалет и пойду охотиться на единорогов!'

process_line(text, model, tokens_vocab, idx_to_label, device, TEXT_LOWER)