<a href="https://colab.research.google.com/github/alexlimatds/victor-doc_classification/blob/main/victor_doc_classification_CNN2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Document classification of Victor project using a CNN as machine learning model

The notebook replicates the document classification with a CNN described in the _VICTOR: a Dataset for Brazilian Legal Documents Classification_ paper. Also, this notebook is prepared to explore different regularization strategies.

Deep learning library: PyTorch

### Instaling dependencies

In [None]:
!pip install tqdm
!python -m spacy download pt

[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('pt_core_news_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.6/dist-packages/pt_core_news_sm -->
/usr/local/lib/python3.6/dist-packages/spacy/data/pt
You can now load the model via spacy.load('pt')


### Mounting Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = '/content/gdrive/My Drive/'

Mounted at /content/gdrive


### Application parameters

In [None]:
L1 = None
L2 = 0.5
DROPOUT_RATE = None

S = 500 # sentence length
BATCH_SIZE = 64
NUM_OF_CLASSES = 6
dataset_fraction = 1.0 # fraction of train and validation datasets to be used

MLP_HIDDEN_UNITS = 128
EMBEDDING_DIM = 100

dataset_dir = root_dir + 'Machine Learning/Victor datasets/'
model_path = dataset_dir + 'CNN_2/'
model_file = model_path + f'pytorch_model'
if L1:
  model_file = model_file + f'-L1_{L1}'
if L2:
  model_file = model_file + f'-L2_{L2}'
if DROPOUT_RATE:
  model_file = model_file + f'-D_{DROPOUT_RATE}'
model_file = model_file + '.pt'

### Loading and preprocessing datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data
from tqdm.notebook import trange, tqdm_notebook
import numpy as np
from datetime import datetime

In [None]:
train_ds_file = 'train_small.csv'
validation_ds_file = 'validation_small.csv'

In [None]:
import spacy

spacy_pt = spacy.load('pt')

def tokenizer(text):
  return [tok.text for tok in spacy_pt.tokenizer(text)]


In [None]:
%%time

TEXT = data.Field(
    tokenize=tokenizer, 
    lower=True, 
    fix_length=S)
LABEL = data.Field(
    sequential=False, 
    unk_token=None)

train_data, valid_data, test_data = data.TabularDataset.splits(
    path=dataset_dir, 
    train=train_ds_file,
    validation=validation_ds_file, 
    test='test_small.csv', 
    format='csv', 
    skip_header = True, 
    fields=[(None, None), (None, None), (None, None), ('label', LABEL), (None, None), ('text', TEXT)])

CPU times: user 5min 15s, sys: 3.59 s, total: 5min 19s
Wall time: 5min 19s


In [None]:
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
  (train_data, valid_data, test_data),
  sort = False, #don't sort test/validation data
  batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),
  device=device)

### Model

In [None]:
class VictorCNN(nn.Module):

  def __init__(self, sentence_len, vocab_size, embed_dim, n_classes, mlp_h, dropout):
    """
    sentence_len: the length of the each input sentence.
    vocab_size:   the number of tokens in the vocabulary.
    embed_dim:    the dimension of each embedding word vector.
    n_classes:    number of classes, i.e., the output dimension of this NN.
    mlp_h:        number of hidden units of the MLP NN.
    dropout:      the dropout rate. None to not apply dropout.
    """
    super(VictorCNN, self).__init__()
        
    self.word_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.cnn_a = self.create_cnn_layer(embed_dim, 256, 3, 1)
    self.cnn_b = self.create_cnn_layer(embed_dim, 256, 4, 2)
    self.cnn_c = self.create_cnn_layer(embed_dim, 256, 5, 2)
    self.max_pool = nn.MaxPool1d(50)
    self.dropout = (nn.Dropout(p=dropout) if dropout else None)
    self.linear_h = nn.Linear(3840, mlp_h)
    self.linear_o = nn.Linear(mlp_h, n_classes)

  def create_cnn_layer(self, n_channels, n_filters, kernel_size, padding):
    return nn.Sequential(
        nn.Conv1d(n_channels, n_filters, kernel_size, padding=padding), 
        nn.BatchNorm1d(n_filters), 
        nn.MaxPool1d(2)
    )

  def forward(self, sentence):
    # sentence.shape: (s_len, b_len)
    embeds = self.word_embeddings(sentence).permute(1, 2, 0) # embeds shape: (b_len, embedding_dim, s_len)
    a = self.cnn_a(embeds)
    b = self.cnn_b(embeds)
    c = self.cnn_c(embeds)
    x = torch.cat((a, b, c), dim=1)
    x = self.max_pool(x)
    x = torch.flatten(x, start_dim=1)
    if self.dropout:
      x = self.dropout(x)
    x = F.relu(self.linear_h(x))
    x = self.linear_o(x)
    return x

