
![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)





[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/legal-nlp/07.1.Training_Legal_Assertion.ipynb)

# Train Legal Assertion


# Installation

In [None]:
! pip install -q johnsnowlabs

## Automatic Installation
Using my.johnsnowlabs.com SSO

In [None]:
from johnsnowlabs import nlp, legal

# nlp.install(force_browser=True)

## Manual downloading
If you are not registered in my.johnsnowlabs.com, you received a license via e-email or you are using Safari, you may need to do a manual update of the license.

- Go to my.johnsnowlabs.com
- Download your license
- Upload it using the following command

In [None]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

- Install it

In [None]:
nlp.install()

# Starting

In [None]:
spark = nlp.start()

# Data Prep 

In [None]:
! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/legal-nlp/data/assertion_fin.csv

In [None]:
import pandas as pd

training_df = pd.read_csv('./assertion_fin.csv')

In [None]:
training_data = spark.createDataFrame(training_df)
training_data.show()

+-------+--------------------+---------+-------+--------------------+------+---------------+
|task_id|            sentence|tkn_start|tkn_end|               chunk|entity|assertion_label|
+-------+--------------------+---------+-------+--------------------+------+---------------+
|      1|The Swedish East ...|        1|      4|Swedish East Indi...|   ORG|           PAST|
|      1|The Swedish East ...|        6|      8|Svenska Ostindisk...| ALIAS|           PAST|
|      1|The Swedish East ...|       10|     10|                SOIC| ALIAS|           PAST|
|      1|The Swedish East ...|       14|     14|          Gothenburg|   LOC|           PAST|
|      1|The Swedish East ...|       15|     15|              Sweden|   LOC|           PAST|
|      1|The Swedish East ...|       17|     17|                1731|  DATE|           PAST|
|      1|The Swedish East ...|       25|     25|               China|   LOC|           PAST|
|      1|The Swedish East ...|       28|     29|            Far East| 

In [None]:
training_data.printSchema()

root
 |-- task_id: long (nullable = true)
 |-- sentence: string (nullable = true)
 |-- tkn_start: long (nullable = true)
 |-- tkn_end: long (nullable = true)
 |-- chunk: string (nullable = true)
 |-- entity: string (nullable = true)
 |-- assertion_label: string (nullable = true)



In [None]:
%time training_data.count()

CPU times: user 12.2 ms, sys: 1.13 ms, total: 13.4 ms
Wall time: 830 ms


8050

In [None]:
(train_data, test_data) = training_data.randomSplit([0.9, 0.1], seed = 100)
print("Training Dataset Count: " + str(training_data.count()))
print("Test Dataset Count: " + str(test_data.count()))

Training Dataset Count: 8050
Test Dataset Count: 802


In [None]:
train_data.show()

+-------+--------------------+---------+-------+--------------------+------+---------------+
|task_id|            sentence|tkn_start|tkn_end|               chunk|entity|assertion_label|
+-------+--------------------+---------+-------+--------------------+------+---------------+
|      1|"Stockholms-varve...|        6|      6|           Stockholm|   LOC|           PAST|
|      1|"The funny busine...|        5|      8|Swedish East Indi...|   ORG|           PAST|
|      1|             (1998).|        0|      0|                1998|  DATE|           PAST|
|      1|2.5 tonnes) and t...|       34|     34|              Sweden|   LOC|           PAST|
|      1|37. Gothenburg: R...|        2|      7|Royal Society of ...|   ORG|           PAST|
|      1|= Decline and fal...|       11|     11|                1806|  DATE|           PAST|
|      1|= Early attempts ...|        9|     11|  Swedish East India|   ORG|           PAST|
|      1|= Early attempts ...|       19|     19|            merchant| 

# Using RoBerta Embeddings

In [None]:
roberta_embeddings = nlp.RoBertaEmbeddings.pretrained("roberta_embeddings_legal_roberta_base","en") \
    .setInputCols(["document", "token"]) \
    .setOutputCol("embeddings") \
    .setMaxSentenceLength(512)

roberta_embeddings_legal_roberta_base download started this may take some time.
Approximate size to download 447.2 MB
[OK!]


In [None]:
document = nlp.DocumentAssembler()\
    .setInputCol("sentence")\
    .setOutputCol("document")

chunk = nlp.Doc2Chunk()\
    .setInputCols("document")\
    .setOutputCol("doc_chunk")\
    .setChunkCol("chunk")\
    .setStartCol("tkn_start")\
    .setStartColByTokenIndex(True)\
    .setFailOnMissing(False)\
    .setLowerCase(False)

