<a href="https://colab.research.google.com/github/alexlimatds/victor-doc_classification/blob/main/victor_doc_classification_LSTM_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Document classification of Victor project using a BiLSTM NN with attention mechanism as machine learning model

Application based on PyTorch and Spacy libraries.

Reference paper of the attention mechanism: https://arxiv.org/abs/1703.03130

### Installing dependencies

In [None]:
!python -m spacy download pt
!pip install tqdm

Collecting pt_core_news_sm==2.2.5
[?25l  Downloading https://github.com/explosion/spacy-models/releases/download/pt_core_news_sm-2.2.5/pt_core_news_sm-2.2.5.tar.gz (21.2MB)
[K     |████████████████████████████████| 21.2MB 1.4MB/s 
Building wheels for collected packages: pt-core-news-sm
  Building wheel for pt-core-news-sm (setup.py) ... [?25l[?25hdone
  Created wheel for pt-core-news-sm: filename=pt_core_news_sm-2.2.5-cp36-none-any.whl size=21186283 sha256=1c693b3e7c02cfa8854373efbed76a2ae5d5d97078323a6d36df11f3138d483e
  Stored in directory: /tmp/pip-ephem-wheel-cache-dlvknkf5/wheels/ea/94/74/ec9be8418e9231b471be5dc7e1b45dd670019a376a6b5bc1c0
Successfully built pt-core-news-sm
Installing collected packages: pt-core-news-sm
Successfully installed pt-core-news-sm-2.2.5
[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('pt_core_news_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.6/dist-packages/pt_core_news_sm -->
/usr/local/

### Mounting Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = '/content/gdrive/My Drive/'

Mounted at /content/gdrive


### Aplication parameters

In [None]:
S = 500 # sentence length
BATCH_SIZE = 32

dataset_fraction = 0.6 # fraction of train and validation datasets to be used

EMBEDDING_DIM = 200
RNN_HIDDEN_UNITS = 200
NUM_OF_CLASSES = 6
ATTENTION_HOPS = 10
ATTENTION_UNITS = 30
USE_DROPOUT = True

dataset_dir = root_dir + 'Machine Learning/Victor datasets/'
model_path = dataset_dir + 'LSTM+attention/'
model_file = model_path + f'pytorch_model-{S}-{dataset_fraction}.pt'

### Loading and preprocessing datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data
from tqdm.notebook import trange, tqdm_notebook
import numpy as np
from datetime import datetime

In [None]:
import spacy

spacy_pt = spacy.load('pt')

def tokenizer(text):
  return [tok.text for tok in spacy_pt.tokenizer(text)]

In [None]:
if dataset_fraction == 1.0: # full dataset
  train_ds_file = 'train_small.csv'
  validation_ds_file = 'validation_small.csv'
else:
  train_ds_file = f'train_small.csv-croped_{dataset_fraction}.csv'
  validation_ds_file = f'validation_small.csv-croped_{dataset_fraction}.csv'

In [None]:
%%time

TEXT = data.Field(
    tokenize=tokenizer, 
    lower=True, 
    fix_length=S)
LABEL = data.Field(
    sequential=False, 
    unk_token=None)

train_data, valid_data, test_data = data.TabularDataset.splits(
    path=dataset_dir, 
    train=train_ds_file,
    validation=validation_ds_file, 
    test='test_small.csv', 
    format='csv', 
    skip_header = True, 
    fields=[(None, None), (None, None), (None, None), ('label', LABEL), (None, None), ('text', TEXT)])

CPU times: user 3min 57s, sys: 2.68 s, total: 4min
Wall time: 4min 4s


In [None]:
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)
vocab = TEXT.vocab

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
  (train_data, valid_data, test_data),
  sort = False, #don't sort test/validation data
  batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),
  device=device)


### Model

In [None]:
class SelfAttentionSentenceEncoder(nn.Module):
  def __init__(self, h_dim, d_a, r):
    """
    h_dim: dimension of a hidden vector, i.e., the dimension of a vector generated by a BiLSTM 
         for one time step. The hidden vectors are part of this layer input.
    d_a: number of attention units to be used in this layer.
    r: number of attention-hops to be used in this layer.
    """
    super(SelfAttentionSentenceEncoder, self).__init__()
    
    self.u = h_dim
    self.d_a = d_a
    self.r = r
    self.W_s1 = nn.Linear(2 * self.u, self.d_a, bias=False)
    self.W_s2 = nn.Linear(self.d_a, self.r, bias=False)

  def forward(self, rnn_outputs):
    """
    rnn_outputs: the outputs of the BiLSTM, i.e., the hidden vectors for all time steps. The expected  
           shape is (batch_size, seq_len, 2 * h_dim).
    returns: 
      a matrix with the produced sentence embedding vectors with shape (batch_size, r, 2 * h_dim)
      a matrix with the computed attention weights with shape (batch_size, r, seq_len)
    """
    H = rnn_outputs                     # H.shape: (batch_size, seq_len, 2 * h_dim)
    H_ = torch.tanh(self.W_s1(H))       # H_.shape: (batch_size, seq_len, d_a)
    A = F.softmax(
        self.W_s2(H_).transpose(1, 2), 
        dim=2)                          # A.shape: (batch_size, r, seq_len)
    M = torch.bmm(A, H)                 # M.shape: (batch_size, r, 2 * h_dim)
    
    return M, A

class AttBiLSTM(nn.Module):
  """
  BiLSTM with attention.
  vocab_size: number of tokens in the vocabulary.
  embed_dim:  dimension of word embedding vectors.
  hidden_dim:  number of hidden units in the BiLSTM network.
  n_classes:  number of classes.
  r:    number of attention hops.
  d_a:  number of attention units.
  dropout: True to apply dropout, False in the otherwise.
  """
  def __init__(self, vocab_size, embed_dim, hidden_dim, n_classes, r, d_a, dropout):
    super(AttBiLSTM, self).__init__()
    self.r = r      # attention hops
    self.d_a = d_a  # attention units
    self.dropout = None
    
    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
    self.attention = SelfAttentionSentenceEncoder(hidden_dim, self.d_a, self.r)
    self.linear = nn.Linear(hidden_dim * 2 * self.r, n_classes)
    if dropout:
      self.dropout = nn.Dropout(0.15)

  def forward(self, sentence):
    # sentence.shape: (b_len, s_len)
    embs = self.embedding(sentence) # embs.shape: (b_len, s_len, n_features)
    lstm_out, _ = self.lstm(embs)   # lstm_out.shape: (b_len, s_len, hidden_dim * 2)
    linear_input, A = self.attention(lstm_out)
    linear_input = torch.flatten(linear_input, start_dim=1) # linear_input.shape: (b_len, r * 2 * hidden_dim)
    if self.dropout:
      linear_input = self.dropout(linear_input)
    x = self.linear(linear_input)   # x.shape: (b_len, n_classes)
    
    return x, A


### Training

In [None]:
EPOCHS = 20
learning_rate = 1e-3

model = AttBiLSTM(len(vocab), EMBEDDING_DIM, RNN_HIDDEN_UNITS, NUM_OF_CLASSES, ATTENTION_HOPS, ATTENTION_UNITS, USE_DROPOUT)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
from sklearn.metrics import f1_score

# Identity matrix to be used in the regularization procedure
I = (torch.zeros(BATCH_SIZE, ATTENTION_HOPS, ATTENTION_HOPS) + 1).to(device)

def compute_metrics(targets, predictions):
  f1_macro = f1_score(targets, np.argmax(predictions, axis=1), average='macro')
  return f1_macro

def train(model, iterator, optimizer, criterion, epoch):
  epoch_loss = 0
  model.train()
  for batch in tqdm_notebook(iterator, desc='Train', unit='batch', leave=False):
    optimizer.zero_grad()
    predictions, A = model(batch.text.transpose(0,1))
    loss = criterion(predictions, batch.label)
    
    # Attention regularization
    A_T = torch.transpose(A, 1, 2)
    I_ = (I if A.shape[0] == I.shape[0] else I[:A.shape[0],:,:])
    extra_loss = torch.sum(torch.norm((torch.bmm(A, A_T) - I_), 'fro', dim=(1,2)) ** 2 * 0.5)
    loss += extra_loss
    
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()

  return epoch_loss / len(iterator)

def predict(model, iterator, set_name):
  model.eval()
  predictions = None
  targets = None
  with torch.no_grad():
    for batch in tqdm_notebook(iterator, desc=f'Predicting ({set_name})', unit='batch', leave=False):
      out, _ = model(batch.text.transpose(0,1))
      if predictions == None:
        predictions = out
        targets = batch.label
      else:
        predictions = torch.cat([predictions, out], dim=0)
        targets = torch.cat([targets, batch.label], dim=0)
  
  return predictions.cpu().numpy(), targets.cpu().numpy()

def evaluate(model, iterator, set_name):  
  predictions, targets = predict(model, iterator, set_name)
  return compute_metrics(targets, predictions)


In [None]:
%%time
import pandas as pd
from IPython.display import display, update_display

metrics_df = pd.DataFrame(columns=['Epoch', 'Loss (train)', 'F1 macro (train)', 'F1 macro (validation)'])
metrics_display = display(metrics_df, display_id='metrics_table')

best_valid_f1 = 0.0

for epoch in range(EPOCHS):
  train_loss = train(model, train_iterator, optimizer, criterion, epoch)
  train_f1_m = evaluate(model, train_iterator, 'train set')
  valid_f1_m = evaluate(model, valid_iterator, 'validation set')
  
  #saving
  if valid_f1_m > best_valid_f1:
    best_valid_f1 = valid_f1_m
    torch.save(model.state_dict(), model_file)

  #printing
  metrics_df.loc[epoch] = [epoch + 1, train_loss, train_f1_m, valid_f1_m]
  metrics_display.update(metrics_df)

Unnamed: 0,Epoch,Loss (train),F1 macro (train),F1 macro (validation)
0,1.0,38.188632,0.439305,0.453706
1,2.0,0.325276,0.71352,0.676701
2,3.0,0.222242,0.773879,0.687085
3,4.0,0.167181,0.820483,0.686353
4,5.0,0.132312,0.863214,0.725663
5,6.0,0.10508,0.89112,0.720405
6,7.0,0.087935,0.907477,0.734854
7,8.0,0.074609,0.929314,0.734578
8,9.0,0.064639,0.939544,0.729896
9,10.0,0.05729,0.950601,0.755839


HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Train', max=2987.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1910.0, style=ProgressS…

CPU times: user 1h 26min 32s, sys: 6min 46s, total: 1h 33min 19s
Wall time: 1h 34min 26s


### Evaluating

In [None]:
def load_saved_model(file_name):
  m = AttBiLSTM(len(vocab), EMBEDDING_DIM, RNN_HIDDEN_UNITS, NUM_OF_CLASSES, ATTENTION_HOPS, ATTENTION_UNITS, USE_DROPOUT)
  m = m.to(device)
  m.load_state_dict(torch.load(file_name, map_location=device))
  m.eval()
  return m

#model_file = model_path + 'pytorch_model-500-1.0.pt'

model = load_saved_model(model_file)

In [None]:
train_predictions, train_targets = predict(model, train_iterator, 'train set')
test_predictions, test_targets = predict(model, test_iterator, 'test set')

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2987.0, style=ProgressStyle(…

HBox(children=(FloatProgress(value=0.0, description='Predicting (test set)', max=2986.0, style=ProgressStyle(d…

In [None]:
from sklearn.metrics import classification_report

test_report = classification_report(
    test_targets, 
    np.argmax(test_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

train_report = classification_report(
    train_targets, 
    np.argmax(train_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

print(test_report)

rep_file = open(model_file + "-test_report.txt", "wt")
rep_file.write(f'Test {test_report}\n')
rep_file.write(f'Train {train_report}\n')
rep_file.write(f'learning rate: {learning_rate}\n')
rep_file.write(f'optimizer: {type(optimizer).__name__}\n')
rep_file.write(f'criterion: {type(criterion).__name__}\n')
rep_file.write(f'Attention hops: {model.r}\n')
rep_file.write(f'Attention units: {model.d_a}\n')
rep_file.write(f'Dropout: {USE_DROPOUT}\n')
rep_file.write(f'Embedding dim: {EMBEDDING_DIM}\n')
rep_file.write(f'BiLSTM hidden units: {RNN_HIDDEN_UNITS}\n')
rep_file.close()

                                  precision    recall  f1-score   support

                          outros     0.9702    0.9592    0.9647     85408
                   peticao_do_RE     0.7244    0.7338    0.7291      6331
agravo_em_recurso_extraordinario     0.4141    0.5932    0.4877      1841
                        sentenca     0.6875    0.7295    0.7079      1475
          acordao_de_2_instancia     0.8659    0.8755    0.8707       273
     despacho_de_admissibilidade     0.5744    0.5657    0.5700       198

                        accuracy                         0.9326     95526
                       macro avg     0.7061    0.7428    0.7217     95526
                    weighted avg     0.9377    0.9326    0.9348     95526



References:
- http://anie.me/On-Torchtext/
- https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/A%20-%20Using%20TorchText%20with%20Your%20Own%20Datasets.ipynb
- https://github.com/bentrevett/pytorch-sentiment-analysis