### Training functions

In [None]:
from sklearn.metrics import f1_score

def compute_metrics(targets, predictions):
  f1_macro = f1_score(targets, np.argmax(predictions, axis=1), average='macro')
  return f1_macro

def train(model, iterator, optimizer, criterion, epoch):
  epoch_loss = 0
  model.train()
  for batch in tqdm_notebook(iterator, desc='Train', unit='batch', leave=False):
    optimizer.zero_grad()
    predictions = model(batch.text)
    loss = criterion(predictions, batch.label)
    
    # L1 regularization
    if L1:
      L1_penalty = 0.0
      for p in model.parameters():
        #L1_penalty = L1_penalty + p.abs().sum()
        L1_penalty = L1_penalty + p.norm(p=1)
      loss = loss + L1 / predictions.shape[0] * L1_penalty

    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()

  return epoch_loss / len(iterator)

def predict(model, iterator, set_name):
  model.eval()
  predictions = None
  targets = None
  with torch.no_grad():
    for batch in tqdm_notebook(iterator, desc=f'Predicting ({set_name})', unit='batch', leave=False):
      out = model(batch.text)
      if predictions == None:
        predictions = out
        targets = batch.label
      else:
        predictions = torch.cat([predictions, out], dim=0)
        targets = torch.cat([targets, batch.label], dim=0)
  
  return predictions.cpu().numpy(), targets.cpu().numpy()

def evaluate(model, iterator, set_name):  
  predictions, targets = predict(model, iterator, set_name)
  return compute_metrics(targets, predictions)


### Training

In [None]:
EPOCHS = 20
learning_rate = 1e-3

model = VictorCNN(S, len(TEXT.vocab), EMBEDDING_DIM, NUM_OF_CLASSES, MLP_HIDDEN_UNITS, DROPOUT_RATE)
model = model.to(device)
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=learning_rate, 
    weight_decay=(L2 if L2 else 0.0))   # L2 regularization
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
%%time
import pandas as pd
from IPython.display import display, update_display

metrics_df = pd.DataFrame(columns=['Epoch', 'Loss (train)', 'F1 macro (train)', 'F1 macro (validation)'])
metrics_display = display(metrics_df, display_id='metrics_table')

best_valid_f1 = 0.0

for epoch in range(EPOCHS):
  train_loss = train(model, train_iterator, optimizer, criterion, epoch)
  train_f1_m = evaluate(model, train_iterator, 'train set')
  valid_f1_m = evaluate(model, valid_iterator, 'validation set')
  
  #saving
  if valid_f1_m > best_valid_f1:
    best_valid_f1 = valid_f1_m
    torch.save(model.state_dict(), model_file)

  #printing
  metrics_df.loc[epoch] = [epoch + 1, train_loss, train_f1_m, valid_f1_m]
  metrics_display.update(metrics_df)

