# 2.8 Relation Extraction

Notebook processes the english-translated documents from batch and outputs a dataframe of relation predictions that can be used to construct a knowledge graph.
Note the addition of the session-scoped java package needed to get Spark NLP to work: spark-nlp-assembly-4.0.2.jar. Spark NLP is used to perform sentence splitting and may eventually be used to replace the API-based coreference resolution.

In [ ]:
%%configure -f
{
"conf": {
     "spark.jars.packages": "com.microsoft.azure:synapseml_2.12:0.10.0-19-c3a445c5-SNAPSHOT",
      "spark.jars.repositories": "https://mmlspark.azureedge.net/maven",
      "spark.jars": "abfss://synapse@<<STORAGE_ACCOUNT_NAME>>.dfs.<<AZURE_STORAGE_DOMAIN>>/jars/spark-nlp.jar",
      "spark.jars.excludes": "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind",
      "spark.yarn.user.classpath.first": "true"
   }
}

In [ ]:
manifest_file_path = ''
manifest_container = ''
summarized_text_dailymail_tbl_name = ''
batch_num = ''
file_system = ''
batch_root = ''
blob_account_name = ''
azure_storage_domain = ''
minted_tables_output_path = ''
input_container = ''
output_container= ''

display_dataframes = False # set to True for debugging (Note, this triggers execution multiple times, only use on small runs)

In [ ]:
import pyodbc
from azure.storage.blob import generate_blob_sas, BlobSasPermissions, generate_container_sas, ContainerSasPermissions, BlobClient
from datetime import datetime, timedelta
dedicated_database = "dedicated"
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")
storage_account_key = mssparkutils.credentials.getSecretWithLS('keyvault', 'StorageAccountKey')

In [ ]:
import os
import ast
import json
from types import SimpleNamespace

# Initialise paths and batch root
batch_path = f'abfss://{manifest_container}@{blob_account_name}.dfs.{azure_storage_domain}'
batch_folder = os.path.dirname(manifest_file_path)
batch_root = f'{batch_path}{batch_folder}'

def copy_global_config(config_path: str, global_config_path: str):
    """
    This method makes sure that a config is availabile in the batch root.
    If a config file isn't already there, it is copied over form global_config_path.
    If there is no config under global_config_path, this function will crash (indicating an error in pipeline set up.)
    """
    logger.info("Loading global config")
    try:
        mssparkutils.fs.cp(global_config_path, config_path)    
    except Py4JJavaError as e:
        logger.exception(e)
        raise e

