In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from nltk.tokenize import word_tokenize
from sklearn.model_selection import train_test_split
import nltk
from tqdm import tqdm_notebook
import torch.nn as nn                                                             
import torch.nn.functional as F

from collections import Counter
from typing import List
import string

import seaborn
seaborn.set(palette='summer')

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\roma0\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [8]:
import pandas as pd
train_data = pd.read_csv('https://raw.githubusercontent.com/MSUcourses/Data-Analysis-with-Python/main/Deep%20Learning/Files/train_authors.csv')
train_data.head()

Unnamed: 0,text,label
0,"-Да, я поторопился. Капитан, примите мои извин...",Pratchett
1,-Похороны по первому разряду! Довольно благоро...,Remark
2,"Третий округ штата Мэн настолько велик, что ег...",King
3,В мире существуют миллиарды и миллиарды богов....,Pratchett
4,Особенность историографии киевского периода со...,Akunin


In [9]:
test_data = pd.read_csv('https://raw.githubusercontent.com/MSUcourses/Data-Analysis-with-Python/main/Deep%20Learning/Files/test_authors.csv')
test_data.head()

Unnamed: 0,text
0,"-Да, я поторопился. Капитан, примите мои извин..."
1,-Похороны по первому разряду! Довольно благоро...
2,"Третий округ штата Мэн настолько велик, что ег..."
3,В мире существуют миллиарды и миллиарды богов....
4,Особенность историографии киевского периода со...


In [10]:
writers = ['Akunin', 'Bulychev', 'Chehov', 'Dostoevsky', 'Gogol', 'King',
       'Pratchett', 'Remark']
writers_to_label = {writer: i for i, writer in enumerate(writers)}
label_to_writers = {i: writer for i, writer in enumerate(writers)}

In [11]:
dataset = {}

dataset['train'] = [{'text':text, 'label':writers_to_label[label]} \
              for text, label in zip(np.array(train_data['text']), np.array(train_data['label']))]
dataset['test'] = [{'text':text, 'label': 0} \
              for text in np.array(test_data['text'])]

In [12]:
def process_and_tokenize_text(text):
    prccessed_text = text.lower().translate(
        str.maketrans('', '', string.punctuation)
    )
    tokens = word_tokenize(prccessed_text)
    return tokens

train_data = []
test_data = []

words = Counter()

for example in tqdm_notebook(dataset['train']):
    text = example['text']
    label = example['label']
    text_processed = process_and_tokenize_text(text)
    train_data.append((text_processed, label))

    for word in text_processed:
        words[word] += 1

for example in tqdm_notebook(dataset['test']):
    text = example['text']
    label = example['label']
    text_processed = process_and_tokenize_text(text)
    test_data.append((text_processed, label))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for example in tqdm_notebook(dataset['train']):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1651.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for example in tqdm_notebook(dataset['test']):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1651.0), HTML(value='')))




In [13]:
vocab = set(['<unk>', '<bos>', '<eos>', '<pad>'])

counter_threshold = 1

for char, cnt in words.items():
    if cnt > counter_threshold:
        vocab.add(char)

In [14]:
word2ind = {char: i for i, char in enumerate(vocab)}
ind2word = {i: char for char, i in word2ind.items()}

In [15]:
def collate_fn_with_padding(input_batch, max_len=256):
    texts = [x[0] for x in input_batch]
    labels = [x[1] for x in input_batch]
    seq_lens = [len(x) for x in texts]
    max_seq_len = min(max(seq_lens), max_len)

    processed_texts = []
    for text, label in zip(texts, labels):
        text = text[:max_seq_len]
        text = [word2ind[x] if x in vocab else word2ind['<unk>'] for x in text]
        for _ in range(max_seq_len - len(text)):
            text.append(word2ind['<pad>'])

        processed_texts.append(text)

    processed_texts = torch.LongTensor(processed_texts).to(device)
    labels = torch.LongTensor(labels).to(device)

    processed_batch = {
        'input_ids': processed_texts,
        'label': labels
    }

    return processed_batch

In [16]:
batch_size = 128
train_dataloader = DataLoader(
    train_data, shuffle=True, collate_fn=collate_fn_with_padding, batch_size=batch_size)

test_dataloader = DataLoader(
    test_data, shuffle=False, collate_fn=collate_fn_with_padding, batch_size=batch_size)

