<a href="https://colab.research.google.com/github/alexlimatds/victor-doc_classification/blob/main/victor_doc_classification_CNN3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Document classification of Victor project using a CNN as machine learning model

The notebook replicates the document classification with a CNN described in the _VICTOR: a Dataset for Brazilian Legal Documents Classification_ paper. As additional, this notebook explores other pre-processing strategies.

- Deep learning library: PyTorch
- NLP library: spaCy

### Instaling dependencies

In [None]:
!pip install tqdm
# spaCy
!pip install -U pip setuptools wheel
!pip install -U spacy
!python -m spacy download pt_core_news_sm

2021-02-11 14:05:14.885347: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
Collecting pt-core-news-sm==3.0.0
  Downloading https://github.com/explosion/spacy-models/releases/download/pt_core_news_sm-3.0.0/pt_core_news_sm-3.0.0-py3-none-any.whl (22.1 MB)
[K     |████████████████████████████████| 22.1 MB 1.3 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('pt_core_news_sm')


### Mounting Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = '/content/gdrive/My Drive/'

Mounted at /content/gdrive


### Application parameters

In [None]:
L2 = 0.0001
TAG_STOP_WORDS = True
LEMMATIZE = True

S = 500 # sentence length
BATCH_SIZE = 64
NUM_OF_CLASSES = 6
dataset_fraction = 1.0 # fraction of train and validation datasets to be used

MLP_HIDDEN_UNITS = 128
EMBEDDING_DIM = 100

dataset_dir = root_dir + 'Machine Learning/Victor datasets/'
model_path = '/'
model_file = model_path + f'pytorch_model-L2_{L2}.pt'
report_file = dataset_dir + f'CNN_3/report-L2_{L2}-tag_stop_words_{TAG_STOP_WORDS}-lemmatize_{LEMMATIZE}.txt'

### Loading and preprocessing datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data
from tqdm.notebook import trange, tqdm_notebook
import numpy as np
from datetime import datetime

In [None]:
train_ds_file = 'train_small.csv'
validation_ds_file = 'validation_small.csv'

In [None]:
import spacy

spacy_pt = spacy.load('pt_core_news_sm', exclude=['tagger', 'parser'])

def tokenizer(text):
  my_str = []
  my_doc = spacy_pt(text)
  for token in my_doc:
    if token.is_stop:
      if TAG_STOP_WORDS:
        my_str.append('STOP_WORD')
    elif token.is_alpha and not token.ent_type_:
      my_str.append(token.lemma_.lower() if LEMMATIZE else token.text.lower())
    elif token.ent_type_ and token.ent_iob == 3:
      my_str.append(token.ent_type_)
    elif token.is_digit:
      my_str.append('NUMBER')
    elif token.is_currency:
      my_str.append('CURRENCY')

  return my_str

In [None]:
%%time

TEXT = data.Field(
    tokenize=tokenizer, 
    lower=True, 
    fix_length=S)
LABEL = data.Field(
    sequential=False, 
    unk_token=None)

train_data, valid_data, test_data = data.TabularDataset.splits(
    path=dataset_dir, 
    train=train_ds_file,
    validation=validation_ds_file, 
    test='test_small.csv', 
    format='csv', 
    skip_header = True, 
    fields=[(None, None), (None, None), (None, None), ('label', LABEL), (None, None), ('text', TEXT)])

CPU times: user 1h 23min 40s, sys: 42 s, total: 1h 24min 22s
Wall time: 1h 24min 23s


In [None]:
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
  (train_data, valid_data, test_data),
  sort = False, #don't sort test/validation data
  batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),
  device=device)

### Model

In [None]:
class VictorCNN(nn.Module):

  def __init__(self, sentence_len, vocab_size, embed_dim, n_classes, mlp_h):
    """
    sentence_len: the length of the each input sentence.
    vocab_size:   the number of tokens in the vocabulary.
    embed_dim:    the dimension of each embedding word vector.
    n_classes:    number of classes, i.e., the output dimension of this NN.
    mlp_h:        number of hidden units of the MLP NN.
    """
    super(VictorCNN, self).__init__()
        
    self.word_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.cnn_a = self.create_cnn_layer(embed_dim, 256, 3, 1)
    self.cnn_b = self.create_cnn_layer(embed_dim, 256, 4, 2)
    self.cnn_c = self.create_cnn_layer(embed_dim, 256, 5, 2)
    self.max_pool = nn.MaxPool1d(50)
    self.linear_h = nn.Linear(3840, mlp_h)
    self.linear_o = nn.Linear(mlp_h, n_classes)

  def create_cnn_layer(self, n_channels, n_filters, kernel_size, padding):
    return nn.Sequential(
        nn.Conv1d(n_channels, n_filters, kernel_size, padding=padding), 
        nn.BatchNorm1d(n_filters), 
        nn.MaxPool1d(2)
    )

  def forward(self, sentence):
    # sentence.shape: (s_len, b_len)
    embeds = self.word_embeddings(sentence).permute(1, 2, 0) # embeds shape: (b_len, embedding_dim, s_len)
    a = self.cnn_a(embeds)
    b = self.cnn_b(embeds)
    c = self.cnn_c(embeds)
    x = torch.cat((a, b, c), dim=1)
    x = self.max_pool(x)
    x = torch.flatten(x, start_dim=1)
    x = F.relu(self.linear_h(x))
    x = self.linear_o(x)
    return x

### Training functions

In [None]:
from sklearn.metrics import f1_score

def compute_metrics(targets, predictions):
  f1_macro = f1_score(targets, np.argmax(predictions, axis=1), average='macro')
  return f1_macro

