# Legal Rhetorical Roles Classification using LEGAL-BERT and LEGAL-ToBERT

In these examples we show how to use LEGAL-BERT and LEGAL-ToBERT to perform rhetorical roles classification for your own legal documents.

First of all, some imports are required:

In [11]:
import os
from pprint import pprint

import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, matthews_corrcoef
import torch
from tqdm import tqdm

from rhetorical_roles_classification import (
    RhetoricalRolesDataset,
    RhetoricalRolesDatasetForTransformerOverBERT
)

## Load the Models

Our models are not stored in this repository but can be download from [here](https://drive.google.com/drive/folders/12U6XzXmWeNeYmwWG4QZNZrAffZqfMw9p?usp=sharing
). In these exmples we use BERT and ToBERT for the English language. These are associated with some configuration info, like the maximum document length supported by ToBERT, wich must be taken into account when using it for your own purposes.

In [2]:
MODELS_FOLDER = "./models/eng"

bert, bert_config = joblib.load(
    os.path.join(MODELS_FOLDER, "LEGAL-BERT.joblib")
)
tobert, tobert_config = joblib.load(
    os.path.join(MODELS_FOLDER, "LEGAL-ToBERT.joblib")
)

print("ToBERT Maximum supported document length:", tobert_config["max_document_length"])
print("Label to rhetorical role mapping:")
pprint(tobert_config["label2rhetRole"])

ToBERT Maximum supported document length: 386
Label to rhetorical role mapping:
{0: 'PREAMBLE',
 1: 'FAC',
 2: 'RLC',
 3: 'ISSUE',
 4: 'ARG_PETITIONER',
 5: 'ARG_RESPONDENT',
 6: 'ANALYSIS',
 7: 'STA',
 8: 'PRE_RELIED',
 9: 'PRE_NOT_RELIED',
 10: 'RATIO',
 11: 'RPC',
 12: 'NONE'}


## Prepare the Datasets

In these examples we use the BUILD public benchmark dataset. The dataset has to be preprocessed in such a way to be usable by LEGAL-BERT and LEGAL-ToBERT. `RhetoricalRolesDataset` and `RhetoricalRolesDatasetForTransformerOverBERT` do exactly this. Specifically:
- `RhetoricalRolesDataset` takes as input the path to a `.csv` file storing each input sentence in a `segments` column;
- `RhetoricalRolesDatasetForTransformerOverBERT` takes as input the path to a `.json` file consisting of a list of documents. Each document must be represented as a dictionary with a `segments` key associated with the list of sentences of the document.  

Some more arguments are necessary for tokenization: be sure to use the same values as the models configuration.

In [3]:
DATA_FOLDER = "./BUILD/data"

bert_dataset = RhetoricalRolesDataset(
    data_filepath=os.path.join(DATA_FOLDER, "test.csv"),
    max_segment_length=bert_config["max_segment_length"],
    tokenizer_model_name=bert_config["tokenizer_model_name"],
    has_labels=False
)
tobert_dataset = RhetoricalRolesDatasetForTransformerOverBERT(
    data_filepath=os.path.join(DATA_FOLDER, "test.json"),
    max_document_length=tobert_config["max_document_length"],
    max_segment_length=tobert_config["max_segment_length"],
    tokenizer_model_name=tobert_config["tokenizer_model_name"],
    has_labels=False
)

## Inference with LEGAL-BERT

Using LEGAL-BERT for inference is straightforward:

In [None]:
bert_dataloader = torch.utils.data.DataLoader(
    bert_dataset,
    batch_size=128
)

bert.eval()
bert_predictions = []
with torch.no_grad():
    for data in tqdm(bert_dataloader):
        output = bert(data, labels=None)
        logits = output.logits
        bert_predictions += logits.argmax(dim=-1).tolist()

## Inference with LEGAL-ToBERT

Using LEGAL-ToBERT for inference is straightforward, too. The only difference is that padding sentences predictions must be filtered out from the output of the model.
This is done by retrieving the number of sentences for each document, taking into account that padding sentences are all 0's vectors.

In [None]:
tobert_dataloader = torch.utils.data.DataLoader(
    tobert_dataset,
    batch_size=1
)

tobert.eval()
tobert_predictions = []
with torch.no_grad():
    for documents in tqdm(tobert_dataloader):
        output = tobert(documents, labels=None)
        logits = output.logits  # Shape: (batch_size = 1, max_document_length, num_labels)
        n_sentences = len([sentence for sentence in documents[0] if sentence.any()])
        tobert_predictions += logits.argmax(dim=-1).ravel().tolist()[:n_sentences]

## Evaluation

For demonstration purposes, we compute some relevant scores from BERT and ToBERT predictions:

In [12]:
df = pd.read_csv(os.path.join(DATA_FOLDER, "test.csv"))
labels = df.labels

print("BERT:")
print(f"\tAccuracy: {accuracy_score(labels, bert_predictions)}")
print(f"\tMCC: {matthews_corrcoef(labels, bert_predictions)}")

print("ToBERT:")
print(f"\tAccuracy: {accuracy_score(labels, tobert_predictions)}")
print(f"\tMCC: {matthews_corrcoef(labels, tobert_predictions)}")

BERT:
	Accuracy: 0.6561306009030914
	MCC: 0.5594308542642985
ToBERT:
	Accuracy: 0.7846474470302188
	MCC: 0.7267670113883938