def read_batch_config(batch_root: str, global_config_path: str):
    """
    We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
    all back together again
    """
    # Change our file system from 'synapse' to 'input'
    sc._jsc.hadoopConfiguration().set("fs.defaultFS", batch_path)

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

    # If we don't have a batch config, copy the global one.
    if fs.exists(config_path) != True:
        copy_global_config(f'{batch_root}/config.json', global_config_path)

    # Open our file directly rather than through spark
    input_stream = fs.open(config_path)  # FSDataInputStream

    config_string = sc._jvm.java.io.BufferedReader(
        sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
        ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

    # # Load it into json    
    return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

# NOTE: this path should be in sync with Terraform configuration which uploads this file
config = read_batch_config(batch_root, global_config_path=f'abfss://configuration@{blob_account_name}.dfs.{azure_storage_domain}/config.global.json')

In [ ]:
# Load Cog Services text analytics keys, set defaults
instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")
text_analytics_keys = mssparkutils.credentials.getSecretWithLS("keyvault", 'TextAnalyticsKeys').split(',')

# Get config values 
web_app_uri = config.relex.relex_webapp_uri
relex_endpoint_api_key = config.relex.key

# Coreference resolution and relation extraction model endpoints
relex_endpoint_healthcheck_url = f'{web_app_uri}/api/healthcheck'
coref_endpoint_url = f'{web_app_uri}/api/coref_batch_pred'
relex_endpoint_url = f'{web_app_uri}/api/relex_batch_pred'

# Pipeline parameters (chunking large documents/summaries and batching the chunks for fewer API calls)
max_coref_doc_char_length = 3500
min_coref_doc_char_length = 2000
coref_batch_size = 30 # number of document chunks to perform coreference resolution per API call

max_entity_link_doc_char_length = 3000
min_entity_link_doc_char_length = 2000
entity_link_min_confidence_score = float(config.relex.el_confidence_score) # (0)

relex_batch_size = 1000 # number of relation predictions per call to API

# Graph Visualization Filters
filter_graph_viz = ast.literal_eval(config.relex.filter_graph_viz) # defaults to False
relex_min_probability = float(config.relex.relex_min_probability) # filters predictions displayed in graph visualizations, (0.7)
exclude_relations = config.relex.exclude_relations # ["main subject", "followed by", "follows", "said to be the same as", "instance of"]
clique_min_nodes = int(config.relex.clique_min_nodes) # plots only connected cliques with at least min # nodes per clique, (3)

# Cognitive Services Entity Linker API Config
cog_svc_concurrency = 1
cog_svc_batch_size = 15 # The /analyze endpoint that TextAnalyze uses is documented to allow batches of up to 25 documents
cog_svc_intial_polling_delay = 15000 # Time (in ms) to wait before first poll for results
cog_svc_polling_delay = 10000 # Time (in ms) to wait between repeated polling for results
cog_svc_maximum_retry_count = 100 # Maximum number of retries. 60 => 60 * 10s + 15s = 615s ~= 10 mins allowed for a job to complete

# Column names 
file_path_col = 'file_path'
text_col = 'summarized_text_dailymail'

In [ ]:
# Initiate Logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'summarized_text_dailymail_tbl_name': summarized_text_dailymail_tbl_name,
    'batch_root': batch_root,
    'batch_num': batch_num,
    'cog_svc_concurrency': cog_svc_concurrency,
    'cog_svc_batch_size': cog_svc_batch_size,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
from collections import defaultdict
import itertools
import re
import random
import requests
import fsspec
import numpy as np
import pandas as pd
import pyvis
from pyvis.network import Network
import networkx as nx
from networkx.algorithms.community import k_clique_communities

import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline

import synapse.ml
from synapse.ml.io import *
from sparknlp.pretrained import PretrainedPipeline
from synapse.ml.cognitive import *
from synapse.ml.featurize.text import PageSplitter
from synapse.ml.stages import FixedMiniBatchTransformer
from synapse.ml.stages import FlattenBatch

from pyspark import SparkContext
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import udf, col, length, size, posexplode, explode, trim
from pyspark.sql.types import ArrayType, MapType, StructType, StructField, StringType, IntegerType, FloatType

# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"RelationExtraction {mssparkutils.runtime.context}").getOrCreate()

### Load documents table from 2_2_Text_Summarization_dailymail notebook output

In [ ]:
#with tracer.span(name='Load documents contents table'):
#
#    prefix_length = len(f'abfss://{input_container}@{blob_account_name}.dfs.{azure_storage_domain}/')
#
#    df = spark.read.parquet(f'{minted_tables_output_path}{summarized_text_dailymail_tbl_name}') \
#                .withColumn('file_name', F.expr(f"substring(file_path, {prefix_length + 1}, length(file_path))"))
#    
#    if display_dataframes:
#        df.show(n=5)

### Resolve document coreferences
Find all expressions that refer to the same entity and replace with entity name.
Chunk up documents into smaller more manageable portions to avoid model OOM errors and then batch the API call to perform coreference resolution.

In [ ]:
#with tracer.span(name='Split large documents into chunks for coreference resolution'):
#    page_splitter = (PageSplitter()
#        .setInputCol(text_col)
#        .setMinimumPageLength(min_coref_doc_char_length)
#        .setMaximumPageLength(max_coref_doc_char_length)
#        .setOutputCol("extracted_doc_split"))
#
#    df_split = page_splitter.transform(df)
#
#    df_split = df_split.select("file_name", posexplode("extracted_doc_split") \
#                        .alias("chunk_number", "extracted_doc_split")) \
#                        .withColumn("split_doc_char_length", length("extracted_doc_split"))
#
#with tracer.span(name='Group rows into batches for coreference resolution'):
#      fmbt = (FixedMiniBatchTransformer()
#            .setBatchSize(coref_batch_size))
#
#      df_batched = fmbt.transform(df_split)

In [ ]:
#with tracer.span(name='Perform coference resolution on document batches'):
#    @udf(returnType=ArrayType(StringType()))
#    def get_coref(docs_list):
#        try:
#            print(coref_endpoint_url)
#            headers = {'Content-Type': 'application/json',
#            'Ocp-Apim-Subscription-Key': relex_endpoint_api_key}
#            response = requests.post(url=coref_endpoint_url, headers=headers, json={"doc_corefs_request":docs_list})
#            return response.json()
#        except Exception as e:
#            logger.error(f"Exception during call to batch coref resolution endpoint: {e}")
#            return docs_list # fall back to non-coreffed text if exception
#
#    health_check = requests.get(relex_endpoint_healthcheck_url)
#
#    if health_check.json()=="Ready":
#        coreffed_batch_df = df_batched.withColumn("coref_output", get_coref(col("extracted_doc_split")))
#    else:
#        logger.error("Relex endpoint inactive")
#
#with tracer.span(name='Unroll batched coreffed documents and re-combine chunks'):
#    flattener = FlattenBatch()
#    coreffed_df = flattener.transform(coreffed_batch_df)
#    coreffed_df = coreffed_df.select("file_name", "chunk_number", "coref_output") \
#                                .groupby("file_name") \
#                                .agg(F.sort_array(F.collect_list(F.struct("chunk_number", "coref_output"))) \
#                                .alias("sorted_coref_text_col")) \
#                                .withColumn("coreffed_text", F.concat_ws("", col(f"sorted_coref_text_col.coref_output")))\
#                                .drop("sorted_coref_text_col")

### Extract named entities and link them to a knowledge base

First split and batch the coreferenced text documents before calling the NER and EL Cog Services endpoint.

In [ ]:
#with tracer.span(name='Split coreffed text into chunks'):
#    page_splitter = (PageSplitter()
#        .setInputCol("coreffed_text")
#        .setMaximumPageLength(max_entity_link_doc_char_length)
#        .setMinimumPageLength(min_entity_link_doc_char_length)
#        .setOutputCol("coreffed_text_split"))
#
#    df_split = page_splitter.transform(coreffed_df)
#
#    df_split = df_split.select("file_name", posexplode("coreffed_text_split") \
#                        .alias("chunk_number", "coreffed_text_split")) \
#                        .withColumn("split_doc_char_length", length("coreffed_text_split"))
#
#with tracer.span(name='Group coreffed text chunks into batches'):
#      fmbt = (FixedMiniBatchTransformer()
#            .setBatchSize(cog_svc_batch_size))
#
#      df_batched = fmbt.transform(df_split)
#
#with tracer.span(name='Distribute cognitive service keys across rows'):
#    @udf(returnType=StringType())
#    def rand_key() :
#        index = random.randint(0, len(text_analytics_keys)-1)
#        return text_analytics_keys[index]
#
#    df_batched = df_batched.withColumn("text_analytics_key", rand_key())

In [ ]:
#with tracer.span(name='Call cognitive services NER and entity linker'):
#    text_analyze = (TextAnalyze()
#        .setTextCol("coreffed_text_split")
#        .setLocation(config.location)
#        .setLanguage(config.prep.target_language)
#        .setSubscriptionKeyCol("text_analytics_key")
#        .setOutputCol("batch_entity_link_results")
#        .setErrorCol("text_analysis_error")
#        .setConcurrency(cog_svc_concurrency)
#        .setInitialPollingDelay(cog_svc_intial_polling_delay)
#        .setPollingDelay(cog_svc_polling_delay)
#        .setMaxPollingRetries(cog_svc_maximum_retry_count)
#        .setSuppressMaxRetriesExceededException(True)
#        .setEntityRecognitionTasks([{"parameters": {"model-version": "latest"}}])
#        .setEntityLinkingTasks([{"parameters": { "model-version": "latest"}}])
#        <<SYNAPSE_ML_TEXT_ANALYZE_ENDPOINT_CMD>>
#        )
#
#    df_results_batched = text_analyze.transform(df_batched)
#
#with tracer.span(name='Unroll batched entity linking API responses'):
#    flattener = FlattenBatch()
#    entity_results_df = flattener.transform(df_results_batched)
#
#    # split out text analysis results
#    el_cols_to_drop = ("entityRecognitionPii","keyPhraseExtraction","sentimentAnalysis")
#    entity_results_df = entity_results_df.select("file_name","chunk_number","coreffed_text_split",
#                                                "batch_entity_link_results.*","text_analysis_error")\
#                            .drop(*el_cols_to_drop)
#
#with tracer.span(name='Collate split rows back to single row per document'):
#    error_response_schema = StructType(
#        [StructField("error", StructType(
#            [StructField("code", StringType()), StructField("message", StringType())]
#        ))]
#    )
#
#    entity_results_df = entity_results_df.select(
#                    "file_name",
#                    "chunk_number",
#                    "coreffed_text_split",
#                    # we only have a single task for each task type, so unpack it
#                    col("entityRecognition")[0]["result"]["entities"].alias("named_entities"),
#                    col("entityLinking")[0]["result"]["entities"].alias("linked_entities"),
#                    # set up as an array for the grouping step
#                    F.from_json(entity_results_df["text_analysis_error"]["response"], error_response_schema)["error"]["message"].alias("text_analysis_error"),
#                )\
#                .groupby("file_name")\
#                .agg(
#                    F.flatten(F.collect_list("named_entities")).alias("named_entities"),
#                    F.flatten(F.collect_list("linked_entities")).alias("linked_entities"),
#                    F.sort_array(F.collect_list(F.struct("chunk_number", "coreffed_text_split"))).alias("sorted_text"),
#                    F.max(col("text_analysis_error")).alias("text_analysis_error"))\
#                .withColumn(
#                    "coreffed_text",
#                    F.concat_ws("", col("sorted_text.coreffed_text_split"))
#                )\
#                .drop("sorted_text")
#
#    if display_dataframes:
#        entity_results_df.show(n=5)

Replace coreferenced text with linked-entity names from cognitive services.

In [ ]:
#with tracer.span(name='Replace entity mentions with linked entities'):
#
#    def link_entities(coreffed_text, linked_entities, min_confidence_score=entity_link_min_confidence_score):
#        """
#        Finds and links entities within a document to a knowledge base (KB) using MS Cog Service Entity Linking API
#        :param min_confidence_score: Limit's the API's matching entities base
#        :return: Document text replaced with KB entities (str), lookup of nicknames mapped to KB entities (dict)
#        """
#        coreffed_text = coreffed_text.strip() # we trimmed the doc earlier in spark but just in case
#
#        # store a mapping of the entity mention to kb name
#        entity_mentions = defaultdict(None) # for replacing mentions with wiki names
#
#        for entity in linked_entities:
#            for match in entity["matches"]:
#                if match["confidenceScore"] >= min_confidence_score:
#                    entity_mentions[match["text"]] = entity["name"]
#                    
#        # store kb entity names
#        _linked_entity_names = list(set(entity_mentions.values()))
#
#        # need underscore instead of ws to do proper multi-name entity replacement
#        string_to_modify = coreffed_text.replace(" ", "_")
#        _ent_mentions = defaultdict(None)
#
#        for k, v in entity_mentions.items():
#            _ent_mentions[k.replace(" ", "_")] = v.replace(" ", "_")
#            
#        # substitute all mentions with the wiki name
#        escaped_key_names = [re.escape(key) for key in _ent_mentions.keys()] # handles special chars in wiki names
#        uncompiled_pattern = "|".join(escaped_key_names)
#        pattern = re.compile(uncompiled_pattern)
#
#        output_text = pattern.sub(lambda x: _ent_mentions.get(x.group(0),x.group(0)), string_to_modify)
#        output_text = output_text.replace("_"," ")
#
#        return output_text, _linked_entity_names
#
#    el_schema = StructType([StructField("entity_linked_text", StringType()), StructField("linked_entity_names", ArrayType(StringType()))])
#    el_udf = udf(link_entities, el_schema)
#
#    entity_linked_text_df = entity_results_df.withColumn('el_results', el_udf(col('coreffed_text'),col('linked_entities'))).select(col('file_name'),col('coreffed_text'), col("el_results.*"))
#
#    if display_dataframes:
#        entity_linked_text_df.show(n=5)

### Prepare pre-processed documents for sentence-level relation extraction
Split documents up into sentences and preserve only sentences containing pairs of identified entities from previous steps.
Note: we cap extremely long sentence lengths at 512 tokens to avoid spark OOM issues. Abnormally long sentences can occur because texts may be missing punctuation.

In [ ]:
#with tracer.span(name='Split documents into sentences for relation extraction'):
#    max_sent_token_length = 512
#
#    document_assembler = DocumentAssembler().setInputCol("entity_linked_text").setOutputCol("document")
#    sentence_detector = SentenceDetectorDLModel.pretrained("sentence_detector_dl", "en")\
#                                                .setSplitLength(max_sent_token_length)\
#                                                .setInputCols(["document"])\
#                                                .setOutputCol("sentences")
#
#    sent_pipeline = Pipeline(stages=[document_assembler, sentence_detector])
#
#    sent_model = sent_pipeline.fit(entity_linked_text_df)
#    sent_df = sent_model.transform(entity_linked_text_df)
#    sent_df = sent_df.withColumn("sent_list",col("sentences.result"))\
#                        .drop("coreffed_text","entity_linked_text","document","sentences")\
#                        .select("file_name", "linked_entity_names", explode(col("sent_list")).alias("sentence"))

In [ ]:
#with tracer.span(name='Generate all pair spans per sentence'):
#
#    @udf(returnType=ArrayType(MapType(StringType(), StringType())))
#    def get_sent_pairs(sentence, linked_entity_names):
#        """Returns all entity permutation information per sentence.
#        The entities are the set of knowledge base entities returned by link_entities() above"""
#        
#        sent_pair_spans = [] # store the entity pair info for all sentences in the doc
#
#        # keep track of linked entity names and their spans in the sentence
#        entity_spans = defaultdict(None)
#        
#        # if entity name is in the sentence, store it along with its start and end spans
#        for linked_name in linked_entity_names:
#            for match in re.finditer(re.escape(linked_name), sentence):
#                entity_spans[(match.start(), match.end())] = linked_name
#        
#        # only extract entity spans if more than one unique entity detected in sentence (else can't predict a relation)
#        if len(set(entity_spans.values())) > 1:
#            
#            # generate all pair permutations excluding same-name pairs
#            perms_list = list(itertools.permutations(entity_spans.items(), 2))
#            pair_span_list = list(perm for perm in perms_list if perm[0][1] != perm[1][1])
#
#            for pair in pair_span_list:
#                head_ent, head_span = pair[0][1], pair[0][0]
#                tail_ent, tail_span = pair[1][1], pair[1][0]
#
#                # convert char spans to string to avoid mixed value types and be able to use MapType() schema
#                # also, the list arraytype(structtype()) representation does not guarantee order of the inputs
#                head_span_start, head_span_end = str(head_span[0]), str(head_span[1])
#                tail_span_start, tail_span_end = str(tail_span[0]), str(tail_span[1])
#
#                sent_pair_dict = {
#                    "sentence":sentence,
#                    "head_ent":head_ent,"head_span_start":head_span_start,"head_span_end":head_span_end,
#                    "tail_ent":tail_ent,"tail_span_start":tail_span_start,"tail_span_end":tail_span_end
#                    }
#
#                sent_pair_spans.append(sent_pair_dict)
#        
#        return sent_pair_spans
# 
#    sent_df = sent_df.withColumn("pair_spans", get_sent_pairs(col("sentence"),col("linked_entity_names")))\
#                    .filter(size("pair_spans") > 0)\
#                    .select("file_name",explode("pair_spans").alias("pair_spans"))
#
#    if display_dataframes:
#        sent_df.show(n=5)

In [ ]:
#with tracer.span(name='Batch pair spans to reduce API calls to relex endpoint'):
#    fmbt = (FixedMiniBatchTransformer()
#        .setBatchSize(relex_batch_size))
#
#    sent_df = fmbt.transform(sent_df)
#
#    if display_dataframes:
#        sent_df.show(n=5)

### Predict relation for each pair span in the doc.
Relation extraction endpoint designed to accept a list of all pair spans for a single document

In [ ]:
#with tracer.span(name='Generate relation predictions per entity pair and unbatch results'):
#    def get_relation(spans_list):
#        try:
#            headers = {'Content-Type': 'application/json', 'Ocp-Apim-Subscription-Key': relex_endpoint_api_key}
#            response = requests.post(url=relex_endpoint_url, headers=headers, json={"doc_spans_request":spans_list})
#            return response.json()
#        except Exception as e:
#            logger.error(f"Exception during call to batch relation extraction endpoint: {e}")
#            return None
#
#    relex_output_schema = ArrayType(StructType([
#        StructField("sentence", StringType(), True),
#        StructField("head_ent", StringType(), True),
#        StructField("tail_ent", StringType(), True),
#        StructField("relation", StringType(), True),
#        StructField("probability", FloatType(), True)
#    ]))
#
#    gr_udf = udf(get_relation, relex_output_schema)
#
#    health_check = requests.get(relex_endpoint_healthcheck_url)
#
#    if health_check.json()=="Ready":
#        sent_df = sent_df.withColumn("relex_output", gr_udf(col("pair_spans")))
#    else:
#        logger.error("Relex endpoint inactive")
#
#    # unbatch the response
#    flattener = FlattenBatch()
#    sent_df = flattener.transform(sent_df)
#    relex_results_df = sent_df.select("file_name","relex_output.*")
#
#    if display_dataframes:
#        relex_results_df.show(n=10)

In [ ]:
#with tracer.span(name='Trigger Relex Pipeline Execution'):
#    # sort by descending probability and preserve only highest probability relation among duplicate relations
#    relex_results_df = relex_results_df.sort(col("probability").desc()) \
#                                        .dropDuplicates(["head_ent","tail_ent","relation"]) \
#                                        .sort(col("probability").desc())
#
#    # save relex_results_df to parquet and create external table
#    edgelist_tbl_name = f'{batch_num}_relex_edgelist'
#    edgelist_df = relex_results_df.toPandas()
#    relex_results_df.write.mode("overwrite").parquet(f'{minted_tables_output_path}{edgelist_tbl_name}')
#
#    relex_sql_command = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{edgelist_tbl_name}') 
#    CREATE EXTERNAL TABLE [{edgelist_tbl_name}] 
#    (
#        [file_name] nvarchar(4000), 
#        [sentence] nvarchar(4000), 
#        [head_ent] nvarchar(4000),
#        [tail_ent] nvarchar(4000),
#        [relation] nvarchar(4000),
#        [probability] float
#    ) WITH (
#            LOCATION = 'minted_tables/{edgelist_tbl_name}/**', 
#            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
#            FILE_FORMAT = [SynapseParquetFormat]
#            )"""
#
#    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
#        with conn.cursor() as cursor:
#            cursor.execute(relex_sql_command)

### Visualize the relations in a knowledge graph.

If user sets "filter_graph_viz=True" in the global config file, a graph preserving only cliques with min_clique_nodes will be generated. Else, all the unfiltered predictions will be graphed.

In [ ]:
#with tracer.span(name='Filter graph for plotting'):
#
#    # Curate edgelist for visualization
#    edgelist_df.dropna(inplace=True) # drop nulls
#
#    # grab head and tail entity document information for graph tooltips
#    head_docs_df = edgelist_df.groupby("head_ent").agg(head_found_in=('file_name','unique')).reset_index()
#    tail_docs_df = edgelist_df.groupby("tail_ent").agg(tail_found_in=('file_name','unique')).reset_index()
#
#    # preserve only unique documents of a given node (entity) whether it's a head or tail entity
#    lookup_df = pd.merge(head_docs_df, tail_docs_df, left_on="head_ent", right_on="tail_ent", how="outer")
#    lookup_df["entity"] = lookup_df["head_ent"].combine_first(lookup_df["tail_ent"])
#
#    # ensure na becomes empty list and arrays get converted to list before merging
#    lookup_df['head_found_in'] = [[] if x is np.NaN else x.tolist() for x in lookup_df['head_found_in']]
#    lookup_df['tail_found_in'] = [[] if x is np.NaN else x.tolist() for x in lookup_df['tail_found_in']]
#    lookup_df["found_in"] = (lookup_df["head_found_in"] + lookup_df["tail_found_in"]).map(set).map(list)
#    lookup_df = lookup_df[["entity", "found_in"]]
#
#    if filter_graph_viz:
#        edgelist_df = edgelist_df[edgelist_df["probability"] > relex_min_probability]
#        edgelist_df = edgelist_df[~edgelist_df["relation"].str.contains('|'.join(exclude_relations))]
#        
#        # grab all connected cliques of minimum node size
#        G = nx.from_pandas_edgelist(edgelist_df, "head_ent","tail_ent",["relation","probability"])
#
#        k = list(k_clique_communities(G, clique_min_nodes))
#
#        # ensure we have something to plot by taking the largest non-zero k for k_clique_communities from the unfiltered graph
#        while len(k) == 0:
#            if clique_min_nodes==1 and len(k)==0:
#                print("No relations to graph!")
#                break
#            
#            clique_min_nodes -= 1
#            k = list(k_clique_communities(G, clique_min_nodes))
#
#        k_clique_nodes = [list(x) for x in k]
#        k_clique_nodes = set(n for clq in k_clique_nodes for n in clq)
#        K_clique_graph = G.subgraph(k_clique_nodes)
#        clique_df = nx.to_pandas_edgelist(K_clique_graph) # grab edgelist instead of actual clique graph for more viz control
#
#        if display_dataframes:
#            print(clique_df.head())

Preserve densely-connected subgraphs for interesting visualization and write results out to output container.

In [ ]:
#with tracer.span(name='Populate a pyvis graph object from filtered or unfiltered predictions'):
#
#    # generate the pyvis graph as .html
#    nt = Network(directed=True, height='1080px', width='100%', bgcolor='#222222', font_color='white')
#    nt.force_atlas_2based(spring_strength=0.01, overlap=1)
#    nt.toggle_physics(True) # False for static graph
#
#    if filter_graph_viz:
#        edge_data = zip(clique_df['source'], clique_df['target'], clique_df['relation'])
#    else:
#        edge_data = zip(edgelist_df['head_ent'], edgelist_df['tail_ent'], edgelist_df['relation'])
#
#    for e in edge_data:
#        src, dst, edge_label = e[0], e[1], e[2]
#
#        nt.add_node(src, src, title=src)
#        nt.add_node(dst, dst, title=dst)
#        nt.add_edge(src, dst, title=edge_label, arrowStrikethrough=True)
#
#    for node in nt.nodes:
#        outbound_rels_df = edgelist_df.loc[edgelist_df["tail_ent"] == node["id"], ["relation","head_ent"]]
#        outbound_rels = (outbound_rels_df["relation"] + " -> " + outbound_rels_df["head_ent"]).tolist()
#        node_docs = lookup_df.loc[lookup_df["entity"] == node["id"], "found_in"].item()
#        
#        node['title'] += '\n\nSuggested relations:\n'+'\n'.join(outbound_rels)+'\n\nFound in:\n'+'\n'.join(node_docs)
#        node['value'] = len(node_docs) # circle diameter based on document frequency
#
#    out_cont_sas_tkn = generate_container_sas(account_name=blob_account_name, 
#                            container_name=output_container,
#                            account_key=storage_account_key,
#                            permission=ContainerSasPermissions(read=True, list=True, write=True, add=True, create=True, update=True),
#                            expiry=datetime.utcnow() + timedelta(hours=1))
#    connection_string = f'DefaultEndpointsProtocol=https;AccountName={blob_account_name};AccountKey={storage_account_key};EndpointSuffix={azure_storage_domain}'
#    blob = BlobClient.from_connection_string(conn_str=connection_string, container_name=f'{output_container}', blob_name=f'{batch_num}/{batch_num}_relex_graph.html', credential=out_cont_sas_tkn)
#
#    file_name = f'{batch_num}_relex_graph.html'
#    #Write XML to file 
#    with open(file_name, mode='w') as f:
#        f.write(nt.generate_html())
#    
#    # Write xml to output_container
#    with open(file_name, "rb") as data:
#        blob.upload_blob(data)

In [ ]:
from time import sleep
import datetime
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ?
        ORDER BY [num_stages_complete] DESC;
    """
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Text Relation Extraction Complete', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)