In [24]:
from sparknlp.internal import ExtendedJavaWrapper
from sparknlp.common import ExternalResource, ReadAs
from pyspark.sql import SparkSession, DataFrame
import conll_eval

class CoNLL(ExtendedJavaWrapper):
    def __init__(self,
                 documentCol = 'document',
                 sentenceCol = 'sentence',
                 tokenCol = 'token',
                 posCol = 'pos',
                 conllLabelIndex = 3,
                 conllPosIndex = 1,
                 textCol = 'text',
                 labelCol = 'label',
                 explodeSentences = True,
                 ):
        super(CoNLL, self).__init__("com.johnsnowlabs.nlp.training.CoNLL",
                                    documentCol,
                                    sentenceCol,
                                    tokenCol,
                                    posCol,
                                    conllLabelIndex,
                                    conllPosIndex,
                                    textCol,
                                    labelCol,
                                    explodeSentences)

    def readDataset(self, spark, path, read_as=ReadAs.TEXT):

        # ToDo Replace with std pyspark
        jSession = spark._jsparkSession

        jdf = self._java_obj.readDataset(jSession, path, read_as)
        return DataFrame(jdf, spark._wrapped)
    
    
    def calculate_metrics(self, preds_df: DataFrame):

        df = preds_df.select(F.explode(F.arrays_zip('label.result','ner.result')).alias("cols")) \
        .select(F.expr("cols['0']").alias("ground_truth"), F.expr("cols['1']").alias("prediction")).toPandas()

        # We are going to use sklearn to evalute the results on test dataset
        from sklearn.metrics import classification_report

        classification_report = classification_report(list(df.ground_truth), list(df.prediction))
        
        print ('Chunk level metrics:')
        (prec, rec, f1), metrics_dict = conll_eval.evaluate(list(df.ground_truth), list(df.prediction), verbose=True)

        return {'precision':prec, 'recall':rec, 'F1':f1, 'chunk_metrics':metrics_dict, 'label_classification_report':classification_report}


class POS(ExtendedJavaWrapper):
    def __init__(self):
        super(POS, self).__init__("com.johnsnowlabs.nlp.training.POS")

    def readDataset(self, spark, path, delimiter="|", outputPosCol="tags", outputDocumentCol="document", outputTextCol="text"):

        # ToDo Replace with std pyspark
        jSession = spark._jsparkSession

        jdf = self._java_obj.readDataset(jSession, path, delimiter, outputPosCol, outputDocumentCol, outputTextCol)
        return DataFrame(jdf, spark._wrapped)


In [26]:
spark_preds_df = spark.read.parquet('/Users/vkocaman/Python_Projects/John_Snow_Labs/Genentech/NER_train/data/deid_preds_df.parquet')

In [27]:
c = CoNLL()

metrics = c.calculate_metrics(spark_preds_df)

Chunk level metrics:
processed 289323 tokens with 7450 phrases; found: 7321 phrases; correct: 7007.
accuracy:  92.97%; (non-O)
accuracy:  99.65%; precision:  95.71%; recall:  94.05%; FB1:  94.88
              AGE: precision:  95.83%; recall:  93.54%; FB1:  94.67  408
          CONTACT: precision:  90.65%; recall:  87.39%; FB1:  88.99  107
             DATE: precision:  97.50%; recall:  97.06%; FB1:  97.28  3321
               ID: precision:  92.97%; recall:  93.27%; FB1:  93.12  313
         LOCATION: precision:  93.41%; recall:  86.43%; FB1:  89.78  1077
             NAME: precision:  95.35%; recall:  95.54%; FB1:  95.44  1998
       PROFESSION: precision:  81.44%; recall:  68.70%; FB1:  74.53  97


In [30]:
print (metrics['chunk_metrics'])

{'AGE': {'precision': 95.83333333333334, 'recall': 93.54066985645933, 'F1': 94.67312348668281}, 'CONTACT': {'precision': 90.65420560747664, 'recall': 87.38738738738738, 'F1': 88.9908256880734}, 'DATE': {'precision': 97.50075278530564, 'recall': 97.06235011990407, 'F1': 97.28105753342346}, 'ID': {'precision': 92.97124600638978, 'recall': 93.26923076923077, 'F1': 93.12}, 'LOCATION': {'precision': 93.40761374187558, 'recall': 86.42611683848797, 'F1': 89.78134761267292}, 'NAME': {'precision': 95.34534534534535, 'recall': 95.53660982948846, 'F1': 95.44088176352705}, 'PROFESSION': {'precision': 81.44329896907216, 'recall': 68.69565217391305, 'F1': 74.52830188679246}}


