![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

# Train Legal Assertion

In [0]:
from johnsnowlabs import * 

# Data Prep

In [0]:
! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Legal/data/assertion_fin.csv
dbutils.fs.cp("file:/databricks/driver/assertion_fin.csv", "dbfs:/Finance") 

In [0]:
import pandas as pd

training_df = pd.read_csv('/dbfs/Finance/assertion_fin.csv')

In [0]:
training_data = spark.createDataFrame(training_df)
training_data.show()

In [0]:
training_data.printSchema()

In [0]:
%time training_data.count()

In [0]:
(train_data, test_data) = training_data.randomSplit([0.9, 0.1], seed = 100)
print("Training Dataset Count: " + str(training_data.count()))
print("Test Dataset Count: " + str(test_data.count()))

In [0]:
train_data.show()

# Using RoBerta Embeddings

In [0]:
roberta_embeddings = nlp.RoBertaEmbeddings.pretrained("roberta_embeddings_legal_roberta_base","en") \
    .setInputCols(["document", "token"]) \
    .setOutputCol("embeddings") \
    .setMaxSentenceLength(512)

In [0]:
document = nlp.DocumentAssembler()\
    .setInputCol("sentence")\
    .setOutputCol("document")

chunk = nlp.Doc2Chunk()\
    .setInputCols("document")\
    .setOutputCol("doc_chunk")\
    .setChunkCol("chunk")\
    .setStartCol("tkn_start")\
    .setStartColByTokenIndex(True)\
    .setFailOnMissing(False)\
    .setLowerCase(False)

token = nlp.Tokenizer()\
    .setInputCols(['document'])\
    .setOutputCol('token')


We save the test data in parquet format to use in `AssertionDLApproach()`.

In [0]:
assertion_pipeline = Pipeline(
    stages = [
    document,
    chunk,
    token,
    roberta_embeddings])

assertion_test_data = assertion_pipeline.fit(test_data).transform(test_data)

In [0]:
assertion_test_data.columns

Out[14]: ['task_id',
 'sentence',
 'tkn_start',
 'tkn_end',
 'chunk',
 'entity',
 'assertion_label',
 'document',
 'doc_chunk',
 'token',
 'embeddings']

In [0]:
assertion_test_data.write.mode('overwrite').parquet('/dbfs/test_data.parquet')

In [0]:
assertion_train_data = assertion_pipeline.fit(training_data).transform(training_data)
assertion_train_data.write.mode('overwrite').parquet('/dbfs/train_data.parquet')

In [0]:
assertion_train_data.columns

Out[17]: ['task_id',
 'sentence',
 'tkn_start',
 'tkn_end',
 'chunk',
 'entity',
 'assertion_label',
 'document',
 'doc_chunk',
 'token',
 'embeddings']

## Graph setup

We will use TFGraphBuilder annotator which can be used to create graphs in the model training pipeline. 

TFGraphBuilder inspects the data and creates the proper graph if a suitable version of TensorFlow (<= 2.7 ) is available. The graph is stored in the defined folder and loaded by the approach.

In [0]:
graph_folder= "/dbfs/tf_graphs"

In [0]:
assertion_graph_builder =  legal.TFGraphBuilder()\
    .setModelName("assertion_dl")\
    .setInputCols(["sentence", "token", "embeddings"]) \
    .setLabelColumn("assertion_label")\
    .setGraphFolder(graph_folder)\
    .setGraphFile("assertion_graph.pb")\
    .setMaxSequenceLength(1200)\
    .setHiddenUnitsNumber(25)

**Setting the Scope Window (Target Area) Dynamically in Assertion Status Detection Models**


This parameter allows you to train the Assertion Status Models to focus on specific context windows when resolving the status of a NER chunk. The window is in format `[X,Y]` being `X` the number of tokens to consider on the left of the chunk, and `Y` the max number of tokens to consider on the right. Let’s take a look at what different windows mean:


*   By default, the window is `[-1,-1]` which means that the Assertion Status will look at all of the tokens in the sentence/document (up to a maximum of tokens set in `setMaxSentLen()` ).
*   `[0,0]` means “don’t pay attention to any token except the ner_chunk”, what basically is not considering any context for the Assertion resolution.
*   `[9,15]` is what empirically seems to be the best baseline, meaning that we look up to 9 tokens on the left and 15 on the right of the ner chunk to understand the context and resolve the status.


Check this [Scope Window Tuning Assertion Status Detection notebook](https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/2.1.Scope_window_tuning_assertion_status_detection.ipynb)  that illustrates the effect of the different windows and how to properly fine-tune your AssertionDLModels to get the best of them.

In our case, the best Scope Window is around [10,10]

In [0]:
scope_window = [50, 50]

assertionStatus = legal.AssertionDLApproach()\
    .setLabelCol("assertion_label")\
    .setInputCols("document", "doc_chunk", "embeddings")\
    .setOutputCol("assertion")\
    .setBatchSize(128)\
    .setLearningRate(0.001)\
    .setEpochs(2)\
    .setStartCol("tkn_start")\
    .setEndCol("tkn_end")\
    .setMaxSentLen(1200)\
    .setEnableOutputLogs(True)\
    .setOutputLogsPath('dbfs:/training_logs/')\
    .setGraphFolder(graph_folder)\
    .setGraphFile(f"{graph_folder}/assertion_graph.pb")\
    .setTestDataset(path="dbfs:/test_data.parquet", read_as='SPARK', options={'format': 'parquet'})\
    .setScopeWindow(scope_window)
    #.setValidationSplit(0.2)\    
    #.setDropout(0.1)\    

In [0]:
clinical_assertion_pipeline = Pipeline(
    stages = [
    #document,
    #chunk,
    #token,
    #embeddings,
    assertion_graph_builder,
    assertionStatus])

In [0]:
training_data.printSchema()

In [0]:
assertion_train_data = spark.read.parquet('/dbfs/train_data.parquet')

In [0]:
%%time
assertion_model = clinical_assertion_pipeline.fit(assertion_train_data)

Checking the results saved in the log file

In [0]:
import os

log_files = os.listdir("/dbfs/training_logs")
log_files

Out[27]: ['AssertionDLApproach_6026f20884ae.log']

In [0]:
with open("/dbfs/training_logs/"+log_files[0]) as log_file:
    print(log_file.read())

Name of the selected graph: ./tf_graphs/assertion_graph.pb
Training started, trainExamples: 8050


Epoch: 0 started, learning rate: 0.001, dataset size: 8050
Done, 479.809131515 total training loss: 62.679592, avg training loss: 0.9949142, batches: 63
Quality on test dataset: 
time to finish evaluation: 41.41s
Total test loss: 2.4015	Avg test loss: 0.3431
label	 tp	 fp	 fn	 prec	 rec	 f1
PRESENT	 216	 26	 25	 0.892562	 0.89626557	 0.8944099
POSSIBLE	 158	 6	 30	 0.9634146	 0.84042555	 0.89772725
FUTURE	 95	 12	 25	 0.88785046	 0.7916667	 0.8370044
PAST	 240	 49	 13	 0.8304498	 0.9486166	 0.88560885
tp: 709 fp: 93 fn: 93 labels: 4
Macro-average	 prec: 0.89356923, rec: 0.8692436, f1: 0.8812386
Micro-average	 prec: 0.8840399, rec: 0.8840399, f1: 0.8840399


Epoch: 1 started, learning rate: 9.5E-4, dataset size: 8050
Done, 475.062305214 total training loss: 18.358475, avg training loss: 0.29140437, batches: 63
Quality on test dataset: 
time to finish evaluation: 41.80s
Total test loss: 1.3

In [0]:
assertion_test_data = spark.read.parquet('/dbfs/test_data.parquet')

In [0]:
preds = assertion_model.transform(assertion_test_data).select('assertion_label','assertion.result')

preds.show()

+---------------+---------+
|assertion_label|   result|
+---------------+---------+
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|        PRESENT|[PRESENT]|
|           PAST|   [PAST]|
|        PRESENT|[PRESENT]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
|           PAST| [FUTURE]|
|        PRESENT|   [PAST]|
|        PRESENT|   [PAST]|
|        PRESENT|[PRESENT]|
|           PAST|   [PAST]|
|           PAST|   [PAST]|
+---------------+---------+
only showing top 20 rows



In [0]:
preds_df = preds.toPandas()

In [0]:
preds_df["result"] = preds_df["result"].apply(lambda x: x[0] if len(x) else pd.NA)
preds_df.dropna(inplace=True)

preds_df

Unnamed: 0,assertion_label,result
0,PAST,PAST
1,PAST,PAST
2,PAST,PAST
3,PRESENT,PRESENT
4,PAST,PAST
...,...,...
797,PRESENT,PRESENT
798,PRESENT,PRESENT
799,PRESENT,PRESENT
800,PRESENT,PRESENT


In [0]:
# We are going to use sklearn to evalute the results on test dataset
from sklearn.metrics import classification_report

print (classification_report( preds_df['assertion_label'], preds_df['result']))

              precision    recall  f1-score   support

      FUTURE       0.90      0.97      0.94       120
        PAST       0.87      0.98      0.92       250
    POSSIBLE       0.99      0.90      0.94       188
     PRESENT       0.97      0.86      0.91       241

    accuracy                           0.93       799
   macro avg       0.93      0.93      0.93       799
weighted avg       0.93      0.93      0.93       799



### Saving the trained model

In [0]:
assertion_model.stages

Out[34]: [TFGraphBuilderModel_bf68c0049344, FINANCE-ASSERTION_DL_70bbc388fb95]

In [0]:
# Save a Spark NLP model
assertion_model.stages[-1].write().overwrite().save('Assertion')

  adding: fields/ (stored 0%)
  adding: fields/datasetParams/ (stored 0%)
  adding: fields/datasetParams/.part-00003.crc (deflated 45%)
  adding: fields/datasetParams/.part-00002.crc (stored 0%)
  adding: fields/datasetParams/part-00000 (deflated 26%)
  adding: fields/datasetParams/.part-00000.crc (stored 0%)
  adding: fields/datasetParams/part-00002 (deflated 26%)
  adding: fields/datasetParams/._SUCCESS.crc (stored 0%)
  adding: fields/datasetParams/part-00003 (deflated 95%)
  adding: fields/datasetParams/part-00001 (deflated 27%)
  adding: fields/datasetParams/.part-00001.crc (stored 0%)
  adding: fields/datasetParams/_SUCCESS (stored 0%)
  adding: metadata/ (stored 0%)
  adding: metadata/part-00000 (deflated 38%)
  adding: metadata/.part-00000.crc (stored 0%)
  adding: metadata/._SUCCESS.crc (stored 0%)
  adding: metadata/_SUCCESS (stored 0%)
  adding: tensorflow (deflated 40%)
