<a href="https://colab.research.google.com/github/alexlimatds/victor-doc_classification/blob/main/victor_doc_classification_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Document classification of Victor project using a CNN as machine learning model

The notebook replicates the document classification with a CNN described in the _VICTOR: a Dataset for Brazilian Legal Documents Classification_ paper.

- Deep learning library: PyTorch
- NLP Library: spaCy

### Instaling dependencies

In [None]:
!pip install tqdm
!python -m spacy download pt

Collecting pt_core_news_sm==2.2.5
[?25l  Downloading https://github.com/explosion/spacy-models/releases/download/pt_core_news_sm-2.2.5/pt_core_news_sm-2.2.5.tar.gz (21.2MB)
[K     |████████████████████████████████| 21.2MB 120.2MB/s 
Building wheels for collected packages: pt-core-news-sm
  Building wheel for pt-core-news-sm (setup.py) ... [?25l[?25hdone
  Created wheel for pt-core-news-sm: filename=pt_core_news_sm-2.2.5-cp36-none-any.whl size=21186283 sha256=1502cfd97f3b2c1b384c1b933ae674d5163c9453392d3b59d2fe81236512f3ec
  Stored in directory: /tmp/pip-ephem-wheel-cache-e6xyhdbo/wheels/ea/94/74/ec9be8418e9231b471be5dc7e1b45dd670019a376a6b5bc1c0
Successfully built pt-core-news-sm
Installing collected packages: pt-core-news-sm
Successfully installed pt-core-news-sm-2.2.5
[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('pt_core_news_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.6/dist-packages/pt_core_news_sm -->
/usr/loca

### Mounting Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = '/content/gdrive/My Drive/'

Mounted at /content/gdrive


### Application parameters

In [None]:
S = 500 # sentence length
BATCH_SIZE = 64
NUM_OF_CLASSES = 6

dataset_fraction = 1.0 # fraction of train and validation datasets to be used

MLP_HIDDEN_UNITS = 128
EMBEDDING_DIM = 100

dataset_dir = root_dir + 'Machine Learning/Victor datasets/'
model_path = dataset_dir + 'CNN/'
model_file = model_path + f'pytorch_model-{S}-{dataset_fraction}.pt'

### Loading and preprocessing datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data
from tqdm.notebook import trange, tqdm_notebook
import numpy as np
from datetime import datetime

In [None]:
if dataset_fraction == 1.0: # full dataset
  train_ds_file = 'train_small.csv'
  validation_ds_file = 'validation_small.csv'
else:
  train_ds_file = f'train_small.csv-croped_{dataset_fraction}.csv'
  validation_ds_file = f'validation_small.csv-croped_{dataset_fraction}.csv'

In [None]:
import spacy

spacy_pt = spacy.load('pt')

def tokenizer(text):
  return [tok.text for tok in spacy_pt.tokenizer(text)]


In [None]:
%%time

TEXT = data.Field(
    tokenize=tokenizer, 
    lower=True, 
    fix_length=S)
LABEL = data.Field(
    sequential=False, 
    unk_token=None)

train_data, valid_data, test_data = data.TabularDataset.splits(
    path=dataset_dir, 
    train=train_ds_file,
    validation=validation_ds_file, 
    test='test_small.csv', 
    format='csv', 
    skip_header = True, 
    fields=[(None, None), (None, None), (None, None), ('label', LABEL), (None, None), ('text', TEXT)])

CPU times: user 2min 8s, sys: 1.44 s, total: 2min 9s
Wall time: 2min 11s


In [None]:
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
  (train_data, valid_data, test_data),
  sort = False, #don't sort test/validation data
  batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),
  device=device)

### Model

In [None]:
class VictorCNN(nn.Module):

  def __init__(self, sentence_len, vocab_size, embed_dim, n_classes, mlp_h):
    """
    sentence_len: the length of the each input sentence.
    vocab_size:   the number of tokens in the vocabulary.
    embed_dim:    the dimension of each embedding word vector.
    n_classes:    number of classes, i.e., the output dimension of this NN.
    mlp_h:        number of hidden units of the MLP NN.
    """
    super(VictorCNN, self).__init__()
        
    self.word_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.cnn_a = self.create_cnn_layer(embed_dim, 256, 3, 1)
    self.cnn_b = self.create_cnn_layer(embed_dim, 256, 4, 2)
    self.cnn_c = self.create_cnn_layer(embed_dim, 256, 5, 2)
    self.max_pool = nn.MaxPool1d(50)
    self.dropout = nn.Dropout(p=0.5)
    self.linear_h = nn.Linear(3840, mlp_h)
    self.linear_o = nn.Linear(mlp_h, n_classes)

  def create_cnn_layer(self, n_channels, n_filters, kernel_size, padding):
    return nn.Sequential(
        nn.Conv1d(n_channels, n_filters, kernel_size, padding=padding), 
        nn.BatchNorm1d(n_filters), 
        nn.MaxPool1d(2)
    )

  def forward(self, sentence):
    # sentence.shape: (s_len, b_len)
    embeds = self.word_embeddings(sentence).permute(1, 2, 0) # embeds shape: (b_len, embedding_dim, s_len)
    a = self.cnn_a(embeds)
    b = self.cnn_b(embeds)
    c = self.cnn_c(embeds)
    x = torch.cat((a, b, c), dim=1)
    x = self.max_pool(x)
    x = torch.flatten(x, start_dim=1)
    x = self.dropout(x)
    x = F.relu(self.linear_h(x))
    x = self.linear_o(x)
    return x

### Training functions

In [None]:
from sklearn.metrics import f1_score

def compute_metrics(targets, predictions):
  f1_macro = f1_score(targets, np.argmax(predictions, axis=1), average='macro')
  return f1_macro

def train(model, iterator, optimizer, criterion, epoch):
  epoch_loss = 0
  model.train()
  for batch in tqdm_notebook(iterator, desc='Train', unit='batch', leave=False):
    optimizer.zero_grad()
    predictions = model(batch.text)
    loss = criterion(predictions, batch.label)
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()

  return epoch_loss / len(iterator)

def predict(model, iterator, set_name):
  model.eval()
  predictions = None
  targets = None
  with torch.no_grad():
    for batch in tqdm_notebook(iterator, desc=f'Predicting ({set_name})', unit='batch', leave=False):
      out = model(batch.text)
      if predictions == None:
        predictions = out
        targets = batch.label
      else:
        predictions = torch.cat([predictions, out], dim=0)
        targets = torch.cat([targets, batch.label], dim=0)
  
  return predictions.cpu().numpy(), targets.cpu().numpy()

def evaluate(model, iterator, set_name):  
  predictions, targets = predict(model, iterator, set_name)
  return compute_metrics(targets, predictions)


### Training

In [None]:
EPOCHS = 20
learning_rate = 1e-3

model = VictorCNN(S, len(TEXT.vocab), EMBEDDING_DIM, NUM_OF_CLASSES, MLP_HIDDEN_UNITS)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
%%time
import pandas as pd
from IPython.display import display, update_display

metrics_df = pd.DataFrame(columns=['Epoch', 'Loss (train)', 'F1 macro (train)', 'F1 macro (validation)'])
metrics_display = display(metrics_df, display_id='metrics_table')

best_valid_f1 = 0.0

for epoch in range(EPOCHS):
  train_loss = train(model, train_iterator, optimizer, criterion, epoch)
  train_f1_m = evaluate(model, train_iterator, 'train set')
  valid_f1_m = evaluate(model, valid_iterator, 'validation set')
  
  #saving
  if valid_f1_m > best_valid_f1:
    best_valid_f1 = valid_f1_m
    torch.save(model.state_dict(), model_file)

  #printing
  metrics_df.loc[epoch] = [epoch + 1, train_loss, train_f1_m, valid_f1_m]
  metrics_display.update(metrics_df)

Unnamed: 0,Epoch,Loss (train),F1 macro (train),F1 macro (validation)
0,1.0,0.95545,0.63318,0.619375
1,2.0,0.639045,0.794214,0.747756
2,3.0,0.512292,0.831055,0.777132
3,4.0,0.4345,0.865736,0.796191
4,5.0,0.363273,0.901452,0.812181
5,6.0,0.323134,0.904849,0.816269
6,7.0,0.27744,0.931104,0.827486
7,8.0,0.240107,0.930703,0.82666
8,9.0,0.204044,0.960224,0.839051
9,10.0,0.183545,0.964173,0.837131


HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…



HBox(children=(FloatProgress(value=0.0, description='Train', max=446.0, style=ProgressStyle(description_width=…



HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (validation set)', max=298.0, style=ProgressSt…

CPU times: user 7min 8s, sys: 2min 37s, total: 9min 46s
Wall time: 9min 54s


### Evaluation

In [None]:
def load_saved_model(file_name):
  m = VictorCNN(S, len(TEXT.vocab), EMBEDDING_DIM, NUM_OF_CLASSES, MLP_HIDDEN_UNITS)
  m = m.to(device)
  m.load_state_dict(torch.load(file_name, map_location=device))
  m.eval()
  return m

#model_file = model_path + 'pytorch_model-500-1.0.pt'
model = load_saved_model(model_file)

In [None]:
train_predictions, train_targets = predict(model, train_iterator, 'train set')
test_predictions, test_targets = predict(model, test_iterator, 'test set')

HBox(children=(FloatProgress(value=0.0, description='Predicting (train set)', max=446.0, style=ProgressStyle(d…



HBox(children=(FloatProgress(value=0.0, description='Predicting (test set)', max=1493.0, style=ProgressStyle(d…



In [None]:
from sklearn.metrics import classification_report

test_report = classification_report(
    test_targets, 
    np.argmax(test_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

train_report = classification_report(
    train_targets, 
    np.argmax(train_predictions, axis=1), 
    digits=4, 
    target_names=LABEL.vocab.itos)

print(test_report)

rep_file = open(model_file + "-test_report.txt", "wt")
rep_file.write('CNN evaluation report\n')
rep_file.write(f'Test {test_report}\n')
rep_file.write(f'Train {train_report}\n')
rep_file.write(f'learning rate: {learning_rate}\n')
rep_file.write(f'optimizer: {type(optimizer).__name__}\n')
rep_file.write(f'criterion: {type(criterion).__name__}\n')
rep_file.write(f'MLP_HIDDEN_UNITS: {MLP_HIDDEN_UNITS}\n')
rep_file.write(f'EMBEDDING_DIM: {EMBEDDING_DIM}\n')
rep_file.close()

                                  precision    recall  f1-score   support

                          outros     0.9821    0.9004    0.9395     85408
                   peticao_do_RE     0.4650    0.8449    0.5999      6331
agravo_em_recurso_extraordinario     0.4539    0.7034    0.5518      1841
                        sentenca     0.5730    0.8590    0.6875      1475
          acordao_de_2_instancia     0.6448    0.9377    0.7642       273
     despacho_de_admissibilidade     0.5367    0.7020    0.6083       198

                        accuracy                         0.8920     95526
                       macro avg     0.6093    0.8246    0.6919     95526
                    weighted avg     0.9295    0.8920    0.9044     95526



References:

- https://medium.com/@rohit_agrawal/using-fine-tuned-gensim-word2vec-embeddings-with-torchtext-and-pytorch-17eea2883cd
