![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/legal-nlp/06.2.Relation_Extraction_Training.ipynb)

## Starting the spark session

In [None]:
! pip install -q johnsnowlabs

In [None]:
from johnsnowlabs import *

# nlp.install(force_browser=True)

In [None]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

In [None]:
nlp.install()

In [None]:
from johnsnowlabs import *
spark = nlp.start()

# Relation Extraction training using TensorFlow 2.x and BERT

In [None]:
import tensorflow.compat.v1 as tf


In [None]:
!wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip



In [None]:
!mkdir /content/trained
!mkdir /content/models
!mkdir /content/models/bert_base

In [None]:
!unzip /content/cased_L-12_H-768_A-12.zip -d /content/models/bert_base/

## RE DL Training Notebook for Tensorflow 2

An adaptation of the original Google Bert TF 1 repo 

# 2. Download BERT code implementation and BERT weights
In this section we will download official BERT code and the Bert pretrained weights we will use to finetune and create our RE model.

## 2.1. Downloading BERT code

In [None]:
#Bert source location
BERT_SRC = "./bert_t"
#Bert repo to download source from
BERT_REPO = "https://github.com/google-research/bert"

In [None]:
!wget  -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/legal-nlp/data/tf2_contrib.py
!wget  -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/legal-nlp/data/relations.csv

In [None]:
!test -d $BERT_SRC || git clone $BERT_REPO $BERT_SRC
!sed -i 's/import tensorflow as tf/import tensorflow.compat.v1 as tf/g' $BERT_SRC/optimization.py
!sed -i 's/import tensorflow as tf/import tensorflow.compat.v1 as tf/g' $BERT_SRC/run_classifier.py
!sed -i 's/import tensorflow as tf/import tensorflow.compat.v1 as tf/g' $BERT_SRC/tokenization.py
!sed -i 's/import tensorflow as tf/import tensorflow.compat.v1 as tf\nfrom tf2_contrib import tf_contrib_layer_norm\n/g' $BERT_SRC/modeling.py
!sed -i 's/tf.contrib.layers.layer_norm/tf_contrib_layer_norm/g' $BERT_SRC/modeling.py
!cp tf2_contrib.py $BERT_SRC/