In [29]:
print (metrics['label_classification_report'])

              precision    recall  f1-score   support

       B-AGE       0.94      0.95      0.94       403
   B-CONTACT       0.86      0.87      0.86        84
      B-DATE       0.98      0.97      0.97      3215
        B-ID       0.91      0.90      0.91       217
  B-LOCATION       0.90      0.80      0.85       920
      B-NAME       0.93      0.96      0.94      1834
B-PROFESSION       0.88      0.75      0.81       114
       I-AGE       1.00      0.05      0.09        22
   I-CONTACT       0.93      0.90      0.92        91
      I-DATE       0.92      0.96      0.94       732
        I-ID       0.96      0.91      0.94       117
  I-LOCATION       0.91      0.91      0.91      1205
      I-NAME       0.95      0.94      0.95      1919
I-PROFESSION       0.94      0.70      0.80       110
           O       1.00      1.00      1.00    278340

    accuracy                           1.00    289323
   macro avg       0.93      0.84      0.85    289323
weighted avg       1.00   

In [15]:
%%writefile conll_eval.py
from __future__ import division, print_function, unicode_literals

import sys
from collections import defaultdict
import pyspark.sql.functions as F
from pyspark.sql import DataFrame

'''

    #https://github.com/sighsmile/conlleval
    """
    This script applies to IOB2 or IOBES tagging scheme.
    If you are using a different scheme, please convert to IOB2 or IOBES.
    IOB2:
    - B = begin, 
    - I = inside but not the first, 
    - O = outside
    e.g. 
    John   lives in New   York  City  .
    B-PER  O     O  B-LOC I-LOC I-LOC O
    IOBES:
    - B = begin, 
    - E = end, 
    - S = singleton, 
    - I = inside but not the first or the last, 
    - O = outside
    e.g.
    John   lives in New   York  City  .
    S-PER  O     O  B-LOC I-LOC E-LOC O
    prefix: IOBES
    chunk_type: PER, LOC, etc.
    """

'''
def split_tag(chunk_tag):
    """
    split chunk tag into IOBES prefix and chunk_type
    e.g. 
    B-PER -> (B, PER)
    O -> (O, None)
    """
    if chunk_tag == 'O':
        return ('O', None)
    return chunk_tag.split('-', maxsplit=1)

def is_chunk_end(prev_tag, tag):
    """
    check if the previous chunk ended between the previous and current word
    e.g. 
    (B-PER, I-PER) -> False
    (B-LOC, O)  -> True
    Note: in case of contradicting tags, e.g. (B-PER, I-LOC)
    this is considered as (B-PER, B-LOC)
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix1 == 'O':
        return False
    if prefix2 == 'O':
        return prefix1 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']

def is_chunk_start(prev_tag, tag):
    """
    check if a new chunk started between the previous and current word
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix2 == 'O':
        return False
    if prefix1 == 'O':
        return prefix2 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']


def calc_metrics(tp, p, t, percent=True):
    """
    compute overall precision, recall and FB1 (default values are 0.0)
    if percent is True, return 100 * original decimal value
    """
    precision = tp / p if p else 0
    recall = tp / t if t else 0
    fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
    if percent:
        return 100 * precision, 100 * recall, 100 * fb1
    else:
        return precision, recall, fb1