token = nlp.Tokenizer()\
    .setInputCols(['document'])\
    .setOutputCol('token')


We save the test data in parquet format to use in `AssertionDLApproach()`. 

In [None]:
assertion_pipeline = nlp.Pipeline(
    stages = [
    document,
    chunk,
    token,
    roberta_embeddings])

assertion_test_data = assertion_pipeline.fit(test_data).transform(test_data)

In [None]:
assertion_test_data.columns

['task_id',
 'sentence',
 'tkn_start',
 'tkn_end',
 'chunk',
 'entity',
 'assertion_label',
 'document',
 'doc_chunk',
 'token',
 'embeddings']

In [None]:
assertion_test_data.write.mode('overwrite').parquet('test_data.parquet')

In [None]:
assertion_train_data = assertion_pipeline.fit(training_data).transform(training_data)
assertion_train_data.write.mode('overwrite').parquet('train_data.parquet')

In [None]:
assertion_train_data.columns

['task_id',
 'sentence',
 'tkn_start',
 'tkn_end',
 'chunk',
 'entity',
 'assertion_label',
 'document',
 'doc_chunk',
 'token',
 'embeddings']

## Graph setup

In [None]:
!pip install -q tensorflow==2.7.0
!pip install -q tensorflow-addons

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m489.6/489.6 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m463.1/463.1 KB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m72.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25h

We will use TFGraphBuilder annotator which can be used to create graphs in the model training pipeline. 

TFGraphBuilder inspects the data and creates the proper graph if a suitable version of TensorFlow (<= 2.7 ) is available. The graph is stored in the defined folder and loaded by the approach.

In [None]:
graph_folder= "./tf_graphs"

In [None]:
assertion_graph_builder =  legal.TFGraphBuilder()\
    .setModelName("assertion_dl")\
    .setInputCols(["sentence", "token", "embeddings"]) \
    .setLabelColumn("assertion_label")\
    .setGraphFolder(graph_folder)\
    .setGraphFile("assertion_graph.pb")\
    .setMaxSequenceLength(1200)\
    .setHiddenUnitsNumber(25)

**Setting the Scope Window (Target Area) Dynamically in Assertion Status Detection Models**


This parameter allows you to train the Assertion Status Models to focus on specific context windows when resolving the status of a NER chunk. The window is in format `[X,Y]` being `X` the number of tokens to consider on the left of the chunk, and `Y` the max number of tokens to consider on the right. Let’s take a look at what different windows mean:


*   By default, the window is `[-1,-1]` which means that the Assertion Status will look at all of the tokens in the sentence/document (up to a maximum of tokens set in `setMaxSentLen()` ).
*   `[0,0]` means “don’t pay attention to any token except the ner_chunk”, what basically is not considering any context for the Assertion resolution.
*   `[9,15]` is what empirically seems to be the best baseline, meaning that we look up to 9 tokens on the left and 15 on the right of the ner chunk to understand the context and resolve the status.


Check this [Scope Window Tuning Assertion Status Detection notebook](https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/2.1.Scope_window_tuning_assertion_status_detection.ipynb)  that illustrates the effect of the different windows and how to properly fine-tune your AssertionDLModels to get the best of them.

In our case, the best Scope Window is around [10,10]

In [None]:
scope_window = [50, 50]

assertionStatus = legal.AssertionDLApproach()\
    .setLabelCol("assertion_label")\
    .setInputCols("document", "doc_chunk", "embeddings")\
    .setOutputCol("assertion")\
    .setBatchSize(128)\
    .setLearningRate(0.001)\
    .setEpochs(2)\
    .setStartCol("tkn_start")\
    .setEndCol("tkn_end")\
    .setMaxSentLen(1200)\
    .setEnableOutputLogs(True)\
    .setOutputLogsPath('training_logs/')\
    .setGraphFolder(graph_folder)\
    .setGraphFile(f"{graph_folder}/assertion_graph.pb")\
    .setTestDataset(path="test_data.parquet", read_as='SPARK', options={'format': 'parquet'})\
    .setScopeWindow(scope_window)
    #.setValidationSplit(0.2)\    
    #.setDropout(0.1)\    

In [None]:
clinical_assertion_pipeline = nlp.Pipeline(
    stages = [
    #document,
    #chunk,
    #token,
    #embeddings,
    assertion_graph_builder,
    assertionStatus])

In [None]:
training_data.printSchema()

root
 |-- task_id: long (nullable = true)
 |-- sentence: string (nullable = true)
 |-- tkn_start: long (nullable = true)
 |-- tkn_end: long (nullable = true)
 |-- chunk: string (nullable = true)
 |-- entity: string (nullable = true)
 |-- assertion_label: string (nullable = true)