Cloning into './bert_t'...
remote: Enumerating objects: 340, done.[K
remote: Total 340 (delta 0), reused 0 (delta 0), pack-reused 340[K
Receiving objects: 100% (340/340), 328.27 KiB | 16.41 MiB/s, done.
Resolving deltas: 100% (182/182), done.


## 2.2.Add BERT to System Path
Bert code will look for several modules in the system path, so we need to set that they can be also find in `bert` folder

In [None]:
import sys

if not BERT_SRC in sys.path:
    sys.path += [BERT_SRC]

## Imports

In [None]:
import tensorflow.compat.v1 as tf
import pandas as pd
import pickle
import sys
import re
import json
import numpy as np
import modeling
import optimization
import tokenization
import run_classifier
import shutil
import os
import pprint
from IPython.display import clear_output
from scipy.spatial.distance import cosine, euclidean
from functools import reduce
import scipy.stats as stats
import pyspark.sql.functions as F

pp = pprint.PrettyPrinter(indent=4)

## Hyperparam configuration


### Base BERT

In [None]:
 
class BaseBertI2B2Config:
    #maximum sequence length, can be up to 512 for standard Bertmodels
    #larger values require more GPU memory
    #A GTX1080 Ti with 11 GB of memory can do no more than batch size 16 with max_seq_len 128
    MAX_SEQ_LENGTH = 128

    #location of pretrained Bert model
    BERT_MODEL_PATH = "/content/models/bert_base/cased_L-12_H-768_A-12/"
    #location of Bert chekpoint used for initializing the model
    BERT_MODEL_CHECKPOINT_PATH = "{}/bert_model.ckpt".format(BERT_MODEL_PATH)
    #location of Bert configuration file
    BERT_MODEL_CONFIG_PATH = "{}/bert_config.json".format(BERT_MODEL_PATH)
    #location of Bert vocabulary file
    BERT_VOCAB_PATH = "{}/vocab.txt".format(BERT_MODEL_PATH)

    #Location for storing trained models  (in checkpoing format)
    CHKPOINT_PATH = "./trained"
   
    #Location to export trained models to (in saved_model format)
    EXPORT_PATH = "/content/models/basebert_re"

    #Initial LR, real LR depends on warm-up and training progress
    LEARNING_RATE = 2e-5
    #Number of training epochs (how many time to iterate through the training set)
    NUM_TRAIN_EPOCHS = 5
    #Proportion of training steps(i.e. number of batches) used for warming up (adaptive LR in the begging)
    WARMUP_PROPORTION = 0.1
    #Training batch size
    BATCH_SIZE = 16
    #Batch size during testing/valdiation
    V_BATCH_SIZE = 100

    #Sentence column name
    SENTENCE_COLUMN = "text"
    #Relation label column name
    REL_LABEL_COLUMN = "rel"
    #Relation argument binding colum name - used if (some of the) relations are not symmetric
    #0 - symmetric relation, argument order doesn't matter
    #1 - rel(ARG1, ARG2), where ARGS1 is the entity which first appears in the text
    #2 - rel(ARG2, ARG1)
    #if None, then ignore argument order(i.e. treat all relations as symmetric)
    REL_ARG_BINDING_COLUMN = None

    #Entities positions in the dataset
    ENTITY1_BEGIN_COLUMN = "firstCharEnt1"
    ENTITY1_END_COLUMN = "lastCharEnt1"
    ENTITY2_BEGIN_COLUMN = "firstCharEnt2"
    ENTITY2_END_COLUMN = "lastCharEnt2"


    ENTITY1_START_TAG = "e1b"
    ENTITY1_END_TAG = "e1e"
    ENTITY2_START_TAG = "e2b"
    ENTITY2_END_TAG = "e2e"

    ENTITY1_START_TAG_ID = 10
    ENTITY1_END_TAG_ID = 11
    ENTITY2_START_TAG_ID = 12
    ENTITY2_END_TAG_ID = 13

    #stadard padding id value for Bert models
    PAD_ID = 0

    #proportion of training examples
    TRAIN_SET_PROB = 0.8
    
    #Not used at the moment  
    NUM_HIDDEN_UNITS = 0
    DROPOUT_RATE = 0
    REPLACE_ARG_PROB = 0    
    
    USE_ENTITY_POSITIONS = True
    USE_CLS_POSITION = True

In [None]:
#BertREConfig is used by the code in this notebook, update for each model you train
BertREConfig = BaseBertI2B2Config

## Data collection
Set of functions to get input data from pandas or Spark dataframes

### Reading RE data from a pandas dataset

In [None]:
def collect_data_from_pandas_dataset(dataset):
    
    rel_labels = sorted(dataset[BertREConfig.REL_LABEL_COLUMN].unique())
    
    def process_row(row):

        row["sentence"] = annotate_sentence(
            row[BertREConfig.SENTENCE_COLUMN], 
            row[BertREConfig.ENTITY1_BEGIN_COLUMN],
            row[BertREConfig.ENTITY1_END_COLUMN],
            row[BertREConfig.ENTITY2_BEGIN_COLUMN],
            row[BertREConfig.ENTITY2_END_COLUMN]
        )
        row["rel_label_id"] = rel_labels.index(row[BertREConfig.REL_LABEL_COLUMN])
        row["rel_arg_binding"] = row[BertREConfig.REL_ARG_BINDING_COLUMN] if BertREConfig.REL_ARG_BINDING_COLUMN else 0

        return row
    
    
    dataset = dataset.apply(process_row, axis=1)
    
    return dataset.sentence, dataset.rel_label_id, dataset.rel_arg_binding, rel_labels
    


### Reading RE data from a spark dataset

In [None]:
def collect_data_from_spark_dataset(dataset):
    
    rel_labels = sorted([row[0] for row in dataset.select(BertREConfig.REL_LABEL_COLUMN).distinct().collect()])
    
    def process_row(row):
        sentence = annotate_sentence(
            row[BertREConfig.SENTENCE_COLUMN], 
            int(row[BertREConfig.ENTITY1_BEGIN_COLUMN]),
            int(row[BertREConfig.ENTITY1_END_COLUMN]),
            int(row[BertREConfig.ENTITY2_BEGIN_COLUMN]),
            int(row[BertREConfig.ENTITY2_END_COLUMN])
        )       
        rel_label_id = rel_labels.index(row[BertREConfig.REL_LABEL_COLUMN])
        
        rel_arg_binding = (
            row[BertREConfig.REL_ARG_BINDING_COLUMN] if BertREConfig.REL_ARG_BINDING_COLUMN else 0)
        
        return (sentence, rel_label_id, rel_arg_binding)
    
    sentences, rel_label_ids, rel_arg_bindings = tuple(
        map(list, zip(*dataset.rdd.map(process_row).collect())))
    
    return sentences, rel_label_ids, rel_arg_bindings, rel_labels

## Data annotation
Set of functions to properly annnotate the sentences using Bert reserved tokens

In [None]:
#Add entity markers to Bert vocabulary
def update_vocab():
    vocab = []

    with open(BertREConfig.BERT_VOCAB_PATH, 'r') as F:
        vocab = F.readlines()
        vocab[BertREConfig.ENTITY1_START_TAG_ID] = BertREConfig.ENTITY1_START_TAG + "\n"
        vocab[BertREConfig.ENTITY1_END_TAG_ID] = BertREConfig.ENTITY1_END_TAG + "\n"
        vocab[BertREConfig.ENTITY2_START_TAG_ID] = BertREConfig.ENTITY2_START_TAG + "\n"
        vocab[BertREConfig.ENTITY2_END_TAG_ID] = BertREConfig.ENTITY2_END_TAG + "\n"

    with open(BertREConfig.BERT_VOCAB_PATH, "w") as F:
        F.writelines(vocab)

In [None]:

#Tokenize sentence using Bert tokenizer, adding entity markers
def tokenize_sentence(sentence, tokenizer, seq_length=BertREConfig.MAX_SEQ_LENGTH, is_test=False):
    
    tokens = ["[CLS]"]

    entity_starts = []
    entity_ends = []
    
    for token in tokenizer.tokenize(sentence)[:seq_length - 2]:
        if token in [BertREConfig.ENTITY1_START_TAG, BertREConfig.ENTITY2_START_TAG]:
            entity_starts.append(len(tokens))
            
        elif token in [BertREConfig.ENTITY1_END_TAG, BertREConfig.ENTITY2_END_TAG]:
            entity_ends.append(len(tokens))
            
        tokens.append(token)
    
    tokens.append("[SEP]")    
        
    if (len(entity_starts) != 2) or (len(entity_ends) != 2):
        return False

    if not is_test:
        if np.random.rand() < BertREConfig.REPLACE_ARG_PROB:
            e1_length_diff = (entity_ends[0] - entity_starts[0]) - 2

            tokens = tokens[:entity_starts[0] + 1] + ["[MASK]"] + tokens[entity_ends[0]:]

            entity_ends[0] = entity_starts[0] + 2        

            entity_starts[1] = entity_starts[1] - e1_length_diff
            entity_ends[1] = entity_ends[1] - e1_length_diff

        if np.random.rand() < BertREConfig.REPLACE_ARG_PROB:
            tokens = tokens[:entity_starts[1] + 1] + ["[MASK]"] + tokens[entity_ends[1]:]
            entity_ends[1] = entity_starts[1] + 2        
    
    assert(tokens[entity_starts[0]] == BertREConfig.ENTITY1_START_TAG)
    assert(tokens[entity_starts[1]] == BertREConfig.ENTITY2_START_TAG)
    assert(tokens[entity_ends[0]] == BertREConfig.ENTITY1_END_TAG)
    assert(tokens[entity_ends[1]] == BertREConfig.ENTITY2_END_TAG)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
        
    return (input_ids, entity_starts[0], entity_starts[1], tokens)



In [None]:
def annotate_sentence(sentence, e1_begin, e1_end, e2_begin, e2_end):
    
    a1_start = min(e1_begin - 1, e2_begin)    
    a1_end = min(e1_end + 1, e2_end + 1)

    a2_start = max(e1_begin - 1, e2_begin)
    a2_end = max(e1_end + 1, e2_end + 1)
    
    new_sentence = " ".join([
        sentence[:a1_start], 
        BertREConfig.ENTITY1_START_TAG, 
        sentence[a1_start:a1_end],
        BertREConfig.ENTITY1_END_TAG, 
        sentence[a1_end:a2_start],
        BertREConfig.ENTITY2_START_TAG, 
        sentence[a2_start:a2_end],
        BertREConfig.ENTITY2_END_TAG, 
        sentence[a2_end:]
    ])

    return new_sentence  

## Feature Engineering
RE Feature Engineering consists of token ids (input_ids), entities POS and label ids

In [None]:
#Representation of RE featurues
class REFeatures(object):
    
    def __str__(self):
        return "{} ({})".format(
            ", ".join(
                map(lambda x: str(x), self.input_ids)), 
            self.sentence)
    
    def __init__(self,
                 input_ids,
                 entity1_pos,
                 entity2_pos,
                 rel_label_id,
                 rel_arg_binding,
                 sentence=""):

        self.input_ids = input_ids
        self.entity1_pos = entity1_pos
        self.entity2_pos = entity2_pos
        self.rel_label_id = rel_label_id
        self.rel_arg_binding = rel_arg_binding
        self.sentence = sentence        


In [None]:
#Create RE features from a list of sentencesand tarkets
def make_features(sentences, targets, tokenizer, is_test=False):    
    features = []
    for i in range(len(sentences)):
        ts = tokenize_sentence(sentences[i], tokenizer, is_test=is_test)
        if ts:
            features.append(
                REFeatures(
                    input_ids=ts[0],
                    entity1_pos=ts[1],
                    entity2_pos=ts[2],
                    rel_label_id=targets[i][0],
                    rel_arg_binding=targets[i][1],
                    sentence=" ".join(ts[3])
                ))

    return features

## Batches creation
For feeding the training process

In [None]:
#Make a batch of trainin/testing examples. 
#If max_seq_len is None, then use the sequence max length in the batch

def make_batch(features, max_seq_len = None):
    batch_size = len(features)
    use_rel_args = BertREConfig.REL_ARG_BINDING_COLUMN is not None
    if max_seq_len is None:
        max_seq_len = max([len(f.input_ids) for f in features])
    
    input_ids = np.ones([batch_size, max_seq_len], dtype=np.int32) * BertREConfig.PAD_ID
    input_mask = np.zeros([batch_size, max_seq_len], dtype=np.int32)
    segment_ids = np.zeros([batch_size, max_seq_len], dtype=np.int32)
    entity1_pos = np.zeros([batch_size], dtype=np.int32)
    entity2_pos = np.zeros([batch_size], dtype=np.int32)
    rel_label_ids = np.zeros([batch_size], dtype=np.int32)
    rel_arg_bindings = np.zeros([batch_size], dtype=np.int32)
    
    i = 0
    
    for f in features:
        
        input_ids[i, :len(f.input_ids)] = np.array(f.input_ids)
        input_mask[i, :len(f.input_ids)] = 1
        rel_label_ids[i] = f.rel_label_id
        rel_arg_bindings[i] = f.rel_arg_binding
        entity1_pos[i] = f.entity1_pos
        entity2_pos[i] = f.entity2_pos
        i += 1
    
    batch = {
        "input_ids:0": input_ids,
        "input_mask:0": input_mask,
        "segment_ids:0": segment_ids,
        "rel_label_ids:0": rel_label_ids,
        "entity1_pos:0": entity1_pos,
        "entity2_pos:0": entity2_pos,
    }
    
    if use_rel_args:
        batch["rel_arg_bindings:0"] = rel_arg_bindings
        
    return batch

## Optimizer creation
To carry out gradient descent and weight update with specific warm up, learning rate, etc.

In [None]:
#Create Bert RE optimizer graph
def create_optimizer(loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu):
    
    global_step = tf.train.get_or_create_global_step()

    
    
    # Implements linear decay of the learning rate.
    learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

    tf.identity(learning_rate, name="c_lr")
    
    # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
    # learning rate will be `global_step/num_warmup_steps * init_lr`.
    
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = num_warmup_steps

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = learning_rate * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

    
    tf.identity(learning_rate, name="c_lr2")
    
    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    optimizer = optimization.AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6,
      exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

    tvars = tf.trainable_variables()
    grads = tf.gradients(loss, tvars)

    # This is how the model was pre-trained.
    (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

    train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

    # Normally the global step update is done inside of `apply_gradients`.
    # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
    # a different optimizer, you should probably take this line out.
    new_global_step = global_step + 1
    train_op = tf.group(train_op, [global_step.assign(new_global_step)], name="optimizer")
    return train_op



## BERT model creation
Creationg of the model using BERT architecture on TensorFlow

In [None]:
#Create Bert RE model graph, for training (is_trainable = True) and for inference (is_trainable = False)
def create_model(
    num_relations, 
    num_arg_bindings=3, 
    num_hidden_units=BertREConfig.NUM_HIDDEN_UNITS, 
    chkpoint_path=None, 
    is_trainable=True):
    
    with tf.compat.v1.Session() as session:    

        num_train_steps = tf.compat.v1.placeholder_with_default(
            input=tf.constant(1000, dtype=tf.float32), shape=(), name="num_train_steps")
        
        num_warm_up_steps = tf.compat.v1.placeholder_with_default(
            input=tf.cast(tf.round(0.1 * num_train_steps), tf.int32), shape=(), name="num_warm_up_steps")

        input_ids = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None, None), name="input_ids")
        
        batch_size = tf.shape(input_ids)[0]
        seq_len = tf.shape(input_ids)[1]
        
        input_mask = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None, None), name="input_mask")
        
        segment_ids = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None, None), name="segment_ids")
        
        rel_label_ids = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None), name="rel_label_ids")
        
        rel_arg_bindings = tf.compat.v1.placeholder_with_default(
            input=tf.zeros(shape=(batch_size),dtype=tf.compat.v1.int32), shape=(None), name="rel_arg_bindings")
        
        entity1_pos = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None), name="entity1_pos")
        
        entity2_pos = tf.compat.v1.placeholder(
            dtype=tf.compat.v1.int32, shape=(None), name="entity2_pos")
        
        #set dropout to 0 if mode is not trainable
        default_dropout_rate = BertREConfig.DROPOUT_RATE if is_trainable else 0.0        
        dropout_rate = tf.compat.v1.placeholder_with_default(
            input=tf.constant(default_dropout_rate, dtype=tf.float32), shape=(), name="dropout_rate")
        
        learning_rate = tf.compat.v1.placeholder_with_default(
            input=tf.constant(2e-5, dtype=tf.float32), shape=(), name="learning_rate")
        
        config = modeling.BertConfig.from_json_file(BertREConfig.BERT_MODEL_CONFIG_PATH)