def train(model, iterator, optimizer, criterion, epoch):
  epoch_loss = 0
  model.train()
  for batch in tqdm_notebook(iterator, desc='Train', unit='batch', leave=False):
    optimizer.zero_grad()
    predictions = model(batch.text)
    loss = criterion(predictions, batch.label)
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()

  return epoch_loss / len(iterator)

def predict(model, iterator, set_name):
  model.eval()
  predictions = None
  targets = None
  with torch.no_grad():
    for batch in tqdm_notebook(iterator, desc=f'Predicting ({set_name})', unit='batch', leave=False):
      out = model(batch.text)
      if predictions == None:
        predictions = out
        targets = batch.label
      else:
        predictions = torch.cat([predictions, out], dim=0)
        targets = torch.cat([targets, batch.label], dim=0)
  
  return predictions.cpu().numpy(), targets.cpu().numpy()

def evaluate(model, iterator, set_name):  
  predictions, targets = predict(model, iterator, set_name)
  return compute_metrics(targets, predictions)


### Training

In [None]:
EPOCHS = 20
learning_rate = 1e-3

model = VictorCNN(S, len(TEXT.vocab), EMBEDDING_DIM, NUM_OF_CLASSES, MLP_HIDDEN_UNITS)
model = model.to(device)
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=learning_rate, 
    weight_decay=(L2 if L2 else 0.0))   # L2 regularization
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
%%time
import pandas as pd
from IPython.display import display, update_display

metrics_df = pd.DataFrame(columns=['Epoch', 'Loss (train)', 'F1 macro (train)', 'F1 macro (validation)'])
metrics_display = display(metrics_df, display_id='metrics_table')

best_valid_f1 = 0.0

for epoch in range(EPOCHS):
  train_loss = train(model, train_iterator, optimizer, criterion, epoch)
  train_f1_m = evaluate(model, train_iterator, 'train set')
  valid_f1_m = evaluate(model, valid_iterator, 'validation set')
  
  #saving
  if valid_f1_m > best_valid_f1:
    best_valid_f1 = valid_f1_m
    torch.save(model.state_dict(), model_file)

  #printing
  metrics_df.loc[epoch] = [epoch + 1, train_loss, train_f1_m, valid_f1_m]
  metrics_display.update(metrics_df)

Unnamed: 0,Epoch,Loss (train),F1 macro (train),F1 macro (validation)
0,1.0,0.303258,0.411604,0.385603
1,2.0,0.221058,0.608692,0.567933
2,3.0,0.177873,0.708527,0.650265
3,4.0,0.15033,0.754277,0.692871
4,5.0,0.130802,0.740178,0.672271
5,6.0,0.116099,0.827835,0.735397
6,7.0,0.106565,0.831729,0.737424
7,8.0,0.097666,0.841746,0.725287
8,9.0,0.091107,0.856566,0.733073
9,10.0,0.087103,0.844483,0.707254


HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



Unnamed: 0,Epoch,Loss (train),F1 macro (train),F1 macro (validation)
0,1.0,0.303258,0.411604,0.385603
1,2.0,0.221058,0.608692,0.567933
2,3.0,0.177873,0.708527,0.650265
3,4.0,0.15033,0.754277,0.692871
4,5.0,0.130802,0.740178,0.672271
5,6.0,0.116099,0.827835,0.735397
6,7.0,0.106565,0.831729,0.737424
7,8.0,0.097666,0.841746,0.725287
8,9.0,0.091107,0.856566,0.733073
9,10.0,0.087103,0.844483,0.707254


HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…

CPU times: user 37min 17s, sys: 15min 36s, total: 52min 54s
Wall time: 53min 24s


### Evaluation

In [None]:
def load_saved_model(file_name):
  m = VictorCNN(S, len(TEXT.vocab), EMBEDDING_DIM, NUM_OF_CLASSES, MLP_HIDDEN_UNITS)
  m = m.to(device)
  m.load_state_dict(torch.load(file_name, map_location=device))
  m.eval()
  return m

model = load_saved_model(model_file)

In [None]:
train_predictions, train_targets = predict(model, train_iterator, 'train set')
test_predictions, test_targets = predict(model, test_iterator, 'test set')
validation_predictions, validation_targets = predict(model, valid_iterator, 'validation set')

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (test set)', max=1493.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



In [None]:
from sklearn.metrics import classification_report

test_report = classification_report(
    test_targets, 
    np.argmax(test_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

validation_report = classification_report(
    validation_targets, 
    np.argmax(validation_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

train_report = classification_report(
    train_targets, 
    np.argmax(train_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

def format_int(value):
  value = int(value)
  return f'{value:2d}'

def format_float(value):
  return f'{value:.4f}'

print('Test report\n' + test_report)

rep_file = open(report_file, "wt")
rep_file.write('CNN evaluation report\n')
rep_file.write(f'L2 lambda: {L2}\n')
rep_file.write(f'LEMMATIZE: {LEMMATIZE}\n')
rep_file.write(f'TAG STOP WORDS: {TAG_STOP_WORDS}\n')
rep_file.write(f'learning rate: {learning_rate}\n')
rep_file.write(f'optimizer: {type(optimizer).__name__}\n')
rep_file.write(f'criterion: {type(criterion).__name__}\n')
rep_file.write(f'MLP_HIDDEN_UNITS: {MLP_HIDDEN_UNITS}\n')
rep_file.write(f'EMBEDDING_DIM: {EMBEDDING_DIM}\n')
rep_file.write(f'Test\n {test_report}\n')
rep_file.write(f'Validation\n {validation_report}\n')
rep_file.write(f'Train\n {train_report}\n')
rep_file.write(f'Train log\n {metrics_df.to_string(index=False, formatters=[format_int, format_float, format_float, format_float])}\n')
rep_file.close()