<a href="https://colab.research.google.com/github/abyaadrafid/LDA_Lab_Defence/blob/main/LegalBertClaudette.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
DEBUG = False
LR = 5e-6
MIN_LR = 1e-7
EPOCHS = 20
BS = 32

In [2]:
from google.colab import drive
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import random_split
from torch.optim import Adam
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from torch.utils.data import WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
import warnings

In [3]:
!pip install torch-lr-finder
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
from tqdm.notebook import tqdm
from transformers import BertTokenizer, AutoTokenizer, AutoModelForPreTraining

In [5]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
MODEL_NAME = 'nlpaueb/legal-bert-small-uncased'
EMBEDDING_SIZE = 512

In [7]:
def get_sentences(path):
    sentences= []
    for filename in os.listdir(path):
        with open(path+filename, 'r') as f:
            for sentence in f :
                sentences.append(sentence)
    return sentences

In [8]:
def get_labels(path):
    all_labels = []
    for filename in os.listdir(path):
        file_labels = []
        with open(path+filename, 'r') as f:
            for label in f :
                all_labels.append(int(label))
    return all_labels

In [9]:
all_sentences = get_sentences("/content/drive/MyDrive/Sentences/")

In [10]:
all_labels = get_labels("/content/drive/MyDrive/Labels/")

In [11]:
if DEBUG : 
  all_sentences = all_sentences[:20]
  all_labels = all_labels[:20]

In [12]:
all_labels =  [0 if label ==-1 else label for label in all_labels]

In [13]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self,sentences ,labels):

        self.labels = labels
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in tqdm(all_sentences)]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [14]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [15]:
dataset = Dataset(all_sentences, all_labels)

  0%|          | 0/9414 [00:00<?, ?it/s]

In [16]:
dataset_size = dataset.__len__()
train_count = int(dataset_size * 0.8)
valid_count = dataset_size - train_count

In [17]:
train_ds, valid_ds = random_split(dataset, [train_count, valid_count])

In [18]:
y_train_indices = train_ds.indices

y_train = [dataset.labels[i] for i in y_train_indices]

class_sample_count = np.array(
    [len(np.where(y_train == t)[0]) for t in np.unique(y_train)])

In [19]:
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_train])
samples_weight = torch.from_numpy(samples_weight)

In [20]:
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

In [21]:
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):

    def __init__(self):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(MODEL_NAME)
        #self.dropout = nn.Dropout(0.3)
        self.l1 = nn.Linear(EMBEDDING_SIZE, 512)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(512, 2)

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        linear_output = self.l2(self.relu(self.l1(pooled_output)))

        return linear_output

In [22]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=1.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

In [23]:
use_cuda = torch.cuda.is_available()

In [24]:
device = torch.device("cuda" if use_cuda else "cpu")

In [25]:
model = BertClassifier().to(device)

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [26]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BS, sampler= sampler)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=BS)

In [27]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr = LR, amsgrad= True)
#scheduler = CosineAnnealingLR(optimizer, 600, eta_min = MIN_LR)

In [28]:
def train_loop(n_epochs) :
  tr_metrics = []
  va_metrics = []
  best_valid_f1 = 0.0
  model.train()
  for epoch in tqdm(range(n_epochs)) :
    train_loss = 0.0
    train_preds = []
    train_targets = []
    model.train()
    for inputs, targets in tqdm(train_dl) :
  
      masks = inputs['attention_mask'].to(device)
      input_ids = inputs['input_ids'].squeeze(1).to(device)
      outputs = model(input_ids, masks)
      oh_targets = F.one_hot(targets, num_classes=2).to(torch.float32).to(device)

      loss = criterion(outputs, oh_targets)
      train_loss += loss

      train_preds.extend(outputs.argmax(dim=1).cpu().numpy())
      train_targets.extend(targets.numpy())

      model.zero_grad()
      loss.backward()
      optimizer.step()
      #scheduler.step()

    valid_loss = 0.0
    valid_preds = []
    valid_targets = []
    model.eval()
    with torch.no_grad():
      for inputs, targets in tqdm(valid_dl) :
        masks = inputs['attention_mask'].to(device)
        input_ids = inputs['input_ids'].squeeze(1).to(device)
        outputs = model(input_ids, masks)
        oh_targets = F.one_hot(targets, num_classes=2).to(torch.float32).to(device)

        loss = criterion(outputs, oh_targets)
        valid_loss += loss

        valid_preds.extend(outputs.argmax(dim=1).cpu().numpy())
        valid_targets.extend(targets.numpy())
    
    train_acc = accuracy_score(train_targets, train_preds)
    train_precision = precision_score(train_targets, train_preds)
    train_recall = recall_score(train_targets, train_preds)
    train_f1 = f1_score(train_targets, train_preds)

    valid_acc = accuracy_score(valid_targets, valid_preds)
    valid_precision = precision_score(valid_targets, valid_preds)
    valid_recall = recall_score(valid_targets, valid_preds)
    valid_f1 = f1_score(valid_targets, valid_preds)

    print(
        f'Epoch {epoch+1} : \n\
        Train_loss : {train_loss/len(train_ds)}\n\
        Train_acc : {train_acc}\n\
        Train_F1 : {train_f1}\n\
        Train_precision : {train_precision}\n\
        Train_recall : {train_recall}\n\
        Valid_loss : {valid_loss/len(valid_ds)}\n\
        Valid_acc : {valid_acc}\n\
        Valid_F1 : {valid_f1}\n\
        Valid_precision : {valid_precision}\n\
        Valid_recall : {valid_recall}'
      )
    if best_valid_f1 < valid_f1 :
      best_valid_f1 = valid_f1 
      print("Model Performance updated, saving model")
      model_name = "/content/drive/MyDrive/best_valid_f1.pt"
      torch.save(model.state_dict(), model_name)

In [None]:
train_loop(EPOCHS)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/236 [00:00<?, ?it/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Epoch 1 : 
        Train_loss : 0.021665608510375023
        Train_acc : 0.5270216438719958
        Train_F1 : 0.47323277136941727
        Train_precision : 0.5213424568263278
        Train_recall : 0.4332520985648524
        Valid_loss : 0.021079046651721
        Valid_acc : 0.6463090812533192
        Valid_F1 : 0.1569620253164557
        Valid_precision : 0.10671256454388985
        Valid_recall : 0.2966507177033493
Model Performance updated, saving model


  0%|          | 0/236 [00:00<?, ?it/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Epoch 2 : 
        Train_loss : 0.02138662151992321
        Train_acc : 0.5663258531403532
        Train_F1 : 0.5911367050575864
        Train_precision : 0.5636190021484841
        Train_recall : 0.6214793366675441
        Valid_loss : 0.02146766521036625
        Valid_acc : 0.5379713223579394
        Valid_F1 : 0.1730038022813688
        Valid_precision : 0.10794780545670225
        Valid_recall : 0.4354066985645933
Model Performance updated, saving model


  0%|          | 0/236 [00:00<?, ?it/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Epoch 3 : 
        Train_loss : 0.020820822566747665
        Train_acc : 0.6054972779179392
        Train_F1 : 0.6075815612204465
        Train_precision : 0.6038330270412182
        Train_recall : 0.6113769271664008
        Valid_loss : 0.024474384263157845
        Valid_acc : 0.3574083908656399
        Valid_F1 : 0.1868279569892473
        Valid_precision : 0.10867865519937452
        Valid_recall : 0.6650717703349283
Model Performance updated, saving model


  0%|          | 0/236 [00:00<?, ?it/s]