#         breakpoint()
        
        bert_model = modeling.BertModel(
            config=config,
            is_training=is_trainable,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids)

        if chkpoint_path:
            tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
                tvars, chkpoint_path)
            tf.train.init_from_checkpoint(chkpoint_path, assignment_map)    

        output_layer = bert_model.get_sequence_output()

        if BertREConfig.USE_ENTITY_POSITIONS:
            #get entity start marker embeddings

            #E1 mask
            entity1_mask =  tf.repeat(
                tf.one_hot(entity1_pos, seq_len), 
                config.hidden_size, 
                axis=1)

            #E2 mask
            entity2_mask =  tf.repeat(
                tf.one_hot(entity2_pos, seq_len), 
                config.hidden_size, 
                axis=1)


            #Hidden layer representation for E1
            entity1_embd = tf.reduce_sum(
                tf.reshape(
                    (tf.reshape(output_layer, shape=(batch_size, -1)) * entity1_mask), 
                    shape=[batch_size, seq_len, config.hidden_size]), 
                axis=1)

            #Hidden layer representation for E2
            entity2_embd = tf.reduce_sum(
                tf.reshape(
                    (tf.reshape(output_layer, shape=(batch_size, -1)) * entity2_mask), 
                    shape=[batch_size, seq_len, config.hidden_size]), 
                axis=1)

            #Concat representions
            if BertREConfig.USE_CLS_POSITION:                
                classification_layer = tf.concat([entity1_embd, entity2_embd, output_layer[:,0,:]], axis=1)
            else:
                classification_layer = tf.concat([entity1_embd, entity2_embd], axis=1)
        else:
            if not BertREConfig.USE_CLS_POSITION:
                raise("Either USE_ENTITY_POSITIONS or USE_CLS_POSITION should be set to True.")
            else:
                classification_layer = output_layer[:,0,:]
                
