<a href="https://colab.research.google.com/github/alexlimatds/fact_extraction/blob/main/AILA2020/FACTS_AILA_LEGAL_BERT_BASE_cv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Facts extraction with AILA data and LEGAL-BERT-BASE

The model is evaluated through a cross-validation but this notebook doesn't perform the hole cross-validation process. Instead, it runs just one fold of the cross-validation. The fold to be used is indicated by the ``fold_id`` variable.

We use the train dataset from AILA 2020. This can be obtained at https://github.com/Law-AI/semantic-segmentation;

LEGAL-BERT is available at https://huggingface.co/nlpaueb/legal-bert-base-uncased.

https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb

### Notebook parameters

In [None]:
model_id = 'nlpaueb/legal-bert-base-uncased'
model_reference = 'legal-bert-base'
fold_id = 0

### Dependencies

In [None]:
!pip install transformers

In [None]:
import transformers
transformers.__version__

In [None]:
import torch
torch.__version__

In [None]:
import sklearn
sklearn.__version__

In [None]:
import numpy as np
np.__version__

### Loading dataset

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
g_drive_dir = '/content/gdrive/MyDrive/'
dataset_dir = 'fact_extraction_AILA/'

In [None]:
!rm -r data
!mkdir data
!mkdir data/train
!tar -xf {g_drive_dir}{dataset_dir}/train.tar.xz -C data/train

train_dir = 'data/train/'

In [None]:
from os import listdir
import pandas as pd
import csv

def read_docs(dir_name):
  """
  Read the docs in a directory.
  Params:
    dir_name : the directory that contains the documents.
  Returns:
    A dictionary whose keys are the names of the read files and the values are 
    pandas dataframes. Each dataframe has sentence and label columns.
  """
  docs = {} # key: file name, value: dataframe with sentences and labels
  for f in listdir(dir_name):
    df = pd.read_csv(
        dir_name + f, 
        sep='\t', 
        quoting=csv.QUOTE_NONE, 
        names=['sentence', 'label'])
    docs[f] = df
  return docs

docs_dic = read_docs(train_dir)

print('Number of documents: ', len(docs_dic))

### Tokenizer and Dataset preparation

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
from torch.utils.data import Dataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class MyDataset(Dataset):
  def __init__(self, sentences, labels, tokenizer):
    """
    Arguments:
      sentences : list of strings.
      lables : list of strings.
    """
    self.len = len(sentences)
    self.targets = []
    for l in labels:
      if l == 'Facts':
        self.targets.append(1.0)
      elif l == 'Other':
        self.targets.append(0.0)
      else:
        raise ValueError('Unknown label: ', l)
    self.targets = torch.tensor(self.targets, dtype=torch.float)
    self.data = tokenizer(
      sentences, 
      None,
      add_special_tokens=True,
      padding='longest', 
      return_token_type_ids=True,
      truncation=True, 
      return_tensors='pt'
    )
    self.data['input_ids'].to(device)
    self.data['attention_mask'].to(device)
    self.targets.to(device)

  def __getitem__(self, index):
    return {
      'ids': self.data['input_ids'][index],
      'mask': self.data['attention_mask'][index],
      'targets': self.targets[index]
    }
  
  def __len__(self):
    return self.len


In [None]:
df_folds = pd.read_csv(
  g_drive_dir + dataset_dir + 'train_docs_by_fold.csv', 
  sep=';', 
  names=['fold id', 'train', 'test'], 
  header=0)

for _, row in df_folds.iterrows():
  if row['fold id'] == fold_id:
    train_files = row['train'].split(',')
    test_files = row['test'].split(',')

print('Train documents: ', train_files)
print('Test documents: ', test_files)

In [None]:
def get_dataset(docs_list):
  sentences = []
  labels = []
  for doc_id in docs_list:
    sentences.extend(docs_dic[doc_id]['sentence'].to_list())
    labels.extend(docs_dic[doc_id]['label'].to_list())
  #return MyDataset(sentences[:11], labels[:11], tokenizer) # for code validation
  return MyDataset(sentences, labels, tokenizer)

ds_train = get_dataset(train_files)
ds_test = get_dataset(test_files)

### Model

In [None]:
class SentenceClassifier(torch.nn.Module):
  def __init__(self):
    super(SentenceClassifier, self).__init__()
    self.mlp_hidden_dim = 100
    self.bert = transformers.AutoModel.from_pretrained(model_id)
    self.pre_classifier = torch.nn.Linear(768, self.mlp_hidden_dim) # 768 is the embedding dimension of BERT's hidden layers
    self.dropout = torch.nn.Dropout(0.3)
    self.classifier = torch.nn.Linear(self.mlp_hidden_dim, 1)

  def forward(self, input_ids, attention_mask):
    # input_ids.shape: (batch_size, seq_len)
    # attention_mask.shape: (batch_size, seq_len)
    output_1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    hidden_state = output_1[0]    # hidden states of last BERT's layer => shape: (batch_size, seq_len, embedd_dim)
    pooler = hidden_state[:, 0]   # hidden state of the CLS token => shape: (batch_size, embedd_dim)
    pooler = self.pre_classifier(pooler)  # shape: (batch_size, mlp_hidden_dim)
    pooler = torch.nn.ReLU()(pooler)      # shape: (batch_size, mlp_hidden_dim)
    pooler = self.dropout(pooler)         # shape: (batch_size, mlp_hidden_dim)
    output = self.classifier(pooler)      # shape: (batch_size, 1)
    return output

  def predict(self, input_ids, attention_mask):
    self.eval()
    output = self.forward(input_ids, attention_mask).squeeze()
    output[output >= 0] = 1
    output[output < 0] = 0
    return output


