![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/legal-nlp/06.2.Relation_Extraction_Training.ipynb)

# Relation Extraction Model Training


This is our BertSpan-based Relation Extraction model, based on [this paper](https://arxiv.org/abs/1907.10529), an implemented by John Snow Labs on Tensorflow 1.x

Unfortunately, from Nov 2022 Google Colab discontinued the support of TF 1.x. 

**We are working on the TF 2.x version of it.**

In the meantime, please use non-colab environments with jupyter and TF 1.x

If you use GPU machine, you can save your training time.

In [None]:
! nvidia-smi

Fri Jan  6 21:16:04 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   30C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [None]:
! nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89


# 1.1. Installing Spark NLP

In [None]:
import json
import os
from os.path import exists

with open('your_license_path', 'r') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)

# Adding license key-value pairs to environment variables
os.environ.update(license_keys)

# 1.2. Installing Spark NLP (licensed)

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade pyspark==3.1.2 spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install spark-nlp-display

# 1.3. Starting Spark NLP

In [None]:
import pandas as pd
import requests
import json
from zipfile import ZipFile
from io import BytesIO
import os
from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
import sparknlp_jsl
import sparknlp

import warnings
warnings.filterwarnings('ignore')

print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark = sparknlp_jsl.start(license_keys['SECRET'])

spark

Spark NLP Version : 4.2.4
Spark NLP_JSL Version : 4.2.4


# Relation Extraction training using TensorFlow 1.x and BERT

# 2. Download BERT code implementation and BERT weights
In this section we will download official BERT code and the Bert pretrained weights we will use to finetune and create our RE model.

## 2.1. Downloading BERT code

In [None]:
!git clone https://github.com/google-research/bert

## 2.2. Downloading pretrained BERT weights

In [None]:
!wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip
#!wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip

In [None]:
!rm -Rf models trained || true

In [None]:
!mkdir models || true

In [None]:
!mkdir trained || true

In [None]:
!mv cased_L-12_H-768_A-12.zip ./models
#!mv uncased_L-12_H-768_A-12.zip ./models

In [None]:
!cd models && unzip -n cased_L-12_H-768_A-12.zip
#!cd models && unzip -n uncased_L-12_H-768_A-12.zip

