In [None]:
import nltk
import torch
import string
import datasets
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


from typing import Any
from nltk.tokenize import word_tokenize
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    f1_score,
    recall_score,
    roc_auc_score,
    precision_score
)

In [None]:
# nltk.download("punkt")
# nltk.download("punkt_tab")

## 1 Подготовка данных

Загружаем данные.

In [None]:
dataset = datasets.load_dataset("imdb")

Найдем количество встречаемости каждого токена из датасета.

In [None]:
token_to_count = {}

for text in tqdm(dataset["train"]["text"]):
    text_clean = text.lower().translate(str.maketrans("", "", string.punctuation))
    for token in word_tokenize(text):
        token = token.strip()
        token_to_count[token] = token_to_count.get(token, 0) + 1
        
print(f"Кол-во токенов в словаре: {len(token_to_count)}")

Добавим специальные токены:
- \<unk\>: неизвестный токен;
- \<bos\>: начало последовательности;
- \<eos\>: конец последовательности;
- \<pad\>: специальный токен для объединения последовательностей разных длин в один батч.

А также, удалим редкие токены.

In [None]:
vocabulary = set(["<unk>", "<bos>", "<eos>", "<pad>"])

threshold = 25
for token, token_count in tqdm(token_to_count.items()):
    if token_count >= threshold:
        vocabulary.add(token)
    
print(f"Кол-во токенов в словаре: {len(vocabulary)}")

In [None]:
token_to_index = {token: index for index, token in enumerate(vocabulary)}
index_to_token = {index: token for token, index in token_to_index.items()}

Заведем специальный класс для хранения данных.

In [None]:
class CustomDataset:
    
    def __init__(self, text: list[str], label: list[int]) -> None:
        self.text = text
        self.label = label
        
        self.unk_id = token_to_index["<unk>"]
        self.bos_id = token_to_index["<bos>"]
        self.eos_id = token_to_index["<eos>"]
    
    def __len__(self) -> int:
        return len(self.text)
    
    def __getitem__(self, index: int) -> Any:
        text  = self.text[index].lower().translate(str.maketrans("", "", string.punctuation))
        label = self.label[index]

        text_indices = [self.bos_id]
        text_indices = text_indices + [token_to_index.get(x, self.unk_id) for x in word_tokenize(text)]
        text_indices = text_indices + [self.eos_id]
        
        return (text_indices, label)
    

def collate_fn(batch: list[list[int]], max_length: int = 256) -> torch.Tensor:
    max_length = min(max_length, max([len(item[0]) for item in batch]))
    
    batch_text_indices = []
    for item in batch:
        text_indices = item[0][:max_length]
        for _ in range(max_length - len(text_indices)):
            text_indices.append(token_to_index["<pad>"])
        batch_text_indices.append(text_indices)
    
    batch_labels = [item[1] for item in batch]
    
    return (torch.LongTensor(batch_text_indices), torch.LongTensor(batch_labels))

In [None]:
valid_indices = np.random.choice(np.arange(dataset["test"].num_rows), 2000)

train_dataset = CustomDataset(dataset["train"]["text"], dataset["train"]["label"])
valid_dataset = CustomDataset(
    dataset["test"].select(valid_indices)["text"],
    dataset["test"].select(valid_indices)["label"]
)

train_loader = DataLoader(train_dataset, shuffle=True,  collate_fn=collate_fn, batch_size=128)
valid_loader = DataLoader(valid_dataset, shuffle=False, collate_fn=collate_fn, batch_size=128)

RecurrentNeuralNetwork

In [None]:
class RecurrentNeuralNetwork(nn.Module):
    
    def __init__(
        self, 
        input_size: int, 
        hidden_size: int, 
        output_size: int,
        vocabulary_size: int, 
        num_layers: int = 1,
        dropout: float = 0.1,
        bidirectional: float = False
    ) -> None:
        super(RecurrentNeuralNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.vocabulary_size = vocabulary_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional = bidirectional
        
        self.embedding = nn.Embedding(vocabulary_size, input_size)
        
        self.rnn = nn.RNN(
            input_size, 
            hidden_size, 
            dropout=dropout,
            num_layers=num_layers,
            # nonlinearity="tanh",
            bidirectional=bidirectional,
            batch_first=True
        )
        
        self.dense = nn.Linear(2 * hidden_size if self.bidirectional else hidden_size, output_size)
    
    def forward(self, x):
        embedding = self.embedding(x)
        
        output, _ = self.rnn(embedding)
        output = output.max(dim=1)[0]
        output = self.dense(output)
                
        # output = self.dense(output[:, -1, :])
        
        return output

In [None]:
num_epoch = 3
learning_rate = 0.001

model = RecurrentNeuralNetwork(
    input_size=128, 
    hidden_size=128, 
    output_size=2, 
    vocabulary_size=len(token_to_index),
    num_layers=1,
    dropout=0.0,
    bidirectional=False
)

criterion = nn.CrossEntropyLoss(ignore_index=token_to_index['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

losses = {"train": [], "valid": []}

for epoch in range(num_epoch):
    
    model.train()
    
    y_train_true = []
    y_train_pred = []
    
    train_loss = 0.0
    for x, y in tqdm(train_loader, desc="Train"):
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        y_train_true.extend(y.tolist())
        y_train_pred.extend(output.argmax(dim=1).tolist())
    losses["train"].append(train_loss / len(train_loader))
    
    model.eval()
    
    y_valid_true = []
    y_valid_pred = []
    
    valid_loss = 0.0
    with torch.no_grad():
        for x, y in tqdm(valid_loader, desc="Valid"):
            output = model(x)
            y_valid_true.extend(y.tolist())
            y_valid_pred.extend(output.argmax(dim=1).tolist())
            loss = criterion(output, y)
            
            valid_loss += loss.item()
    losses["valid"].append(valid_loss / len(valid_loader))
    
    print(f"train roc_auc_score={roc_auc_score(y_train_true, y_train_pred)}, valid roc_auc_score={roc_auc_score(y_valid_true, y_valid_pred)}")