### Evaluation

In [None]:
from sklearn.metrics import precision_recall_fscore_support
from tqdm.notebook import tqdm_notebook

test_metrics = {}

def evaluate(model, epoch, test_dataloader):
  predictions = None
  y_true = None
  #for data in test_dataloader:
  for data in tqdm_notebook(test_dataloader, desc='Evaluation'):
    y_hat = model.predict(data['ids'], data['mask'])
    if predictions is None:
      predictions = y_hat
      y_true = data['targets']
    else:
      predictions = torch.concat((predictions, y_hat))
      y_true = torch.concat((y_true, data['targets']))

  predictions = predictions.detach().to('cpu').numpy()
  y_true = y_true.detach().to('cpu').numpy()  
  # Precision, Recall, F1
  t_metrics = precision_recall_fscore_support(
    y_true, 
    predictions, 
    average='binary', 
    pos_label=1, 
    zero_division=0)
  print(f'Precision: {t_metrics[0]:.4f}')
  print(f'Recall:    {t_metrics[1]:.4f}')
  print(f'F-score:   {t_metrics[2]:.4f}')
  test_metrics[epoch] = {
      'precision': t_metrics[0], 
      'recall': t_metrics[1], 
      'f1': t_metrics[2]
  }

In [None]:
from datetime import datetime

def save_report(fold_id, metrics, dest_dir):
  """
  Arguments:
    fold_id : The identifier of the cross-validation fold.
    metrics : A dictionary with the metrics by epoch. The key indicates the epoch. 
              Each value must be a dictionary.
    dest_dir : The directory where the report will be saved.
  """
  report = 'OBS: the zero epoch concerns to the model\'s performance before any fine-tuning step.\n\n'
  report += 'epoch\t Precision   Recall   F1\n------------------------------------\n'
  for i in range(len(metrics)):
    dic = metrics[i]
    report += f'{i}\t {dic["precision"]:.4f}      {dic["recall"]:.4f}   {dic["f1"]:.4f}\n'
  
  with open(dest_dir + f'report-{model_reference}_fold-{fold_id}_{datetime.now().strftime("%Y-%m-%d")}.txt', 'w') as f:
    f.write(report)

### Fine-tuning

In [None]:
# model
sentence_classifier = SentenceClassifier()

In [None]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=8)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=8)

In [None]:
%%time
# evaluating model before fine-tunning
evaluate(sentence_classifier, 0, dl_test)

In [None]:
%%time
# Training params
n_epochs = 5
learning_rate = 5e-5

criterion = torch.nn.BCEWithLogitsLoss().to(device)
sentence_classifier.to(device)
optimizer = torch.optim.Adam(
  sentence_classifier.parameters(), 
  lr=learning_rate
)
for epoch in range(1, n_epochs + 1):
  print(f'*** Epoch {epoch} ***')
  for train_data in tqdm_notebook(dl_train, desc=f'Epoch {epoch}'):
    # training
    sentence_classifier.train()
    optimizer.zero_grad()
    y_hat = sentence_classifier(train_data['ids'], train_data['mask'])
    loss = criterion(y_hat.squeeze(), train_data['targets'])
    loss.backward()
    optimizer.step()
  # evaluation
  evaluate(sentence_classifier, epoch, dl_test)


In [None]:
save_report(fold_id, test_metrics, g_drive_dir + dataset_dir)

### Outline

In [None]:
from IPython.display import display, HTML

metrics_df = pd.DataFrame(columns=['Epoch', 'Precision', 'Recall', 'F1'])
for epoch in range(len(test_metrics)):
  metrics = test_metrics[epoch]
  metrics_df.loc[epoch] = [epoch, f'{metrics["precision"]:.4f}', f'{metrics["recall"]:.4f}', f'{metrics["f1"]:.4f}']
display(HTML(f'<br><span style="font-weight: bold">Test metrics: fold {fold_id}</span>'))
metrics_display = display(metrics_df, display_id='metrics_table')

### References

- I. Chalkidis, M. Fergadiotis, P. Malakasiotis, N. Aletras and I. Androutsopoulos. **"LEGAL-BERT: The Muppets straight out of Law School"**. In Findings of Empirical Methods in Natural Language Processing (EMNLP 2020) (Short Papers), to be held online, 2020. (https://aclanthology.org/2020.findings-emnlp.261)
- Paheli Bhattacharya, Shounak Paul, Kripabandhu Ghosh, Saptarshi Ghosh, and Adam Wyner. 2019. **Identification of Rhetorical Roles of Sentences in Indian Legal Judgments**. In Proc. International Conference on Legal Knowledge and Information Systems (JURIX).