#         print("Classification layer size: ")
#         print(classification_layer.shape)
        '''Add full connection layer and dropout layer'''
        
        if num_hidden_units > 0:
            fc = tf.layers.dense(
                classification_layer, 
                num_hidden_units, 
                name='fc1')
            fc = tf.nn.relu(fc)
        else:
            fc = tf.identity(classification_layer, name='fc1')

        fc = tf.nn.dropout(fc, rate=dropout_rate)

        '''logits'''
        rel_label_logits = tf.layers.dense(fc, num_relations, name='rel_label_logits')
        rel_label_log_probs = tf.nn.softmax(rel_label_logits, name="rel_label_probs")
        rel_label_predictions = tf.argmax(
            rel_label_log_probs, axis=-1, output_type=tf.int32, name="rel_label_predictions")

        if num_arg_bindings > 1:    
            rel_arg_binding_logits = tf.layers.dense(fc, num_arg_bindings, name='rel_arg_binding_logits')
            rel_arg_binding_probs = tf.nn.softmax(rel_arg_binding_logits, name="rel_arg_binding_probs")
            rel_arg_binding_predictions = tf.argmax(
                rel_arg_binding_probs, axis=-1, output_type=tf.int32, name="rel_arg_binding_predictions")
        else:
            rel_arg_binding_probs = tf.ones_like(rel_arg_bindings, name="rel_arg_binding_probs")
            rel_arg_binding_predictions = tf.zeros_like(
                rel_arg_bindings, name="rel_arg_binding_predictions")

        '''Calculate loss. Convert predicted labels into one hot form. '''            
        rel_label_targets = tf.one_hot(rel_label_ids, depth=num_relations)
        rel_label_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=rel_label_targets,
                logits=rel_label_logits)

        if num_arg_bindings > 1:                
            rel_arg_binding_targets = tf.one_hot(rel_arg_bindings, depth=num_arg_bindings)
            rel_arg_binding_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=rel_arg_binding_targets,
                    logits=rel_arg_binding_logits)
        else:
            rel_arg_binding_loss = 0                    
        
        rel_label_example_accuracy = tf.cast(
            tf.equal(rel_label_predictions, rel_label_ids), tf.float32, name="rel_label_acc")
        
        rel_label_accuracy = tf.reduce_mean(rel_label_example_accuracy, name="rel_label_mean_acc")        
        
        rel_arg_binding_example_accuracy = tf.cast(
            tf.equal(rel_arg_binding_predictions, rel_arg_bindings), tf.float32, name="rel_arg_binding_acc")
        
        rel_arg_binding_accuracy = tf.reduce_mean(
            rel_arg_binding_example_accuracy, name="rel_arg_binding_mean_acc")
        
        total_example_accuracy = tf.identity(
            rel_label_example_accuracy * rel_arg_binding_example_accuracy, name="total_acc")
        
        total_accuracy = tf.identity(total_example_accuracy, name="total_mean_acc")
        
        
        loss = tf.reduce_mean(rel_label_loss + rel_arg_binding_loss, name="loss")

        if is_trainable:
            train_op = create_optimizer(
                loss, 
                learning_rate, 
                num_train_steps, 
                num_warm_up_steps,
                use_tpu=False)        
        else:
            train_op = tf.no_op()
            
        init = tf.global_variables_initializer()

        return (
                train_op,
                loss,       
                total_accuracy,
                (rel_label_predictions, rel_arg_binding_predictions), 
                (rel_label_log_probs, rel_arg_binding_probs))

## Model saving
Function to export the trained BERT model to disk in TF

In [None]:
def export_model(model_id, is_trainable = True, num_arg_bindings = 3):
    
    with tf.compat.v1.Session() as session:
    
        model = create_model(
            len(rel_labels), 
            is_trainable=is_trainable,
            num_arg_bindings=num_arg_bindings)
    
        input_tensors = {}
        output_tensors = {}

        input_tensors_names = {
            "input_ids:0",
            "input_mask:0",
            "segment_ids:0",
            "entity1_pos:0",
            "entity2_pos:0",
        }

        output_tensors_names = [
            "loss:0",
            "rel_label_acc:0",
            "rel_arg_binding_acc:0",
            "total_acc:0",
            "rel_label_probs:0",
            "rel_label_predictions:0",
            "rel_arg_binding_probs:0",
            "rel_arg_binding_predictions:0"        
        ]


        for k in input_tensors_names:
            t = session.graph.get_tensor_by_name(k)
            input_tensors[t.name] = t

        for k in output_tensors_names:
            t = session.graph.get_tensor_by_name(k)
            output_tensors[k] = t
        
        # print("{} trainable variables: ".format(len(tf.trainable_variables())))
        size_f = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
        n = sum(size_f(v) for v in tf.trainable_variables())
        print("{} trainbale parameters.".format(n))    
        
        tf.train.Saver().restore(session, f"{BertREConfig.CHKPOINT_PATH}/{model_id}/model")

        shutil.rmtree(BertREConfig.EXPORT_PATH, ignore_errors=True)

        #save model
        tf.saved_model.simple_save(
            session,
            BertREConfig.EXPORT_PATH,
            inputs=input_tensors,
            outputs=output_tensors
        )

        #copy assets to the destiation folder
        shutil.copytree(
            f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets", 
            f"{BertREConfig.EXPORT_PATH}/assets")