In [None]:
assertion_train_data = spark.read.parquet('train_data.parquet')

In [None]:
%%time
assertion_model = clinical_assertion_pipeline.fit(assertion_train_data)

TF Graph Builder configuration:
Model name: assertion_dl
Graph folder: ./tf_graphs
Graph file name: assertion_graph.pb
Build params: {'n_classes': 4, 'feat_size': 768, 'max_seq_len': 1200, 'n_hidden': 25}


Instructions for updating:
non-resource variables are not supported in the long term


Device mapping: no known devices.


Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Device mapping: no known devices.
assertion_dl graph exported to ./tf_graphs/assertion_graph.pb
CPU times: user 15.1 s, sys: 1.48 s, total: 16.6 s
Wall time: 19min 4s


Checking the results saved in the log file

In [None]:
import os

log_files = os.listdir("./training_logs")
log_files

['AssertionDLApproach_2d70e95b57ae.log']

In [None]:
with open("./training_logs/"+log_files[0]) as log_file:
    print(log_file.read())

Name of the selected graph: ./tf_graphs/assertion_graph.pb
Training started, trainExamples: 8050


Epoch: 0 started, learning rate: 0.001, dataset size: 8050
Done, 500.450298869 total training loss: 52.26979, avg training loss: 0.8296792, batches: 63
Quality on test dataset: 
time to finish evaluation: 44.26s
Total test loss: 2.1977	Avg test loss: 0.3140
label	 tp	 fp	 fn	 prec	 rec	 f1
PRESENT	 219	 26	 22	 0.89387757	 0.9087137	 0.90123457
POSSIBLE	 166	 27	 22	 0.8601036	 0.88297874	 0.87139106
FUTURE	 95	 6	 25	 0.9405941	 0.7916667	 0.8597286
PAST	 234	 29	 19	 0.88973385	 0.9249012	 0.90697676
tp: 714 fp: 88 fn: 88 labels: 4
Macro-average	 prec: 0.8960773, rec: 0.8770651, f1: 0.8864693
Micro-average	 prec: 0.8902743, rec: 0.8902743, f1: 0.8902743


Epoch: 1 started, learning rate: 9.5E-4, dataset size: 8050
Done, 508.122675912 total training loss: 17.050142, avg training loss: 0.27063718, batches: 63
Quality on test dataset: 
time to finish evaluation: 43.71s
Total test loss: 1.1

In [None]:
assertion_test_data = spark.read.parquet('test_data.parquet')

In [None]:
preds = assertion_model.transform(assertion_test_data).select('assertion_label','assertion.result')

preds.show()

+---------------+---------+
|assertion_label|   result|
+---------------+---------+
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|        PRESENT|[PRESENT]|
|           PAST|   [PAST]|
|        PRESENT|[PRESENT]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST| [FUTURE]|
|        PRESENT|[PRESENT]|
|        PRESENT|[PRESENT]|
|        PRESENT|[PRESENT]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
+---------------+---------+
only showing top 20 rows



In [None]:
preds_df = preds.toPandas()

In [None]:
preds_df["result"] = preds_df["result"].apply(lambda x: x[0] if len(x) else pd.NA)
preds_df.dropna(inplace=True)

preds_df

Unnamed: 0,assertion_label,result
0,PAST,PAST
1,PAST,PAST
2,PAST,PAST
3,PRESENT,PRESENT
4,PAST,PAST
...,...,...
797,PRESENT,PRESENT
798,PRESENT,PRESENT
799,PRESENT,PRESENT
800,PRESENT,PRESENT


In [None]:
# We are going to use sklearn to evalute the results on test dataset
from sklearn.metrics import classification_report

print (classification_report( preds_df['assertion_label'], preds_df['result']))

              precision    recall  f1-score   support

      FUTURE       0.90      0.97      0.94       120
        PAST       0.95      0.96      0.95       250
    POSSIBLE       0.96      0.94      0.95       188
     PRESENT       0.97      0.94      0.95       241

    accuracy                           0.95       799
   macro avg       0.95      0.95      0.95       799
weighted avg       0.95      0.95      0.95       799



### Saving the trained model

In [None]:
assertion_model.stages

[TFGraphBuilderModel_2ab67fe97fd2, LEGAL-ASSERTION_DL_6b6a15bdecd2]

In [None]:
# Save a Spark NLP model
assertion_model.stages[-1].write().overwrite().save('Assertion')

# cd into saved dir and zip
! cd /content/Assertion ; zip -r /content/Assertion.zip *