In [None]:
!mv models/cased_L-12_H-768_A-12/* models/

## 2.3.Add BERT to System Path
Bert code will look for several modules in the system path, so we need to set that they can be also find in `bert` folder

In [None]:
import sys
import os

BERT_CODE = os.getcwd() + "/bert"

if not BERT_CODE in sys.path:
    sys.path += [BERT_CODE]

## 3. Library installation
Required:
```
- Java 8
- The Data Science classical kit (pandas+numpy+scipy)
- Tensorflow 1.x
- PySpark+SparkNLP
```

#### Java

In [None]:
# Make sure java 8 is installed.
!java -version

In [None]:
# If not, run:
!sudo apt-get update
!sudo apt-get purge -y openjdk-11* -qq > /dev/null && sudo apt-get autoremove -y -qq > /dev/null
!sudo apt-get install -y openjdk-8-jdk-headless -qq > /dev/null

In [None]:
# Make sure java 8 is installed.
!java -version

#### Data Science

In [None]:
!pip install pandas numpy==1.19.5 scipy

In [None]:
!pip install tensorflow==1.15

## Imports

In [None]:
import tensorflow as tf

#### Make sure TF 1.x is installed

In [None]:
print(tf.__version__)

1.15.0


In [None]:
import tensorflow as tf
import pandas as pd
import pickle
import sys
import re
import json
import numpy as np
import modeling
import optimization
import tokenization
import run_classifier
import shutil
import os
import pprint
from IPython.display import clear_output
from scipy.spatial.distance import cosine, euclidean
from functools import reduce
import scipy.stats as stats

pp = pprint.PrettyPrinter(indent=4)

In [None]:
config = tf.ConfigProto(device_count = {'GPU': 1})
config

device_count {
  key: "GPU"
  value: 1
}

## Hyperparam configuration
There are 2 available: generic BERT and specific BioBERT

### Base BERT

In [None]:
class BaseBertI2B2Config:
    #maximum sequence length, can be up to 512 for standard Bertmodels
    #larger values require more GPU memory
    #A GTX1080 Ti with 11 GB of memory can do no more than batch size 16 with max_seq_len 128
    MAX_SEQ_LENGTH = 256

    #location of pretrained Bert model
    BERT_MODEL_PATH = "./models"
    #location of Bert chekpoint used for initializing the model
    BERT_MODEL_CHECKPOINT_PATH = "{}/bert_model.ckpt".format(BERT_MODEL_PATH)
    #location of Bert configuration file
    BERT_MODEL_CONFIG_PATH = "{}/bert_config.json".format(BERT_MODEL_PATH)
    #location of Bert vocabulary file
    BERT_VOCAB_PATH = "{}/vocab.txt".format(BERT_MODEL_PATH)

    #Location for storing trained models  (in checkpoing format)
    CHKPOINT_PATH = "./trained"
   
    #Location to export trained models to (in saved_model format)
    EXPORT_PATH = "./models/basebert_re"

    #Initial LR, real LR depends on warm-up and training progress
    LEARNING_RATE = 2e-5
    #Number of training epochs (how many time to iterate through the training set)
    NUM_TRAIN_EPOCHS = 3
    #Proportion of training steps(i.e. number of batches) used for warming up (adaptive LR in the begging)
    WARMUP_PROPORTION = 0
    #Training batch size
    BATCH_SIZE = 16
    #Batch size during testing/valdiation
    V_BATCH_SIZE = 100

    #Sentence column name
    SENTENCE_COLUMN = "text"
    #Relation label column name
    REL_LABEL_COLUMN = "rel"
    #Relation argument binding colum name - used if (some of the) relations are not symmetric
    #0 - symmetric relation, argument order doesn't matter
    #1 - rel(ARG1, ARG2), where ARGS1 is the entity which first appears in the text
    #2 - rel(ARG2, ARG1)
    #if None, then ignore argument order(i.e. treat all relations as symmetric)
    REL_ARG_BINDING_COLUMN = 'direction'

    #Entities positions in the dataset
    ENTITY1_BEGIN_COLUMN = "firstCharEnt1"
    ENTITY1_END_COLUMN = "lastCharEnt1"
    ENTITY2_BEGIN_COLUMN = "firstCharEnt2"
    ENTITY2_END_COLUMN = "lastCharEnt2"


    ENTITY1_START_TAG = "e1b"
    ENTITY1_END_TAG = "e1e"
    ENTITY2_START_TAG = "e2b"
    ENTITY2_END_TAG = "e2e"

    ENTITY1_START_TAG_ID = 10
    ENTITY1_END_TAG_ID = 11
    ENTITY2_START_TAG_ID = 12
    ENTITY2_END_TAG_ID = 13

    
    NUM_HIDDEN_UNITS = 0
    
    DROPOUT_RATE = 0

    #stadard padding id value for Bert models
    PAD_ID = 0

    #proportion of training examples
    TRAIN_SET_PROB = 0.8
    
    #Not used at the moment
    REPLACE_ARG_PROB = 0    
    
    USE_ENTITY_POSITIONS = True
    USE_CLS_POSITION = True

In [None]:
# By default, we will use BaseBert (see step 1 in Main to change it)
BertREConfig = BaseBertI2B2Config

## Data collection
Set of functions to get input data from pandas or Spark dataframes

### Reading RE data from a pandas dataset

In [None]:
def collect_data_from_pandas_dataset(dataset):
    
    rel_labels = sorted(dataset[BertREConfig.REL_LABEL_COLUMN].unique())
    
    def process_row(row):

        row["sentence"] = annotate_sentence(
            row[BertREConfig.SENTENCE_COLUMN], 
            row[BertREConfig.ENTITY1_BEGIN_COLUMN],
            row[BertREConfig.ENTITY1_END_COLUMN],
            row[BertREConfig.ENTITY2_BEGIN_COLUMN],
            row[BertREConfig.ENTITY2_END_COLUMN]
        )
        row["rel_label_id"] = rel_labels.index(row[BertREConfig.REL_LABEL_COLUMN])
        row["rel_arg_binding"] = row[BertREConfig.REL_ARG_BINDING_COLUMN] if BertREConfig.REL_ARG_BINDING_COLUMN else 0

        return row
    
    
    dataset = dataset.apply(process_row, axis=1)
    
    return dataset.sentence, dataset.rel_label_id, dataset.rel_arg_binding, rel_labels

### Reading RE data from a spark dataset

In [None]:
def collect_data_from_spark_dataset(dataset):
    
    rel_labels = sorted([row[0] for row in dataset.select(BertREConfig.REL_LABEL_COLUMN).distinct().collect()])
    
    def process_row(row):
        sentence = annotate_sentence(
            row[BertREConfig.SENTENCE_COLUMN], 
            int(row[BertREConfig.ENTITY1_BEGIN_COLUMN]),
            int(row[BertREConfig.ENTITY1_END_COLUMN]),
            int(row[BertREConfig.ENTITY2_BEGIN_COLUMN]),
            int(row[BertREConfig.ENTITY2_END_COLUMN])
        )       
        rel_label_id = rel_labels.index(row[BertREConfig.REL_LABEL_COLUMN])
        
        rel_arg_binding = (
            row[BertREConfig.REL_ARG_BINDING_COLUMN] if BertREConfig.REL_ARG_BINDING_COLUMN else 0)
        
        return (sentence, rel_label_id, rel_arg_binding)
    
    sentences, rel_label_ids, rel_arg_bindings = tuple(
        map(list, zip(*dataset.rdd.map(process_row).collect())))
    
    return sentences, rel_label_ids, rel_arg_bindings, rel_labels

## Data annotation
Set of functions to properly annnotate the sentences using Bert reserved tokens

In [None]:
#Add entity markers to Bert vocabulary
def update_vocab():
    vocab = []

    with open(BertREConfig.BERT_VOCAB_PATH, 'r') as F:
        vocab = F.readlines()
        vocab[BertREConfig.ENTITY1_START_TAG_ID] = BertREConfig.ENTITY1_START_TAG + "\n"
        vocab[BertREConfig.ENTITY1_END_TAG_ID] = BertREConfig.ENTITY1_END_TAG + "\n"
        vocab[BertREConfig.ENTITY2_START_TAG_ID] = BertREConfig.ENTITY2_START_TAG + "\n"
        vocab[BertREConfig.ENTITY2_END_TAG_ID] = BertREConfig.ENTITY2_END_TAG + "\n"

    with open(BertREConfig.BERT_VOCAB_PATH, "w") as F:
        F.writelines(vocab)

In [None]:
#Tokenize sentence using Bert tokenizer, adding entity markers
def tokenize_sentence(sentence, tokenizer, seq_length=BertREConfig.MAX_SEQ_LENGTH, is_test=False):
    
    tokens = ["[CLS]"]

    entity_starts = []
    entity_ends = []
    
    for token in tokenizer.tokenize(sentence)[:seq_length - 2]:
        if token in [BertREConfig.ENTITY1_START_TAG, BertREConfig.ENTITY2_START_TAG]:
            entity_starts.append(len(tokens))
            
        elif token in [BertREConfig.ENTITY1_END_TAG, BertREConfig.ENTITY2_END_TAG]:
            entity_ends.append(len(tokens))
            
        tokens.append(token)
    
    tokens.append("[SEP]")    
        
    if (len(entity_starts) != 2) or (len(entity_ends) != 2):
        return False

    if not is_test:
        if np.random.rand() < BertREConfig.REPLACE_ARG_PROB:
            e1_length_diff = (entity_ends[0] - entity_starts[0]) - 2

            tokens = tokens[:entity_starts[0] + 1] + ["[MASK]"] + tokens[entity_ends[0]:]

            entity_ends[0] = entity_starts[0] + 2        

            entity_starts[1] = entity_starts[1] - e1_length_diff
            entity_ends[1] = entity_ends[1] - e1_length_diff

        if np.random.rand() < BertREConfig.REPLACE_ARG_PROB:
            tokens = tokens[:entity_starts[1] + 1] + ["[MASK]"] + tokens[entity_ends[1]:]
            entity_ends[1] = entity_starts[1] + 2        
    
    assert(tokens[entity_starts[0]] == BertREConfig.ENTITY1_START_TAG)
    assert(tokens[entity_starts[1]] == BertREConfig.ENTITY2_START_TAG)
    assert(tokens[entity_ends[0]] == BertREConfig.ENTITY1_END_TAG)
    assert(tokens[entity_ends[1]] == BertREConfig.ENTITY2_END_TAG)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
        
    return (input_ids, entity_starts[0], entity_starts[1], tokens)

In [None]:
def annotate_sentence(sentence, e1_begin, e1_end, e2_begin, e2_end):
    
    a1_start = min(e1_begin - 1, e2_begin)    
    a1_end = min(e1_end + 1, e2_end + 1)

    a2_start = max(e1_begin - 1, e2_begin)
    a2_end = max(e1_end + 1, e2_end + 1)
    
    new_sentence = " ".join([
        sentence[:a1_start], 
        BertREConfig.ENTITY1_START_TAG, 
        sentence[a1_start:a1_end],
        BertREConfig.ENTITY1_END_TAG, 
        sentence[a1_end:a2_start],
        BertREConfig.ENTITY2_START_TAG, 
        sentence[a2_start:a2_end],
        BertREConfig.ENTITY2_END_TAG, 
        sentence[a2_end:]
    ])

    return new_sentence      

## Feature Engineering
RE Feature Engineering consists of token ids (input_ids), entities POS and label ids

In [None]:
#Representation of RE featurues
class REFeatures(object):
    
    def __str__(self):
        return "{} ({})".format(
            ", ".join(
                map(lambda x: str(x), self.input_ids)), 
            self.sentence)
    
    def __init__(self,
                 input_ids,
                 entity1_pos,
                 entity2_pos,
                 rel_label_id,
                 rel_arg_binding,
                 sentence=""):

        self.input_ids = input_ids
        self.entity1_pos = entity1_pos
        self.entity2_pos = entity2_pos
        self.rel_label_id = rel_label_id
        self.rel_arg_binding = rel_arg_binding
        self.sentence = sentence

In [None]:
#Create RE features from a list of sentences and targets
def make_features(sentences, targets, tokenizer, is_test=False):    
    features = []
    for i in range(len(sentences)):
        ts = tokenize_sentence(sentences[i], tokenizer, is_test=is_test)
        if ts:
            features.append(
                REFeatures(
                    input_ids=ts[0],
                    entity1_pos=ts[1],
                    entity2_pos=ts[2],
                    rel_label_id=targets[i][0],
                    rel_arg_binding=targets[i][1],
                    sentence=" ".join(ts[3])
                ))

    return features

## Batches creation
For feeding the training process

In [None]:
#Make a batch of training /testing examples. 
#If max_seq_len is None, then use the sequence max length in the batch

def make_batch(features, max_seq_len = None):
    batch_size = len(features)
    use_rel_args = BertREConfig.REL_ARG_BINDING_COLUMN is not None
    if max_seq_len is None:
        max_seq_len = max([len(f.input_ids) for f in features])
    
    input_ids = np.ones([batch_size, max_seq_len], dtype=np.int32) * BertREConfig.PAD_ID
    input_mask = np.zeros([batch_size, max_seq_len], dtype=np.int32)
    segment_ids = np.zeros([batch_size, max_seq_len], dtype=np.int32)
    entity1_pos = np.zeros([batch_size], dtype=np.int32)
    entity2_pos = np.zeros([batch_size], dtype=np.int32)
    rel_label_ids = np.zeros([batch_size], dtype=np.int32)
    rel_arg_bindings = np.zeros([batch_size], dtype=np.int32)
    
    i = 0
    
    for f in features:
        
        input_ids[i, :len(f.input_ids)] = np.array(f.input_ids)
        input_mask[i, :len(f.input_ids)] = 1
        rel_label_ids[i] = f.rel_label_id
        rel_arg_bindings[i] = f.rel_arg_binding
        entity1_pos[i] = f.entity1_pos
        entity2_pos[i] = f.entity2_pos
        i += 1
    
    batch = {
        "input_ids:0": input_ids,
        "input_mask:0": input_mask,
        "segment_ids:0": segment_ids,
        "rel_label_ids:0": rel_label_ids,
        "entity1_pos:0": entity1_pos,
        "entity2_pos:0": entity2_pos,
    }
    
    if use_rel_args:
        batch["rel_arg_bindings:0"] = rel_arg_bindings
        
    return batch

## Optimizer creation
To carry out gradient descent and weight update with specific warm up, learning rate, etc.

In [None]:
#Create Bert RE optimizer graph
def create_optimizer(loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu):
    
    global_step = tf.train.get_or_create_global_step()    
    
    # Implements linear decay of the learning rate.
    learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

    tf.identity(learning_rate, name="c_lr")
    
    # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
    # learning rate will be `global_step/num_warmup_steps * init_lr`.
    
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = num_warmup_steps

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = learning_rate * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

    
    tf.identity(learning_rate, name="c_lr2")
    
    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    optimizer = optimization.AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6,
      exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

    tvars = tf.trainable_variables()
    grads = tf.gradients(loss, tvars)

    # This is how the model was pre-trained.
    (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

    train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

    # Normally the global step update is done inside of `apply_gradients`.
    # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
    # a different optimizer, you should probably take this line out.
    new_global_step = global_step + 1
    train_op = tf.group(train_op, [global_step.assign(new_global_step)], name="optimizer")
    return train_op

## BERT model creation
Creationg of the model using BERT architecture on TensorFlow

In [None]:
#Create Bert RE model graph, for training (is_trainable = True) and for inference (is_trainable = False)
def create_model(
    num_relations, 
    num_arg_bindings=3, 
    num_hidden_units=BertREConfig.NUM_HIDDEN_UNITS, 
    chkpoint_path=None, 
    is_trainable=True,
    config=config):
    
    with tf.Session(config=config) as session:    

        num_train_steps = tf.compat.v1.placeholder_with_default(
            input=tf.constant(1000, dtype=tf.float32), shape=(), name="num_train_steps")
        
        num_warm_up_steps = tf.compat.v1.placeholder_with_default(
            input=tf.cast(tf.round(0.1 * num_train_steps), tf.int32), shape=(), name="num_warm_up_steps")

        input_ids = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None, None), name="input_ids")
        
        batch_size = tf.shape(input_ids)[0]
        seq_len = tf.shape(input_ids)[1]
        
        input_mask = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None, None), name="input_mask")
        
        segment_ids = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None, None), name="segment_ids")
        
        rel_label_ids = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None), name="rel_label_ids")
        
        rel_arg_bindings = tf.compat.v1.placeholder_with_default(
            input=tf.zeros(shape=(batch_size),dtype=tf.compat.v1.int32), shape=(None), name="rel_arg_bindings")
        
        entity1_pos = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None), name="entity1_pos")
        
        entity2_pos = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None), name="entity2_pos")
        
        dropout_rate = tf.compat.v1.placeholder_with_default(
            input=tf.constant(BertREConfig.DROPOUT_RATE, dtype=tf.float32), shape=(), name="dropout_rate")
        
        learning_rate = tf.compat.v1.placeholder_with_default(
            input=tf.constant(2e-5, dtype=tf.float32), shape=(), name="learning_rate")
        
        config = modeling.BertConfig.from_json_file(BertREConfig.BERT_MODEL_CONFIG_PATH)

        bert_model = modeling.BertModel(
            config=config,
            is_training=is_trainable,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids)

        if chkpoint_path:
            tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
                tvars, chkpoint_path)
            tf.train.init_from_checkpoint(chkpoint_path, assignment_map)    

        output_layer = bert_model.get_sequence_output()

        if BertREConfig.USE_ENTITY_POSITIONS:
            #get entity start marker embeddings

            #E1 mask
            entity1_mask =  tf.repeat(
                tf.one_hot(entity1_pos, seq_len), 
                config.hidden_size, 
                axis=1)

            #E2 mask
            entity2_mask =  tf.repeat(
                tf.one_hot(entity2_pos, seq_len), 
                config.hidden_size, 
                axis=1)


            #Hidden layer representation for E1
            entity1_embd = tf.reduce_sum(
                tf.reshape(
                    (tf.reshape(output_layer, shape=(batch_size, -1)) * entity1_mask), 
                    shape=[batch_size, seq_len, config.hidden_size]), 
                axis=1)

            #Hidden layer representation for E2
            entity2_embd = tf.reduce_sum(
                tf.reshape(
                    (tf.reshape(output_layer, shape=(batch_size, -1)) * entity2_mask), 
                    shape=[batch_size, seq_len, config.hidden_size]), 
                axis=1)

            #Concat representions
            if BertREConfig.USE_CLS_POSITION:                
                classification_layer = tf.concat([entity1_embd, entity2_embd, output_layer[:,0,:]], axis=1)
            else:
                classification_layer = tf.concat([entity1_embd, entity2_embd], axis=1)
        else:
            if not BertREConfig.USE_CLS_POSITION:
                raise("Either USE_ENTITY_POSITIONS or USE_CLS_POSITION should be set to True.")
            else:
                classification_layer = output_layer[:,0,:]
                
        '''Add full connection layer and dropout layer'''
        
        if num_hidden_units > 0:
            fc = tf.layers.dense(
                classification_layer, 
                num_hidden_units, 
                name='fc1')
            fc = tf.nn.relu(fc)
        else:
            fc = tf.identity(classification_layer, name='fc1')

        fc = tf.nn.dropout(fc, rate=dropout_rate)

        '''logits'''
        rel_label_logits = tf.layers.dense(fc, num_relations, name='rel_label_logits')
        rel_label_log_probs = tf.nn.softmax(rel_label_logits, name="rel_label_probs")
        rel_label_predictions = tf.argmax(
            rel_label_log_probs, axis=-1, output_type=tf.int32, name="rel_label_predictions")

        if num_arg_bindings > 1:    
            rel_arg_binding_logits = tf.layers.dense(fc, num_arg_bindings, name='rel_arg_binding_logits')
            rel_arg_binding_probs = tf.nn.softmax(rel_arg_binding_logits, name="rel_arg_binding_probs")
            rel_arg_binding_predictions = tf.argmax(
                rel_arg_binding_probs, axis=-1, output_type=tf.int32, name="rel_arg_binding_predictions")
        else:
            rel_arg_binding_probs = tf.ones_like(rel_arg_bindings, name="rel_arg_binding_probs")
            rel_arg_binding_predictions = tf.zeros_like(
                rel_arg_bindings, name="rel_arg_binding_predictions")

        '''Calculate loss. Convert predicted labels into one hot form. '''            
        rel_label_targets = tf.one_hot(rel_label_ids, depth=num_relations)
        rel_label_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=rel_label_targets,
                logits=rel_label_logits)

        if num_arg_bindings > 1:                
            rel_arg_binding_targets = tf.one_hot(rel_arg_bindings, depth=num_arg_bindings)
            rel_arg_binding_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=rel_arg_binding_targets,
                    logits=rel_arg_binding_logits)
        else:
            rel_arg_binding_loss = 0                    
        
        rel_label_example_accuracy = tf.cast(
            tf.equal(rel_label_predictions, rel_label_ids), tf.float32, name="rel_label_acc")
        
        rel_label_accuracy = tf.reduce_mean(rel_label_example_accuracy, name="rel_label_mean_acc")        
        
        rel_arg_binding_example_accuracy = tf.cast(
            tf.equal(rel_arg_binding_predictions, rel_arg_bindings), tf.float32, name="rel_arg_binding_acc")
        
        rel_arg_binding_accuracy = tf.reduce_mean(
            rel_arg_binding_example_accuracy, name="rel_arg_binding_mean_acc")
        
        total_example_accuracy = tf.identity(
            rel_label_example_accuracy * rel_arg_binding_example_accuracy, name="total_acc")
        
        total_accuracy = tf.identity(total_example_accuracy, name="total_mean_acc")
        
        
        loss = tf.reduce_mean(rel_label_loss + rel_arg_binding_loss, name="loss")

        if is_trainable:
            train_op = create_optimizer(
                loss, 
                learning_rate, 
                num_train_steps, 
                num_warm_up_steps,
                use_tpu=False)        
        else:
            train_op = tf.no_op()
            
        init = tf.global_variables_initializer()

        return (
                train_op,
                loss,       
                total_accuracy,
                (rel_label_predictions, rel_arg_binding_predictions), 
                (rel_label_log_probs, rel_arg_binding_probs))

## Model saving
Function to export the trained BERT model to disk in TF

In [None]:
def export_model(model_id, is_trainable = True, num_arg_bindings = 3, config=config):
    
    with tf.Session(config=config) as session:
    
        model = create_model(
            len(rel_labels), 
            is_trainable=is_trainable,
            num_arg_bindings=num_arg_bindings)
    
        input_tensors = {}
        output_tensors = {}

        input_tensors_names = {
            "input_ids:0",
            "input_mask:0",
            "segment_ids:0",
            "entity1_pos:0",
            "entity2_pos:0",
        }

        output_tensors_names = [
            "loss:0",
            "rel_label_acc:0",
            "rel_arg_binding_acc:0",
            "total_acc:0",
            "rel_label_probs:0",
            "rel_label_predictions:0",
            "rel_arg_binding_probs:0",
            "rel_arg_binding_predictions:0"        
        ]


        for k in input_tensors_names:
            t = session.graph.get_tensor_by_name(k)
            input_tensors[t.name] = t

        for k in output_tensors_names:
            t = session.graph.get_tensor_by_name(k)
            output_tensors[k] = t
        
        print("{} trainable variables: ".format(len(tf.trainable_variables())))
        size_f = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
        n = sum(size_f(v) for v in tf.trainable_variables())
        print("{} trainbale parameters.".format(n))    
        
        tf.train.Saver().restore(session, f"{BertREConfig.CHKPOINT_PATH}/{model_id}/model")

        shutil.rmtree(BertREConfig.EXPORT_PATH, ignore_errors=True)

        #save model
        tf.saved_model.simple_save(
            session,
            BertREConfig.EXPORT_PATH,
            inputs=input_tensors,
            outputs=output_tensors
        )

        #copy assets to the destiation folder
        shutil.copytree(
            f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets", 
            f"{BertREConfig.EXPORT_PATH}/assets")


## Training
Function to train the model for RE

In [None]:
!nvidia-smi

Fri Jan  6 21:17:08 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   30C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [None]:
#Train a Bert RE model and save it in the checkpoints folder
def train_model(model_id, train_features, test_features, rel_labels, num_arg_bindings = 3, config=config):

    tf.reset_default_graph()
    
    use_rel_args = BertREConfig.REL_ARG_BINDING_COLUMN is not None
    
    with tf.Session(config=config) as session:

        model = create_model(
            len(rel_labels), 
            chkpoint_path=BertREConfig.BERT_MODEL_CHECKPOINT_PATH, 
            num_arg_bindings=num_arg_bindings, 
            is_trainable=True,
            config=config
        )

        session.run("init")

        num_train_steps = (BertREConfig.NUM_TRAIN_EPOCHS * len(train_features)) // BertREConfig.BATCH_SIZE    

        clear_output(wait=True)

        print("{:^11}{:^11}{:>10}{:>10}{:>10}{:>10}{:>10}{:>10}{:>10}".format(
            "Epoch", "Batch", 
            "Loss", 
            "L_ACC", "Arg_ACC", "ACC",
            "vL_ACC", "vArg_ACC", "vACC"))

        for e in range(BertREConfig.NUM_TRAIN_EPOCHS):

            np.random.shuffle(train_features)            

            b_loss = []
            b_rel_label_acc = []
            b_rel_arg_binding_acc = []
            b_total_acc = []
            for b in range(0, len(train_features) // BertREConfig.BATCH_SIZE):

                batch = make_batch(
                    train_features[b * BertREConfig.BATCH_SIZE: (b + 1) * BertREConfig.BATCH_SIZE]
                )#, max_seq_len=MAX_SEQ_LENGTH)

                data = batch

                if b == 0:
                    data["num_train_steps:0"] = num_train_steps
                    data["learning_rate:0"] = BertREConfig.LEARNING_RATE

                eval_tensors = [
                    "optimizer", 
                    "loss:0", 
                    "rel_label_mean_acc:0",
                    "rel_arg_binding_mean_acc:0",
                    "total_mean_acc:0"
                ]
                _, loss, rel_label_acc, rel_arg_bind_acc, total_acc = session.run(
                    eval_tensors, feed_dict=data)
                b_loss.append(loss)
                b_rel_label_acc.append(rel_label_acc)
                b_rel_arg_binding_acc.append(rel_arg_bind_acc)
                b_total_acc.append(total_acc)

                print("\r{:>5}/{:<5}{:>5}/{:<5}{:>10.4f}{:>10.3f}{:>10.3f}{:>10.3f}".format(
                        e+1,
                        BertREConfig.NUM_TRAIN_EPOCHS,
                        b + 1, 
                        len(train_features) // BertREConfig.BATCH_SIZE,
                        np.mean(b_loss), 
                        np.mean(b_rel_label_acc), 
                        np.mean(b_rel_arg_binding_acc), 
                        np.mean(b_total_acc)  
                    ), end="")


            v_rel_label_acc = []
            v_rel_arg_binding_acc = []
            v_total_acc = []

            for v_b in range(0, len(test_features) // BertREConfig.V_BATCH_SIZE):
                batch = make_batch(
                    test_features[v_b * BertREConfig.V_BATCH_SIZE: (v_b + 1) * BertREConfig.V_BATCH_SIZE])

                data = batch

                eval_tensors = [
                    "rel_label_mean_acc:0",
                    "rel_arg_binding_mean_acc:0",
                    "total_mean_acc:0",
                ]

                rel_label_acc, rel_arg_bind_acc, total_acc = session.run(eval_tensors, feed_dict=data)
                v_rel_label_acc.append(rel_label_acc)
                v_rel_arg_binding_acc.append(rel_arg_bind_acc)
                v_total_acc.append(total_acc)

            print("{:>10.3f}{:>10.3f}{:>10.3f}".format(
                np.mean(v_rel_label_acc), 
                np.mean(v_rel_arg_binding_acc), 
                np.mean(v_total_acc)))                


        
        shutil.rmtree(f"{BertREConfig.CHKPOINT_PATH}/{model_id}", ignore_errors=True)
        os.mkdir(f"{BertREConfig.CHKPOINT_PATH}/{model_id}")
        
        saver = tf.train.Saver()
        saver.save(session, f"{BertREConfig.CHKPOINT_PATH}/{model_id}/model")
        
        os.mkdir(f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets/")
        
        shutil.copy(
            BertREConfig.BERT_VOCAB_PATH, 
            f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets/vocab.txt")
        
        with open(f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets/categories.txt", "wt") as F:
            F.writelines("\n".join(rel_labels))


## Evaluation
Functions to get the metrics on the RE model and print them

In [None]:
def eval_metrics(model_id, features, rel_labels, num_arg_bindings = 3, exclude_rels=[], config=config):
    with tf.Session(config=config) as session:

        model = create_model(
            len(rel_labels), 
            num_arg_bindings=num_arg_bindings,
            is_trainable=False)
        
        tf.train.Saver().restore(session, f"{BertREConfig.CHKPOINT_PATH}/{model_id}/model")

        metrics_data = {}
        for rel in rel_labels:
            metrics_data[rel] = ([], [], [])
            
        for v_b in range(0, len(features) // BertREConfig.V_BATCH_SIZE):
            batch = make_batch(
                features[v_b * BertREConfig.V_BATCH_SIZE: (v_b + 1) * BertREConfig.V_BATCH_SIZE])

            data = batch

            eval_tensors = ["total_acc:0", "rel_label_ids:0", "rel_label_predictions:0"]

            total_acc, rel_label_ids, rel_label_preds = session.run(eval_tensors, feed_dict=data)
            
            for i in range(len(rel_label_ids)):
                acc = total_acc[i]
                pred = rel_label_preds[i]
                target = rel_label_ids[i]
                rel_target = rel_labels[target]
                rel_pred = rel_labels[pred]
                
                metrics_data[rel_target][2].append(1)
                
                if acc:
                    metrics_data[rel_target][0].append(1)
                    metrics_data[rel_pred][1].append(1)
                else:
                    metrics_data[rel_target][0].append(0)
                    metrics_data[rel_pred][1].append(0)                

        results = {}        
        
        for rel in [rel for rel in rel_labels if rel not in exclude_rels]:
            if len(metrics_data[rel][0]):
                recall = np.mean(metrics_data[rel][0])
            else:
                recall = 0
            if len(metrics_data[rel][1]):
                precision = np.mean(metrics_data[rel][1])
            else:
                precision = 0
            if (recall + precision):
                f1 = 2 * (recall * precision) / (recall + precision)
            else:
                f1 = np.NaN
               
            support = np.sum(metrics_data[rel][2])
            
            results[rel] = (recall, precision, f1, support)
        
        return results

In [None]:
def print_metrics(results):
    print("\n")
    print("{:<15}{:>10}{:>10}{:>10}{:>10}\n".format("Relation", "Recall", "Precision", "F1", "Support"))

    for rel in results:

        print(f"{rel:<15}{results[rel][0]:>10.3f}{results[rel][1]:>10.3f}{results[rel][2]:>10.3f}{results[rel][3]:>10}")

    mean_recall = np.mean([results[rel][0] for rel in results])
    mean_precision = np.mean([results[rel][1] for rel in results])
    mean_f1 = np.mean([results[rel][2] for rel in results])

    support_sum = np.sum([results[rel][3] for rel in results])

    w_mean_recall = np.sum([results[rel][0] * results[rel][3] for rel in results]) / support_sum
    w_mean_precision = np.sum([results[rel][1] * results[rel][3] for rel in results]) / support_sum
    w_mean_f1 = np.sum([results[rel][2] * results[rel][3] for rel in results]) / support_sum


    metrics_name = "Avg."

    print(f"\n{metrics_name:<15}{mean_recall:>10.3f}{mean_precision:>10.3f}{mean_f1:>10.3f}")

    metrics_name = "Weighted Avg."

    print(f"\n{metrics_name:<15}{w_mean_recall:>10.3f}{w_mean_precision:>10.3f}{w_mean_f1:>10.3f}")

## MAIN: STEP-BY-STEP RE MODEL TRAINING EXECUTION

### 1. Using BioBERT hyperparams instead of BertBase ones

In [None]:
# You can change me to BaseBertI2B2Config
BertREConfig = BaseBertI2B2Config

### 2. Update Bert vocabulary with special tokens

In [None]:
#Update Bert vocabylary
update_vocab()

### 3. Creating a Tokenizer

In [None]:
#create tokenizer
tokenizer = tokenization.FullTokenizer(vocab_file=BertREConfig.BERT_VOCAB_PATH, do_lower_case=True)

### 5. Read the input data.
It should look like as follows (see output) and have the following columns

| Column      |          Explanation                     |
|:-----------:|:----------------------------------------:|
|dataset      | train/test                               |
|source       | data provider                            |
|txt_file     | .txt file                                |
|sentence     | tokenized text sentence                  |
|sent_id      | sentence id                              |
|chunk1       | first entity                             |
|begin1       | first token number of the first entity   |
|end1         | last token number of the first entity    |
|rel          | relation (O for no-relation)             |
|chunk2       | second entity                            |
|begin2       | first token number of the second entity  |
|end2         | last token number of the second entity   |
|label1       | label of the first entity                |
|label2       | label of the second entity               |
|lastCharEnt1 | last char number of the first entity     |
|firstCharEnt1| first char number of the first entity    |
|lastCharEnt2 | last char number of the second entity    |
|firstCharEnt2| first char number of the second entity   |
|words_in_ent1| number of words in first entity          |
|words_in_ent2| number of words in second entity         |
|words_between| word between entities                    |
|is_train     | is it used for training?                 |

Yes, We are ready to train REDL model. Now we will train a REDL model to get relations between **DOC**, **PARTY**, **ALIAS** and **EFFDATE** entitties. 

Let's look at our dataset. 

Your dataset have to be like this format.

In [None]:
data = pd.read_csv('relations.csv')
data

Unnamed: 0,text,firstCharEnt1,firstCharEnt2,lastCharEnt1,lastCharEnt2,chunk1,chunk2,label1,label2,rel,direction
0,EXHIBIT 10.43 Dated 29/3/18\n\nDistributorship...,29,65,54,95,Distributorship agreement,Signature Orthopaedics Pty Ltd,DOC,PARTY,signed_by,1
1,EXHIBIT 10.43 Dated 29/3/18\n\nDistributorship...,29,102,54,129,Distributorship agreement,CPM Medical Consultants LLC,DOC,PARTY,signed_by,1
2,EXHIBIT 10.43 Dated 29/3/18\n\nDistributorship...,20,29,27,54,29/3/18,Distributorship agreement,EFFDATE,DOC,dated_as,2
3,Sections 200.80(b)(4) and Rule 406 of the Secu...,173,236,196,247,Collaboration Agreement,"Xencor, Inc",DOC,PARTY,signed_by,1
4,"Monrovia, CA 91016 USA\n\n(hereinafter called ...",63,170,97,173,Boehringer Ingelheim International,BII,PARTY,ALIAS,has_alias,1
...,...,...,...,...,...,...,...,...,...,...,...
3496,By execution of this Supplier/Subcontractor Co...,85,85,93,93,Supplier,Supplier,PARTY,ROLE,other,0
3497,/s Liu Gang Name: LIU GANG Title: Authoriz...,20,38,28,58,LIU GANG,Authorized Signatory,SIGNING_PERSON,SIGNING_TITLE,other,0
3498,"HOFV: HOF VILLAGE, LLC By: /s / Brian Parisi N...",193,212,204,227,David Baker,President & CEO,SIGNING_PERSON,SIGNING_TITLE,other,0
3499,By: /s/ Robert Mattacchione Name: Robert Matta...,34,61,53,64,Robert Mattacchione,CEO,SIGNING_PERSON,SIGNING_TITLE,other,0


In [None]:
data.value_counts('rel')

rel
other                   1639
signed_by                865
has_alias                471
dated_as                 435
has_collective_alias      91
dtype: int64

In [None]:
data.value_counts('label1')

label1
PARTY                         1284
DOC                           1267
SIGNING_PERSON                 737
ORG                             72
ALIAS                           57
ROLE                            41
EFFDATE                         18
SIGNING_TITLE                   11
TITLE                            7
NAME                             6
PERMISSION_INDIRECT_OBJECT       1
dtype: int64

In [None]:
data.value_counts('label2')

label2
PARTY                882
SIGNING_TITLE        724
SIGNING_PERSON       498
ALIAS                483
ROLE                 389
EFFDATE              376
ORG                   63
AGRDATE               52
DOC                   19
NAME                   7
TITLE                  6
PERMISSION             1
FORMER_PARTY_NAME      1
dtype: int64

In [None]:
#get a list of valid relation names (less than 10 occurrences are probably wrong labels, or at least with a very low representation)
valid_rel_labels = data['rel'].unique()
valid_rel_labels

array(['signed_by', 'dated_as', 'has_alias', 'has_collective_alias',
       'other'], dtype=object)

### 6. Create the training and test datasets

In [None]:
df_train = data.sample(frac=0.9, random_state=1)
df_test = data.drop(df_train.index)

In [None]:
#collect data
train_sentences, train_rel_label_ids, train_rel_arg_bindings, rel_labels = (
    collect_data_from_pandas_dataset(df_train))

test_sentences, test_rel_label_ids, test_rel_arg_bindings, _ = (
    collect_data_from_pandas_dataset(df_test))

In [None]:
train_sentences = train_sentences.values
train_sentences

In [None]:
train_rel_label_ids = train_rel_label_ids.values

In [None]:
train_rel_arg_bindings = train_rel_arg_bindings.values

In [None]:
rel_labels

['dated_as', 'has_alias', 'has_collective_alias', 'other', 'signed_by']

In [None]:
test_sentences = test_sentences.values
test_rel_label_ids = test_rel_label_ids.values
test_rel_arg_bindings = test_rel_arg_bindings.values

### 7. Create the features from the datasets

In [None]:
#create features
train_features = make_features(
    train_sentences, 
    list(zip(train_rel_label_ids, train_rel_arg_bindings)),
    tokenizer, 
    is_test=False)

if BertREConfig.REPLACE_ARG_PROB and BertREConfig.REPLICATE_DATASET:
    train_features += make_features(
        train_sentences, 
        list(zip(train_rel_label_ids, train_rel_arg_bindings)),
        tokenizer, 
        is_test=True)

In [None]:
test_features = make_features(
    test_sentences, 
    list(zip(test_rel_label_ids, test_rel_arg_bindings)),
    tokenizer, 
    is_test=True)

In [None]:
print("{} training examples".format(len(train_features)))
print("{} test examples".format(len(test_features)))

2995 training examples
333 test examples


### 8. Training the model
Here is where all the fun happens! 🏄

In [None]:
import sys, os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
train_model(
    "CONTRACT_DOC_PARTIES", 
    train_features=train_features, 
    test_features=test_features, 
    rel_labels=rel_labels,
    num_arg_bindings=3,
    config=config)

   Epoch      Batch         Loss     L_ACC   Arg_ACC       ACC    vL_ACC  vArg_ACC      vACC
    1/3      187/187      0.5342     0.893     0.918     0.865     0.963     0.980     0.957
    2/3      187/187      0.1363     0.986     0.985     0.980     0.977     0.980     0.970
    3/3      187/187      0.0643     0.994     0.991     0.988     0.977     0.987     0.973


### 9. Evaluating the model

In [None]:
tf.reset_default_graph()        
metrics = eval_metrics(
    "CONTRACT_DOC_PARTIES", test_features, rel_labels, num_arg_bindings=3, exclude_rels=[])        
print_metrics(metrics)

INFO:tensorflow:Restoring parameters from ./trained/CONTRACT_DOC_PARTIES/model


Relation           Recall Precision        F1   Support

dated_as            1.000     0.957     0.978        44
has_alias           0.950     0.974     0.962        40
has_collective_alias     0.667     1.000     0.800         3
other               0.992     0.968     0.980       121
signed_by           0.957     0.989     0.972        92

Avg.                0.913     0.977     0.938

Weighted Avg.       0.973     0.974     0.973


# Train with all data

In [None]:
tf.reset_default_graph()    

train_model(
    "CONTRACT_DOC_PARTIES", 
    train_features=train_features+test_features,
    test_features=[], 
    rel_labels=rel_labels,
    num_arg_bindings=3)

   Epoch      Batch         Loss     L_ACC   Arg_ACC       ACC    vL_ACC  vArg_ACC      vACC
    1/3      208/208      0.5453     0.887     0.918     0.868       nan       nan       nan
    2/3      208/208      0.1223     0.986     0.986     0.980       nan       nan       nan
    3/3      208/208      0.0614     0.994     0.991     0.988       nan       nan       nan


### 10. Finally saving it!

In [None]:
tf.reset_default_graph()

export_model("CONTRACT_DOC_PARTIES", is_trainable=False, num_arg_bindings=3)

In [None]:
!sudo apt-get -y install zip unzip

In [None]:
!zip -r redl.zip ./models/basebert_re

# We test in SPARK NLP

Let's test our model with SparkNLP

In [None]:
import json
import os
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
import sparknlp_jsl
import sparknlp

import pandas as pd
import numpy as np
import os

from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel
from pyspark.sql import functions as F

from sparknlp.annotator import *
from sparknlp.base import *
import sparknlp_jsl
from sparknlp_jsl.annotator import *

from tqdm import tqdm
import pandas as pd
import numpy as np
from tqdm import tqdm

from sparknlp.training import CoNLL

import re
import copy

import random

from pyspark.sql.types import StringType

import glob
import pickle
from termcolor import colored
import traceback

In [None]:
import pandas as pd

def get_relations_df (results, col='relations'):
  rel_pairs=[]
  for rel in results[0][col]:
      rel_pairs.append((
          rel.result, 
          rel.metadata['entity1'], 
          rel.metadata['entity1_begin'],
          rel.metadata['entity1_end'],
          rel.metadata['chunk1'], 
          rel.metadata['entity2'],
          rel.metadata['entity2_begin'],
          rel.metadata['entity2_end'],
          rel.metadata['chunk2'], 
          rel.metadata['confidence']
      ))

  rel_df = pd.DataFrame(rel_pairs, columns=['relation','entity1','entity1_begin','entity1_end','chunk1','entity2','entity2_begin','entity2_end','chunk2', 'confidence'])

  return rel_df

Here, we import our model to SparkNLP

In [None]:
re = RelationExtractionDLModel().loadSavedModel('models/basebert_re', spark)
re.write().overwrite().save('legre_contract_doc_parties_md')

Now before getting relations, we have to extract entities from the given text. For this, we will use `legner_contract_doc_parties` NER model.

In [None]:
document_assembler = DocumentAssembler()\
  .setInputCol("text")\
  .setOutputCol("document")

tokenizer = Tokenizer()\
    .setInputCols("document")\
    .setOutputCol("token")

embeddings = RoBertaEmbeddings.pretrained("roberta_embeddings_legal_roberta_base", "en") \
    .setInputCols("document", "token") \
    .setOutputCol("embeddings")\
    .setMaxSentenceLength(512)

ner_model = legal.LegalNerModel.pretrained('legner_contract_doc_parties', 'en', 'legal/models')\
    .setInputCols(["sentence", "token", "embeddings"])\
    .setOutputCol("ner")

ner_converter = NerConverter()\
    .setInputCols(["document","token","ner"])\
    .setOutputCol("ner_chunk")

# We use the load function to run our trained model.
reDL = RelationExtractionDLModel().load('legre_contract_doc_parties_md')\
    .setPredictionThreshold(0.5)\
    .setInputCols(["ner_chunk", "document"])\
    .setOutputCol("relations")

nlpPipeline = Pipeline(stages=[
    document_assembler,
    tokenizer,
    embeddings,
    ner_model,
    ner_converter,
    reDL
    ])

data = spark.createDataFrame([[text]]).toDF("text")

model = nlpPipeline.fit(data)

In [None]:
light_model = LightPipeline(model)

text='''
This INTELLECTUAL PROPERTY AGREEMENT (this "Agreement"), dated as of December 31, 2018 (the "Effective Date") is entered into by and between Armstrong Flooring, Inc., a Delaware corporation ("Seller") and AFI Licensing LLC, a Delaware limited liability company ("Licensing" and together with Seller, "Arizona") and AHF Holding, Inc. (formerly known as Tarzan HoldCo, Inc.), a Delaware corporation ("Buyer") and Armstrong Hardwood Flooring Company, a Tennessee corporation (the "Company" and together with Buyer the "Buyer Entities") (each of Arizona on the one hand and the Buyer Entities on the other hand, a "Party" and collectively, the "Parties").
'''

results = light_model.fullAnnotate(text)

In [None]:
rel_df = get_relations_df(results)
rel_df = rel_df[rel_df['relation']!='other']
rel_df

Unnamed: 0,relation,entity1,entity1_begin,entity1_end,chunk1,entity2,entity2_begin,entity2_end,chunk2,confidence
0,dated_as,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,EFFDATE,70,86,"December 31, 2018",0.99994016
1,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,142,164,"Armstrong Flooring, Inc",0.9995191
2,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,193,198,Seller,0.9823355
3,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,206,222,AFI Licensing LLC,0.9989542
4,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,264,272,Licensing,0.92109
5,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,293,298,Seller,0.9938019
6,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,316,331,"AHF Holding, Inc",0.9989403
7,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,400,404,Buyer,0.89959186
8,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,412,446,Armstrong Hardwood Flooring Company,0.9974464
9,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,479,485,Company,0.95839113