## Training
Function to train the model for RE

In [None]:
#Train a Bert RE model and save it in the checkpoints folder
def train_model(model_id, train_features, test_features, rel_labels, num_arg_bindings = 3):

    ops.reset_default_graph()
    
    use_rel_args = BertREConfig.REL_ARG_BINDING_COLUMN is not None
    
    with tf.compat.v1.Session() as session:

        model = create_model(
            len(rel_labels), 
            chkpoint_path=BertREConfig.BERT_MODEL_CHECKPOINT_PATH, 
            num_arg_bindings=num_arg_bindings, 
            is_trainable=True
        )
        # return
        session.run("init")
        
        num_train_steps = (BertREConfig.NUM_TRAIN_EPOCHS * len(train_features)) // BertREConfig.BATCH_SIZE    

        clear_output(wait=True)

        print("{:^11}{:^11}{:>10}{:>10}{:>10}{:>10}{:>10}{:>10}{:>10}".format(
            "Epoch", "Batch", 
            "Loss", 
            "L_ACC", "Arg_ACC", "ACC",
            "vL_ACC", "vArg_ACC", "vACC"))

        for e in range(BertREConfig.NUM_TRAIN_EPOCHS):

            np.random.shuffle(train_features)            

            b_loss = []
            b_rel_label_acc = []
            b_rel_arg_binding_acc = []
            b_total_acc = []
            for b in range(0, len(train_features) // BertREConfig.BATCH_SIZE):

                batch = make_batch(
                    train_features[b * BertREConfig.BATCH_SIZE: (b + 1) * BertREConfig.BATCH_SIZE]
                )#, max_seq_len=MAX_SEQ_LENGTH)

                data = batch

                if b == 0:
                    data["num_train_steps:0"] = num_train_steps
                    data["learning_rate:0"] = BertREConfig.LEARNING_RATE

                eval_tensors = [
                    "optimizer", 
                    "loss:0", 
                    "rel_label_mean_acc:0",
                    "rel_arg_binding_mean_acc:0",
                    "total_mean_acc:0"
                ]
                _, loss, rel_label_acc, rel_arg_bind_acc, total_acc = session.run(
                    eval_tensors, feed_dict=data)
                b_loss.append(loss)
                b_rel_label_acc.append(rel_label_acc)
                b_rel_arg_binding_acc.append(rel_arg_bind_acc)
                b_total_acc.append(total_acc)

                print("\r{:>5}/{:<5}{:>5}/{:<5}{:>10.4f}{:>10.3f}{:>10.3f}{:>10.3f}".format(
                        e+1,
                        BertREConfig.NUM_TRAIN_EPOCHS,
                        b + 1, 
                        len(train_features) // BertREConfig.BATCH_SIZE,
                        np.mean(b_loss), 
                        np.mean(b_rel_label_acc), 
                        np.mean(b_rel_arg_binding_acc), 
                        np.mean(b_total_acc)  
                    ), end="")


            v_rel_label_acc = []
            v_rel_arg_binding_acc = []
            v_total_acc = []

            for v_b in range(0, len(test_features) // BertREConfig.V_BATCH_SIZE):
                batch = make_batch(
                    test_features[v_b * BertREConfig.V_BATCH_SIZE: (v_b + 1) * BertREConfig.V_BATCH_SIZE])

                data = batch

                eval_tensors = [
                    "rel_label_mean_acc:0",
                    "rel_arg_binding_mean_acc:0",
                    "total_mean_acc:0",
                ]

                rel_label_acc, rel_arg_bind_acc, total_acc = session.run(eval_tensors, feed_dict=data)
                v_rel_label_acc.append(rel_label_acc)
                v_rel_arg_binding_acc.append(rel_arg_bind_acc)
                v_total_acc.append(total_acc)

            print("{:>10.3f}{:>10.3f}{:>10.3f}".format(
                np.mean(v_rel_label_acc), 
                np.mean(v_rel_arg_binding_acc), 
                np.mean(v_total_acc)))                


        
        shutil.rmtree(f"{BertREConfig.CHKPOINT_PATH}/{model_id}", ignore_errors=True)
        os.mkdir(f"{BertREConfig.CHKPOINT_PATH}/{model_id}")
        
        saver = tf.train.Saver()
        saver.save(session, f"{BertREConfig.CHKPOINT_PATH}/{model_id}/model")
        
        os.mkdir(f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets/")
        
        shutil.copy(
            BertREConfig.BERT_VOCAB_PATH, 
            f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets/vocab.txt")
        
        with open(f"{BertREConfig.CHKPOINT_PATH}/{model_id}/assets/categories.txt", "wt") as F:
            F.writelines("\n".join(rel_labels))


## Evaluation
Functions to get the metrics on the RE model and print them

In [None]:
def eval_metrics(model_id, features, rel_labels, num_arg_bindings = 3, exclude_rels=[]):
    with tf.compat.v1.Session() as session:

        model = create_model(
            len(rel_labels), 
            num_arg_bindings=num_arg_bindings,
            is_trainable=False)
        
        tf.train.Saver().restore(session, f"{BertREConfig.CHKPOINT_PATH}/{model_id}/model")

        metrics_data = {}
        for rel in rel_labels:
            metrics_data[rel] = ([], [], [])
            
        for v_b in range(0, len(features) // BertREConfig.V_BATCH_SIZE):
            batch = make_batch(
                features[v_b * BertREConfig.V_BATCH_SIZE: (v_b + 1) * BertREConfig.V_BATCH_SIZE])

            data = batch

            eval_tensors = ["total_acc:0", "rel_label_ids:0", "rel_label_predictions:0"]

            total_acc, rel_label_ids, rel_label_preds = session.run(eval_tensors, feed_dict=data)
            
            for i in range(len(rel_label_ids)):
                acc = total_acc[i]
                pred = rel_label_preds[i]
                target = rel_label_ids[i]
                rel_target = rel_labels[target]
                rel_pred = rel_labels[pred]
                
                metrics_data[rel_target][2].append(1)
                
                if acc:
                    metrics_data[rel_target][0].append(1)
                    metrics_data[rel_pred][1].append(1)
                else:
                    metrics_data[rel_target][0].append(0)
                    metrics_data[rel_pred][1].append(0)                

        results = {}        
        
        for rel in [rel for rel in rel_labels if rel not in exclude_rels]:
            if len(metrics_data[rel][0]):
                recall = np.mean(metrics_data[rel][0])
            else:
                recall = 0
            if len(metrics_data[rel][1]):
                precision = np.mean(metrics_data[rel][1])
            else:
                precision = 0
            if (recall + precision):
                f1 = 2 * (recall * precision) / (recall + precision)
            else:
                f1 = np.NaN
               
            support = np.sum(metrics_data[rel][2])
            
            results[rel] = (recall, precision, f1, support)
        
        return results



In [None]:
def print_metrics(results):
    print("\n")
    print("{:<15}{:>10}{:>10}{:>10}{:>10}\n".format("Relation", "Recall", "Precision", "F1", "Support"))

    for rel in results:

        print(f"{rel:<15}{results[rel][0]:>10.3f}{results[rel][1]:>10.3f}{results[rel][2]:>10.3f}{results[rel][3]:>10}")

    mean_recall = np.mean([results[rel][0] for rel in results])
    mean_precision = np.mean([results[rel][1] for rel in results])
    mean_f1 = np.mean([results[rel][2] for rel in results])

    support_sum = np.sum([results[rel][3] for rel in results])

    w_mean_recall = np.sum([results[rel][0] * results[rel][3] for rel in results]) / support_sum
    w_mean_precision = np.sum([results[rel][1] * results[rel][3] for rel in results]) / support_sum
    w_mean_f1 = np.sum([results[rel][2] * results[rel][3] for rel in results]) / support_sum


    metrics_name = "Avg."

    print(f"\n{metrics_name:<15}{mean_recall:>10.3f}{mean_precision:>10.3f}{mean_f1:>10.3f}")

    metrics_name = "Weighted Avg."

    print(f"\n{metrics_name:<15}{w_mean_recall:>10.3f}{w_mean_precision:>10.3f}{w_mean_f1:>10.3f}")

## MAIN: STEP-BY-STEP RE MODEL TRAINING EXECUTION

### Update Bert vocabulary with special tokens

In [None]:
#Update Bert vocabylary
update_vocab()

### Creating a Tokenizer

In [None]:
#create tokenizer
tokenizer = tokenization.FullTokenizer(vocab_file=BertREConfig.BERT_VOCAB_PATH, do_lower_case=False)

### Read the input data.
It should look like as follows (see output) and have the following columns

| Column      |          Explanation                     |
|:-----------:|:----------------------------------------:|
|dataset      | train/test                               |
|source       | data provider                            |
|txt_file     | .txt file                                |
|sentence     | tokenized text sentence                  |
|sent_id      | sentence id                              |
|chunk1       | first entity                             |
|begin1       | first token number of the first entity   |
|end1         | last token number of the first entity    |
|rel          | relation (O for no-relation)             |
|chunk2       | second entity                            |
|begin2       | first token number of the second entity  |
|end2         | last token number of the second entity   |
|label1       | label of the first entity                |
|label2       | label of the second entity               |
|lastCharEnt1 | last char number of the first entity     |
|firstCharEnt1| first char number of the first entity    |
|lastCharEnt2 | last char number of the second entity    |
|firstCharEnt2| first char number of the second entity   |
|words_in_ent1| number of words in first entity          |
|words_in_ent2| number of words in second entity         |
|words_between| word between entities                    |
|is_train     | is it used for training?                 |

Yes, We are ready to train REDL model. Now we will train a REDL model to get relations between **DOC**, **PARTY**, **ALIAS** and **EFFDATE** entitties. 

Let's look at our dataset. 

Your dataset have to be like this format.

In [None]:
#Update Bert vocabylary
import pandas as pd
data = pd.read_csv("/content/relations.csv")
data


Unnamed: 0,text,firstCharEnt1,firstCharEnt2,lastCharEnt1,lastCharEnt2,chunk1,chunk2,label1,label2,rel,direction
0,EXHIBIT 10.43 Dated 29/3/18\n\nDistributorship...,29,65,54,95,Distributorship agreement,Signature Orthopaedics Pty Ltd,DOC,PARTY,signed_by,1
1,EXHIBIT 10.43 Dated 29/3/18\n\nDistributorship...,29,102,54,129,Distributorship agreement,CPM Medical Consultants LLC,DOC,PARTY,signed_by,1
2,EXHIBIT 10.43 Dated 29/3/18\n\nDistributorship...,20,29,27,54,29/3/18,Distributorship agreement,EFFDATE,DOC,dated_as,2
3,Sections 200.80(b)(4) and Rule 406 of the Secu...,173,236,196,247,Collaboration Agreement,"Xencor, Inc",DOC,PARTY,signed_by,1
4,"Monrovia, CA 91016 USA\n\n(hereinafter called ...",63,170,97,173,Boehringer Ingelheim International,BII,PARTY,ALIAS,has_alias,1
...,...,...,...,...,...,...,...,...,...,...,...
3496,By execution of this Supplier/Subcontractor Co...,85,85,93,93,Supplier,Supplier,PARTY,ROLE,other,0
3497,/s Liu Gang Name: LIU GANG Title: Authoriz...,20,38,28,58,LIU GANG,Authorized Signatory,SIGNING_PERSON,SIGNING_TITLE,other,0
3498,"HOFV: HOF VILLAGE, LLC By: /s / Brian Parisi N...",193,212,204,227,David Baker,President & CEO,SIGNING_PERSON,SIGNING_TITLE,other,0
3499,By: /s/ Robert Mattacchione Name: Robert Matta...,34,61,53,64,Robert Mattacchione,CEO,SIGNING_PERSON,SIGNING_TITLE,other,0


In [None]:
data.value_counts('rel')

rel
other                   1639
signed_by                865
has_alias                471
dated_as                 435
has_collective_alias      91
dtype: int64

In [None]:
data.value_counts('label1')

label1
PARTY                         1284
DOC                           1267
SIGNING_PERSON                 737
ORG                             72
ALIAS                           57
ROLE                            41
EFFDATE                         18
SIGNING_TITLE                   11
TITLE                            7
NAME                             6
PERMISSION_INDIRECT_OBJECT       1
dtype: int64

In [None]:
data.value_counts('label2')

label2
PARTY                882
SIGNING_TITLE        724
SIGNING_PERSON       498
ALIAS                483
ROLE                 389
EFFDATE              376
ORG                   63
AGRDATE               52
DOC                   19
NAME                   7
TITLE                  6
FORMER_PARTY_NAME      1
PERMISSION             1
dtype: int64

In [None]:
#get a list of valid relation names (less than 10 occurrences are probably wrong labels, or at least with a very low representation)
valid_rel_labels = data['rel'].unique()
valid_rel_labels

array(['signed_by', 'dated_as', 'has_alias', 'has_collective_alias',
       'other'], dtype=object)

### Create the training and test datasets

In [None]:
df_train = data.sample(frac=0.9, random_state=1)
df_test = data.drop(df_train.index)


In [None]:
train_sentences, train_rel_label_ids, train_rel_arg_bindings, rel_labels = (
    collect_data_from_pandas_dataset(df_train))

test_sentences, test_rel_label_ids, test_rel_arg_bindings, _ = (
    collect_data_from_pandas_dataset(df_test))

In [None]:
train_sentences = train_sentences.values
train_rel_label_ids = train_rel_label_ids.values
train_rel_arg_bindings = train_rel_arg_bindings.values



In [None]:
test_sentences = test_sentences.values
test_rel_label_ids = test_rel_label_ids.values
test_rel_arg_bindings = test_rel_arg_bindings.values

### 7. Create the features from the datasets

In [None]:
#create features
train_features = make_features(
    train_sentences, 
    list(zip(train_rel_label_ids, train_rel_arg_bindings)),
    tokenizer, 
    is_test=False)

if BertREConfig.REPLACE_ARG_PROB and BertREConfig.REPLICATE_DATASET:
    train_features += make_features(
        train_sentences, 
        list(zip(train_rel_label_ids, train_rel_arg_bindings)),
        tokenizer, 
        is_test=True)
        
test_features = make_features(
    test_sentences, 
    list(zip(test_rel_label_ids, test_rel_arg_bindings)),
    tokenizer, 
    is_test=True)

In [None]:
print("{} training examples".format(len(train_features)))
print("{} test examples".format(len(test_features)))

2383 training examples
266 test examples


In [None]:
train_features_part1 = train_features[:len(train_features) // 2]
train_features_part2 = train_features[len(train_features) // 2:]

In [None]:
print("{} training part 1 examples".format(len(train_features_part1)))
print("{} training part 2 examples".format(len(train_features_part2)))
print("{} test examples".format(len(test_features)))

1191 training part 1 examples
1192 training part 2 examples
266 test examples


In [None]:
from tensorflow.python.framework import ops

In [None]:
rel_labels

['dated_as', 'has_alias', 'has_collective_alias', 'other', 'signed_by']

### Training the model
Here is where all the fun happens! 🏄

In [None]:


train_model(
    "CONTRACT_DOC_PARTIES_SPLIT",
    train_features=train_features, 
    test_features=test_features, 
    rel_labels=rel_labels,
    num_arg_bindings=3)

   Epoch      Batch         Loss     L_ACC   Arg_ACC       ACC    vL_ACC  vArg_ACC      vACC
    1/5      148/148      0.4418     0.866     0.959     0.857     0.975     1.000     0.975
    2/5      148/148      0.0478     0.987     1.000     0.987     0.975     1.000     0.975
    3/5      148/148      0.0143     0.996     1.000     0.996     0.975     1.000     0.975
    4/5      148/148      0.0049     0.999     1.000     0.999     0.970     1.000     0.970
    5/5      148/148      0.0022     1.000     1.000     1.000     0.975     1.000     0.975


### Evaluating the model

In [None]:
tf.reset_default_graph()

metrics = eval_metrics(
    "CONTRACT_DOC_PARTIES_SPLIT", test_features, rel_labels, num_arg_bindings=0, exclude_rels=[])        
print_metrics(metrics)



Relation           Recall Precision        F1   Support

dated_as            0.976     0.953     0.965        42
has_alias           0.964     1.000     0.982        28
has_collective_alias     1.000     1.000     1.000         1
other               1.000     0.965     0.982        55
signed_by           0.959     0.986     0.973        74

Avg.                0.980     0.981     0.980

Weighted Avg.       0.975     0.975     0.975


# Train with all data

In [None]:
tf.reset_default_graph()    

train_model(
    "CONTRACT_DOC_PARTIES", 
    train_features=train_features+test_features,
    test_features=[], 
    rel_labels=rel_labels,
    num_arg_bindings=3)

   Epoch      Batch         Loss     L_ACC   Arg_ACC       ACC    vL_ACC  vArg_ACC      vACC
    1/5      165/165      0.3778     0.887     0.969     0.880       nan       nan       nan
    2/5      165/165      0.0545     0.989     1.000     0.989       nan       nan       nan
    3/5      165/165      0.0169     0.995     1.000     0.995       nan       nan       nan
    4/5      165/165      0.0115     0.998     1.000     0.998       nan       nan       nan
    5/5      165/165      0.0088     0.997     1.000     0.997       nan       nan       nan


In [None]:
!ls -l /content/trained/*

/content/trained/CONTRACT_DOC_PARTIES:
total 1269052
drwxr-xr-x 2 root root       4096 Feb 12 18:22 assets
-rw-r--r-- 1 root root         67 Feb 12 18:22 checkpoint
-rw-r--r-- 1 root root 1295219816 Feb 12 18:22 model.data-00000-of-00001
-rw-r--r-- 1 root root      22852 Feb 12 18:22 model.index
-rw-r--r-- 1 root root    4248536 Feb 12 18:22 model.meta

/content/trained/CONTRACT_DOC_PARTIES_SPLIT:
total 1269052
drwxr-xr-x 2 root root       4096 Feb 12 18:15 assets
-rw-r--r-- 1 root root         67 Feb 12 18:15 checkpoint
-rw-r--r-- 1 root root 1295219816 Feb 12 18:15 model.data-00000-of-00001
-rw-r--r-- 1 root root      22852 Feb 12 18:15 model.index
-rw-r--r-- 1 root root    4248536 Feb 12 18:15 model.meta


### Finally saving it!

In [None]:
tf.reset_default_graph()

export_model("CONTRACT_DOC_PARTIES", is_trainable=False, num_arg_bindings=0)

108321797 trainbale parameters.


Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.


In [None]:
!ls -l  {BertREConfig.EXPORT_PATH}/*

-rw-r--r-- 1 root root 1013294 Feb 12 18:22 /models/basebert_re/saved_model.pb

/models/basebert_re/assets:
total 216
-rw-r--r-- 1 root root     55 Feb 12 18:22 categories.txt
-rw-r--r-- 1 root root 213422 Feb 12 18:22 vocab.txt

/models/basebert_re/variables:
total 423148
-rw-r--r-- 1 root root 433287188 Feb 12 18:22 variables.data-00000-of-00001
-rw-r--r-- 1 root root      8306 Feb 12 18:22 variables.index


In [None]:
!sudo apt-get -y install zip unzip

Reading package lists... Done
Building dependency tree       
Reading state information... Done
zip is already the newest version (3.0-11build1).
unzip is already the newest version (6.0-25ubuntu1.1).
The following package was automatically installed and is no longer required:
  libnvidia-common-510
Use 'sudo apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 21 not upgraded.


In [None]:
!zip -r redl.zip /content/models/basebert_re

  adding: content/models/basebert_re/ (stored 0%)
  adding: content/models/basebert_re/saved_model.pb (deflated 92%)
  adding: content/models/basebert_re/variables/ (stored 0%)
  adding: content/models/basebert_re/variables/variables.index (deflated 68%)
  adding: content/models/basebert_re/variables/variables.data-00000-of-00001 (deflated 7%)
  adding: content/models/basebert_re/assets/ (stored 0%)
  adding: content/models/basebert_re/assets/categories.txt (deflated 18%)
  adding: content/models/basebert_re/assets/vocab.txt (deflated 49%)


# We test in SPARK NLP

Let's test our model with SparkNLP

In [None]:
import pandas as pd

def get_relations_df (results, col='relations'):
  rel_pairs=[]
  for rel in results[0][col]:
      rel_pairs.append((
          rel.result, 
          rel.metadata['entity1'], 
          rel.metadata['entity1_begin'],
          rel.metadata['entity1_end'],
          rel.metadata['chunk1'], 
          rel.metadata['entity2'],
          rel.metadata['entity2_begin'],
          rel.metadata['entity2_end'],
          rel.metadata['chunk2'], 
          rel.metadata['confidence']
      ))

  rel_df = pd.DataFrame(rel_pairs, columns=['relation','entity1','entity1_begin','entity1_end','chunk1','entity2','entity2_begin','entity2_end','chunk2', 'confidence'])

  return rel_df

Here, we import our model to SparkNLP

In [None]:
re = legal.RelationExtractionDLModel().loadSavedModel('/content/models/basebert_re', spark)
re.write().overwrite().save('legre_contract_doc_parties')

In [None]:

text='''
This INTELLECTUAL PROPERTY AGREEMENT (this "Agreement"), dated as of December 31, 2018 (the "Effective Date") is entered into by and between Armstrong Flooring, Inc., a Delaware corporation ("Seller") and AFI Licensing LLC, a Delaware limited liability company ("Licensing" and together with Seller, "Arizona") and AHF Holding, Inc. (formerly known as Tarzan HoldCo, Inc.), a Delaware corporation ("Buyer") and Armstrong Hardwood Flooring Company, a Tennessee corporation (the "Company" and together with Buyer the "Buyer Entities") (each of Arizona on the one hand and the Buyer Entities on the other hand, a "Party" and collectively, the "Parties").
'''

Now before getting relations, we have to extract entities from the given text. For this, we will use `legner_contract_doc_parties_lg` NER model.

In [None]:
document_assembler = nlp.DocumentAssembler()\
  .setInputCol("text")\
  .setOutputCol("document")

tokenizer = nlp.Tokenizer()\
    .setInputCols("document")\
    .setOutputCol("token")

embeddings = nlp.RoBertaEmbeddings.pretrained("roberta_embeddings_legal_roberta_base", "en") \
    .setInputCols("document", "token") \
    .setOutputCol("embeddings")\
    .setMaxSentenceLength(512)

ner_model = legal.NerModel.pretrained('legner_contract_doc_parties_lg', 'en', 'legal/models')\
    .setInputCols(["sentence", "token", "embeddings"])\
    .setOutputCol("ner")

ner_converter = nlp.NerConverter()\
    .setInputCols(["document","token","ner"])\
    .setOutputCol("ner_chunk")

# We use the load function to run our trained model.
reDL = legal.RelationExtractionDLModel().load('legre_contract_doc_parties')\
    .setPredictionThreshold(0.5)\
    .setInputCols(["ner_chunk", "document"])\
    .setOutputCol("relations")

nlpPipeline = nlp.Pipeline(stages=[
    document_assembler,
    tokenizer,
    embeddings,
    ner_model,
    ner_converter,
    reDL
    ])

data = spark.createDataFrame([[text]]).toDF("text")

model = nlpPipeline.fit(data)

roberta_embeddings_legal_roberta_base download started this may take some time.
Approximate size to download 447.2 MB
[OK!]
legner_contract_doc_parties_lg download started this may take some time.
[OK!]


In [None]:
light_model = nlp.LightPipeline(model)


results = light_model.fullAnnotate(text)

In [None]:
rel_df = get_relations_df(results)
rel_df = rel_df[rel_df['relation']!='other']
rel_df

Unnamed: 0,relation,entity1,entity1_begin,entity1_end,chunk1,entity2,entity2_begin,entity2_end,chunk2,confidence
0,dated_as,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,EFFDATE,70,86,"December 31, 2018",0.9998894
1,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,142,164,"Armstrong Flooring, Inc",0.99805707
2,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,193,198,Seller,0.54615384
3,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,206,222,AFI Licensing LLC,0.9963246
4,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,264,272,Licensing,0.5839682
5,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,293,298,Seller,0.9191695
6,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,316,331,"AHF Holding, Inc",0.99637455
8,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,PARTY,412,446,Armstrong Hardwood Flooring Company,0.9782145
9,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,479,485,Company,0.93827885
10,signed_by,DOC,6,36,INTELLECTUAL PROPERTY AGREEMENT,ALIAS,517,530,Buyer Entities,0.93827885
