In [1]:
#FINE TUNING
#1. Model -> bert-large-uncased-whole-word-masking-finetuned-squad
#2. Dataset -> Legal Dataset
#3. Batch size -> 16

#Steps
#1. Loading the Dataset
#2. Processing the Data
#3. Fine Tuning
#4. Evaluation

In [2]:
!pip install datasets
!pip install transformers

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/46/1a/b9f9b3bfef624686ae81c070f0a6bb635047b17cdb3698c7ad01281e6f9a/datasets-1.6.2-py3-none-any.whl (221kB)
[K     |█▌                              | 10kB 14.1MB/s eta 0:00:01[K     |███                             | 20kB 17.8MB/s eta 0:00:01[K     |████▍                           | 30kB 15.4MB/s eta 0:00:01[K     |██████                          | 40kB 14.0MB/s eta 0:00:01[K     |███████▍                        | 51kB 12.2MB/s eta 0:00:01[K     |████████▉                       | 61kB 12.9MB/s eta 0:00:01[K     |██████████▍                     | 71kB 10.4MB/s eta 0:00:01[K     |███████████▉                    | 81kB 11.1MB/s eta 0:00:01[K     |█████████████▎                  | 92kB 10.5MB/s eta 0:00:01[K     |██████████████▊                 | 102kB 10.4MB/s eta 0:00:01[K     |████████████████▎               | 112kB 10.4MB/s eta 0:00:01[K     |█████████████████▊              | 122kB 10

**Fine Tuning Parameters**

In [3]:
#Pretrained Model Name
model = "bert-large-uncased-whole-word-masking-finetuned-squad"

#Dataset Files
training_file = "train_dataset.json"
validation_file = "validate_dataset.json"
test_file = "test_dataset.json"

#Batch_size for training
batch_size = 16

**Loading Dataset**

In [57]:
#Loading the Dataset
from datasets import load_dataset
dataset = load_dataset('json', data_files={'train': training_file, 'validation': validation_file, 'test': test_file}, field='data')
print(dataset)

Using custom data configuration default-df13f75c8e6e1c60


Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/json/default-df13f75c8e6e1c60/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-df13f75c8e6e1c60/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02. Subsequent calls will reuse this data.
DatasetDict({
    train: Dataset({
        features: ['title', 'id', 'context', 'question', 'answer'],
        num_rows: 2365
    })
    validation: Dataset({
        features: ['title', 'id', 'context', 'question', 'answer'],
        num_rows: 79
    })
    test: Dataset({
        features: ['title', 'id', 'context', 'question', 'answer'],
        num_rows: 119
    })
})


**Processing Data**

In [5]:
#Processing the Data
import transformers
from transformers import AutoTokenizer

#Fetching the Fast Tokenizer of the respective model
tokenizer = AutoTokenizer.from_pretrained(model)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

#Function to prepare the train features for training
def prepare_train_features(examples):
    max_length = 480    #max length of the input(question+context)
    doc_stride = 128    #length of overlap between consecutive features of the same example

    # Tokenizing with truncation and padding, but keeping the overflows using a stride.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # One example might give us several features if it has a long context, therefore,
    # we need a map from the feature to its corresponding example
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # The offset mappings will give us a map from token to character positions in the original context
    offset_mapping = tokenized_examples.pop("offset_mapping")

    #labeling of the samples
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]

        # We will label impossible answers with the index of the CLS token.
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the original example number containing this span of text.
        sample_index = sample_mapping[i]
        answer = examples["answer"][sample_index]
        
        # Start/end character index of the answer in the text.
        start_char = answer["answer_start"]
        end_char = answer["answer_end"] + 1

        # Start token index of the current context in the text.
        token_start_index = 0
        while sequence_ids[token_start_index] != 1:
            token_start_index += 1

        # End token index of the current context in the text.
        token_end_index = len(input_ids) - 1
        while sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
        if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
            # Note: we could go after the last offset if the answer is the last word (edge case).
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            tokenized_examples["start_positions"].append(token_start_index - 1)
            while offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=443.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [6]:
#Preparing the tokenized data input
tokenized_dataset = dataset.map(prepare_train_features, batched=True, remove_columns=dataset["train"].column_names)

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




**Fine Tuning**

In [7]:
#Fine Tuning
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
model = AutoModelForQuestionAnswering.from_pretrained(model)

#Not changing the base parameters of the model(Non-Task specific layer)
for param in model.base_model.parameters():
    param.requires_grad = False

#Defining the Training Arguments
args = TrainingArguments(
    f"test-squad",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=4,
    weight_decay=0.01,
    logging_dir = './logs',
    logging_steps = 10
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1340675298.0, style=ProgressStyle(descr…




In [8]:
#Fetching the Data Collator to batch the processed examples
from transformers import default_data_collator
data_collator = default_data_collator

In [9]:
#Trainer Object for training
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

#Train the model (with evaluation loss only)
trainer.train()

Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,0.5924,0.289995,16.8729,4.682
2,0.5089,0.275751,16.898,4.675
3,0.6359,0.267938,16.8855,4.679
4,0.5231,0.265579,16.9044,4.673


TrainOutput(global_step=604, training_loss=0.5449028508552652, metrics={'train_runtime': 2219.5599, 'train_samples_per_second': 0.272, 'total_flos': 9298620525404160.0, 'epoch': 4.0, 'init_mem_cpu_alloc_delta': 283545600, 'init_mem_gpu_alloc_delta': 1337192960, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 12804096, 'train_mem_gpu_alloc_delta': 211968, 'train_mem_cpu_peaked_delta': 2162688, 'train_mem_gpu_peaked_delta': 1589934080})

In [10]:
#Saving the trained model
trainer.save_model("LegalTrained")

**Evaluation**

In [58]:
#Evaluation
#Function to prepare the validation features for evaluation
def prepare_test_features(examples):
    max_length = 480    #max length of input(question + context)
    doc_stride = 128    #length of overlap between consecutive features of the same example

    # Tokenizing with truncation and padding, but keeping the overflows using a stride
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # One example might give us several features if it has a long context, therefore,
    # we need a map from a feature to its corresponding example
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the original example number containing this context
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == 1 else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [59]:
#preparing the validation features
test_features = dataset["test"].map(
    prepare_test_features,
    batched=True,
    remove_columns=dataset["test"].column_names
)

#Getting the predictions on the test set
raw_predictions = trainer.predict(test_features)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [60]:
#Getting the column names of the test_features
test_features.set_format(type=test_features.format["type"], columns=list(test_features.features.keys()))

In [61]:
from tqdm.auto import tqdm
import numpy as np
import collections

#Function to find the best possible answers using the raw predictions
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Building a mapping of examples to its feature list
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # final predictions to be evaluated
    predictions = collections.OrderedDict()

    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        predictions[example["id"]] = best_answer["text"]
    return predictions

In [62]:
#finding the final predictions (the best possible answers)
final_predictions = postprocess_qa_predictions(dataset["test"], test_features, raw_predictions.predictions)

HBox(children=(FloatProgress(value=0.0, max=119.0), HTML(value='')))




In [75]:
#Preparing the prediction and reference items for evaluation
predictions = [v for k, v in final_predictions.items()]
references = []
for ex in dataset['test']:
    references.append(ex["context"][ex["answer"]["answer_start"] : ex["answer"]["answer_end"]+1])

In [78]:
#Evaluation Metric Function
import re
from collections import Counter

#Function to compute the Accuracy and the F1 score
#Predictions -> Predictions by the model
#References -> Reference items, to compare with the predictions
def compute(predictions, references):
    predictions_score = 0
    common_cnt = 0
    predicted_cnt = 0
    reference_cnt = 0
    sample_pred = list(predictions)
    sample_ref = list(references)

    for i in range(len(references)) :
        #Converting all the lower case
        sample_ref[i] = sample_ref[i].lower()
        sample_pred[i] = sample_pred[i].lower()

        #Removing the punctuations
        sample_ref[i] = re.sub(r'[^\w\s]', ' ', sample_ref[i])
        sample_pred[i] = re.sub(r'[^\w\s]', ' ', sample_pred[i])

        print(sample_ref[i])
        print(sample_pred[i])
        #Fetching the individual tokens
        sample_ref[i] = sample_ref[i].split()
        sample_pred[i] = sample_pred[i].split()

        #Finding the number of common words between the predicted item and the reference item
        cnt_common = sum(( Counter(sample_ref[i]) & Counter(sample_pred[i]) ).values())

        #Evaluating the difference counts for f1 measurement
        common_cnt += cnt_common
        predicted_cnt += len(sample_pred[i])
        reference_cnt += len(sample_ref[i])
        #Evaluating the match fraction of the prediction with the reference
        match_ratio = cnt_common / len(sample_ref[i])

        #Evaluating the scores based on the amount of match
        if match_ratio >= 0.25 and match_ratio < 0.5 :
            predictions_score += 1
        elif match_ratio >= 0.5 and match_ratio < 0.75 :
            predictions_score += 2
        elif match_ratio >= 0.75 and match_ratio < 1 :
            predictions_score += 3
        elif match_ratio == 1 :
            predictions_score += 4

    #Total score for references
    reference_score = 4 * len(sample_ref)

    #Various Evaluation Metrics
    accuracy = predictions_score / reference_score * 100
    precision = common_cnt / predicted_cnt
    recall = common_cnt / reference_cnt
    F1 = ((2 * precision * recall) / (precision + recall)) * 100

    #Sending the result
    metric = dict()
    metric['accuracy'] = accuracy
    metric['precision'] = precision
    metric['recall'] = recall
    metric['f1'] = F1
    return metric

In [79]:
#Evaluating the results
result = compute(predictions, references)
print("Accuracy : " + str(result['accuracy']))
print("Precision : " + str(result['precision']))
print("Recall : " + str(result['recall']))
print("F1 : " + str(result['f1']))

1982
1982
the hague
the hague
the convention for the suppression of unlawful seizure of aircraft
the convention for the suppression of unlawful seizure of aircraft
anti hijacking act  1982
the anti hijacking act  1982
official gazette
official gazette
2016
2016
the central government
the central government
any offence thereunder committed outside india by any person
any offence thereunder committed outside india by any person
the whole of india
the whole of india
security personnel
security personnel
national investigation agency act  2008
national investigation agency act  2008
unlawful seizure of aircraft in flight
unlawful seizure of aircraft in flight
hostage taking
hostage taking
hijacking
hijacking
criminal prosecution
criminal prosecution
sub  section  1 
sub  section  1 
hijacking
hijacking
twenty four hours
twenty four hours
imprisonment for life
imprisonment for life
death
death
india
india
all officers of police and all officers of government
all officers of police and all o