In [18]:
class SimpleGRU(nn.Module):
    def __init__(
        self, hidden_dim, vocab_size, num_classes,
        aggregation_type: str = 'last'
        ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers = 1, batch_first=True)

        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

        self.dropout = nn.Dropout(p=0.1)

        self.aggregation_type = aggregation_type

    def forward(self, input_batch) -> torch.Tensor:
        embeddings = self.embedding(input_batch)  # [batch_size, seq_len, hidden_dim]
        output, _ = self.gru(embeddings)          # [batch_size, seq_len, hidden_dim]

        if self.aggregation_type == 'max':
            output = output.max(dim=1)[0]         #[batch_size, hidden_dim]
        elif self.aggregation_type == 'mean':
            output = output.mean(dim=1)           #[batch_size, hidden_dim]
        elif self.aggregation_type == 'last':
            output = output[:, -1, :]
        else:
            raise ValueError("Invalid aggregation_type")

        output = F.tanh(output)
        output = F.tanh(self.dropout(self.fc1(output)))   # [batch_size, hidden_dim]
        output = self.fc2(output)                         # [batch_size, num_classes]

        return output

In [19]:
def evaluate(model, dataloader):
    """
    Calculate accuracy on data from dataloader.
    """

    predictions = []
    target = []
    with torch.no_grad():
        for batch in tqdm_notebook(dataloader,
                                   desc=f'Evaluating'):
            logits = model(batch['input_ids'])
            predictions.append(logits.argmax(dim=1))
            target.append(batch['label'])

    predictions = torch.cat(predictions)
    target = torch.cat(target)
    accuracy = (predictions == target).float().mean().item()

    return accuracy

In [20]:
def train(model, optimizer, criterion, num_epoch=5, eval_steps=100):

    losses = []
    accs_train = []

    for epoch in range(num_epoch):
        epoch_losses = []
        model.train()
        for i, batch in enumerate(tqdm_notebook(train_dataloader,
                                                desc=f'Training epoch {epoch}:')):
            optimizer.zero_grad()
            logits = model(batch['input_ids'])
            loss = criterion(logits, batch['label'])
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())
            if i % eval_steps == 0:
                model.eval()
                accs_train.append(evaluate(model, train_dataloader))
                model.train()

        losses.append(sum(epoch_losses) / len(epoch_losses))

    return losses, accs_train

In [23]:
model = SimpleGRU(hidden_dim = 256, vocab_size = len(vocab), num_classes = 8, aggregation_type = 'max').to(device)

In [24]:
opt = torch.optim.Adam(model.parameters(), lr = 1e-3)
crt = nn.CrossEntropyLoss(ignore_index=word2ind['<pad>'])

losses, acca_train = train(model, opt, crt, num_epoch = 10, eval_steps=len(train_dataloader) // 2)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in enumerate(tqdm_notebook(train_dataloader,


HBox(children=(HTML(value='Training epoch 0:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 1:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 2:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 3:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 4:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 5:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 6:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 7:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 8:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





HBox(children=(HTML(value='Training epoch 9:'), FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))





In [25]:
evaluate(model, train_dataloader)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch in tqdm_notebook(dataloader,


HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




0.9836462736129761

In [26]:
def get_predictions(model, dataloader):
    
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in tqdm_notebook(dataloader,
                                   desc=f'Evaluating'):
            logits = model(batch['input_ids'])
            predictions.append(logits.argmax(dim=1))

    predictions = torch.cat(predictions).data.cpu().numpy()

    return predictions

In [27]:
predictions = get_predictions(model, test_dataloader)
predictions = [label_to_writers[x] for x in predictions]
predictions

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch in tqdm_notebook(dataloader,


HBox(children=(HTML(value='Evaluating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))




['Pratchett',
 'Remark',
 'King',
 'Pratchett',
 'Akunin',
 'Akunin',
 'Akunin',
 'Pratchett',
 'Akunin',
 'Bulychev',
 'Bulychev',
 'Bulychev',
 'Akunin',
 'Bulychev',
 'Akunin',
 'Akunin',
 'Akunin',
 'Bulychev',
 'King',
 'Chehov',
 'Akunin',
 'Bulychev',
 'Remark',
 'Remark',
 'Remark',
 'King',
 'Akunin',
 'Pratchett',
 'Bulychev',
 'Akunin',
 'Chehov',
 'Chehov',
 'King',
 'Akunin',
 'Bulychev',
 'Bulychev',
 'Bulychev',
 'Bulychev',
 'Bulychev',
 'Akunin',
 'Gogol',
 'King',
 'Bulychev',
 'Akunin',
 'Chehov',
 'Dostoevsky',
 'Bulychev',
 'Bulychev',
 'Akunin',
 'Akunin',
 'Remark',
 'King',
 'Akunin',
 'Bulychev',
 'King',
 'Akunin',
 'Chehov',
 'Remark',
 'Akunin',
 'Remark',
 'Chehov',
 'King',
 'King',
 'Akunin',
 'Akunin',
 'Akunin',
 'Chehov',
 'Pratchett',
 'Pratchett',
 'Akunin',
 'Akunin',
 'Akunin',
 'Remark',
 'Bulychev',
 'Akunin',
 'King',
 'King',
 'King',
 'Akunin',
 'Bulychev',
 'King',
 'Bulychev',
 'Pratchett',
 'Gogol',
 'King',
 'Gogol',
 'Akunin',
 'King',
 '