def count_chunks(true_seqs, pred_seqs):
    """
    true_seqs: a list of true tags
    pred_seqs: a list of predicted tags
    return: 
    correct_chunks: a dict (counter), 
                    key = chunk types, 
                    value = number of correctly identified chunks per type
    true_chunks:    a dict, number of true chunks per type
    pred_chunks:    a dict, number of identified chunks per type
    correct_counts, true_counts, pred_counts: similar to above, but for tags
    """
    correct_chunks = defaultdict(int)
    true_chunks = defaultdict(int)
    pred_chunks = defaultdict(int)

    correct_counts = defaultdict(int)
    true_counts = defaultdict(int)
    pred_counts = defaultdict(int)

    prev_true_tag, prev_pred_tag = 'O', 'O'
    correct_chunk = None

    for true_tag, pred_tag in zip(true_seqs, pred_seqs):
        if true_tag == pred_tag:
            correct_counts[true_tag] += 1
        true_counts[true_tag] += 1
        pred_counts[pred_tag] += 1

        _, true_type = split_tag(true_tag)
        _, pred_type = split_tag(pred_tag)

        if correct_chunk is not None:
            true_end = is_chunk_end(prev_true_tag, true_tag)
            pred_end = is_chunk_end(prev_pred_tag, pred_tag)

            if pred_end and true_end:
                correct_chunks[correct_chunk] += 1
                correct_chunk = None
            elif pred_end != true_end or true_type != pred_type:
                correct_chunk = None

        true_start = is_chunk_start(prev_true_tag, true_tag)
        pred_start = is_chunk_start(prev_pred_tag, pred_tag)

        if true_start and pred_start and true_type == pred_type:
            correct_chunk = true_type
        if true_start:
            true_chunks[true_type] += 1
        if pred_start:
            pred_chunks[pred_type] += 1

        prev_true_tag, prev_pred_tag = true_tag, pred_tag
    if correct_chunk is not None:
        correct_chunks[correct_chunk] += 1

    return (correct_chunks, true_chunks, pred_chunks, 
        correct_counts, true_counts, pred_counts)

def get_result(correct_chunks, true_chunks, pred_chunks,
    correct_counts, true_counts, pred_counts, verbose=True):
    """
    if verbose, print overall performance, as well as preformance per chunk type;
    otherwise, simply return overall prec, rec, f1 scores
    """
    # sum counts
    sum_correct_chunks = sum(correct_chunks.values())
    sum_true_chunks = sum(true_chunks.values())
    sum_pred_chunks = sum(pred_chunks.values())

    sum_correct_counts = sum(correct_counts.values())
    sum_true_counts = sum(true_counts.values())

    nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O')
    nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O')

    chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))

    # compute overall precision, recall and FB1 (default values are 0.0)
    prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)
    res = (prec, rec, f1)
    if not verbose:
        return res

    # print overall performance, and performance per chunk type

    print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='')
    print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='')

    print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts))
    print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='')
    print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1))

    # for each chunk type, compute precision, recall and FB1 (default values are 0.0)
    metrics_dict = {}
    for t in chunk_types:
        prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])
        print("%17s: " %t , end='')
        print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" %
                    (prec, rec, f1), end='')
        print("  %d" % pred_chunks[t])
        metrics_dict[t] = {'precision':prec, 'recall':rec, 'F1':f1}

    return res, metrics_dict
    # you can generate LaTeX output for tables like in
    # http://cnts.uia.ac.be/conll2003/ner/example.tex
    # but I'm not implementing this

def evaluate(true_seqs, pred_seqs, verbose=True):
    (correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)
    result = get_result(correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts, verbose=verbose)
    return result

def evaluate_conll_file(fileIterator):
    true_seqs, pred_seqs = [], []

    for line in fileIterator:
        cols = line.strip().split()
        # each non-empty line must contain >= 3 columns
        if not cols:
            true_seqs.append('O')
            pred_seqs.append('O')
        elif len(cols) < 3:
            raise IOError("conlleval: too few columns in line %s\n" % line)
        else:
            # extract tags from last 2 columns
            true_seqs.append(cols[-2])
            pred_seqs.append(cols[-1])
    return evaluate(true_seqs, pred_seqs)






Writing conll_eval.py


In [3]:
import sparknlp
spark=sparknlp.start()

In [4]:
CoNLL()

<__main__.CoNLL at 0x113327350>

In [5]:
import pandas as pd

df = pd.read_pickle('/Users/vkocaman/Python_Projects/John_Snow_Labs/Genentech/NER_train/data/deid_preds_df.pickle')

In [7]:
spark_preds_df = spark.createDataFrame(df)

In [11]:
ev = NerEval()

TypeError: __init__() missing 1 required positional argument: 'java_obj'