Unnamed: 0,Epoch,Loss (train),F1 macro (train),F1 macro (validation)
0,1.0,0.440954,0.157795,0.156759
1,2.0,0.459111,0.157795,0.156759
2,3.0,0.452644,0.157795,0.156759
3,4.0,0.451549,0.157795,0.156759
4,5.0,0.451636,0.157795,0.156759
5,6.0,0.451434,0.157795,0.156759
6,7.0,0.451456,0.157795,0.156759
7,8.0,0.451529,0.157795,0.156759
8,9.0,0.451146,0.157795,0.156759
9,10.0,0.451228,0.157795,0.156759


HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



HBox(children=(FloatProgress(value=0.0, description='Train', max=2332.0, style=ProgressStyle(description_width…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…

CPU times: user 39min, sys: 15min 10s, total: 54min 11s
Wall time: 54min 47s


### Evaluation

In [None]:
def load_saved_model(file_name):
  m = VictorCNN(S, len(TEXT.vocab), EMBEDDING_DIM, NUM_OF_CLASSES, MLP_HIDDEN_UNITS, DROPOUT_RATE)
  m = m.to(device)
  m.load_state_dict(torch.load(file_name, map_location=device))
  m.eval()
  return m

#model_file = model_path + 'pytorch_model-500-1.0.pt'
model = load_saved_model(model_file)

In [None]:
train_predictions, train_targets = predict(model, train_iterator, 'train set')
test_predictions, test_targets = predict(model, test_iterator, 'test set')
validation_predictions, validation_targets = predict(model, valid_iterator, 'validation set')

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=2332.0, style=ProgressStyle(…



HBox(children=(FloatProgress(value=0.0, description='Predicting (test set)', max=1493.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=1481.0, style=ProgressS…



In [None]:
from sklearn.metrics import classification_report

test_report = classification_report(
    test_targets, 
    np.argmax(test_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

validation_report = classification_report(
    validation_targets, 
    np.argmax(validation_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

train_report = classification_report(
    train_targets, 
    np.argmax(train_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

def format_int(value):
  value = int(value)
  return f'{value:2d}'

def format_float(value):
  return f'{value:.4f}'

print('Test report\n' + test_report)

rep_file = open(model_file + "-test_report.txt", "wt")
rep_file.write('CNN evaluation report\n')
rep_file.write(f'L1 lambda: {L1}\n')
rep_file.write(f'L2 lambda: {L2}\n')
rep_file.write(f'DROPOUT_RATE: {DROPOUT_RATE}\n')
rep_file.write(f'learning rate: {learning_rate}\n')
rep_file.write(f'optimizer: {type(optimizer).__name__}\n')
rep_file.write(f'criterion: {type(criterion).__name__}\n')
rep_file.write(f'MLP_HIDDEN_UNITS: {MLP_HIDDEN_UNITS}\n')
rep_file.write(f'EMBEDDING_DIM: {EMBEDDING_DIM}\n')
rep_file.write(f'Test\n {test_report}\n')
rep_file.write(f'Validation\n {validation_report}\n')
rep_file.write(f'Train\n {train_report}\n')
rep_file.write(f'Train log\n {metrics_df.to_string(index=False, formatters=[format_int, format_float, format_float, format_float])}\n')
rep_file.close()

  _warn_prf(average, modifier, msg_start, len(result))


Test report
                                  precision    recall  f1-score   support

                          outros     0.8941    1.0000    0.9441     85408
                   peticao_do_RE     0.0000    0.0000    0.0000      6331
agravo_em_recurso_extraordinario     0.0000    0.0000    0.0000      1841
                        sentenca     0.0000    0.0000    0.0000      1475
          acordao_de_2_instancia     0.0000    0.0000    0.0000       273
     despacho_de_admissibilidade     0.0000    0.0000    0.0000       198

                        accuracy                         0.8941     95526
                       macro avg     0.1490    0.1667    0.1573     95526
                    weighted avg     0.7994    0.8941    0.8441     95526

