In [None]:
#| default_exp banchmark.covid

In [None]:
#| hide
from nbdev.showdoc import show_doc
from IPython.display import display, HTML
%load_ext autoreload
%autoreload 2
from itables import init_notebook_mode,show
init_notebook_mode(all_interactive=False,connected=False)


In [None]:
#| exports
# importing dependencies
import re
import csv
import time
import pandas as pd
from pandas import DataFrame
from pathlib import Path
from spannerlib import get_magic_session,Session
from spannerlib.ie_func.basic import rgx, rgx_is_match, rgx_split, span_arity, span_contained

VERSION = "OLD"
# VERSION = "SPANNERFLOW"
VERSION = "SPANNERFLOW_PYTHON_IE"
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    from spannerflow.span import Span
else:
    from spannerlib import Span
sess = get_magic_session()


sess.register('py_rgx', rgx, [str, Span], span_arity)
sess.register('py_rgx_split', rgx_split, [str, Span], [Span,Span])
sess.register('py_rgx_is_match', rgx_is_match, [str, Span], [bool])
sess.register('py_span_contained', span_contained, [Span, Span], [bool])    

import spacy
nlp = spacy.load("en_core_web_sm")


In [None]:
start_time = time.time()

In [None]:
#|exports
# configurations
input_dir = Path('covid_data/sample_inputs')
data_dir = Path('covid_data/rules_data')

In [None]:
#| exports
def split_sentence(text):
    """
    Splits a text into individual sentences. using spacy's sentence detection.
    
    Returns:
        str: Individual sentences extracted from the input text.
    """

    doc = nlp(str(text))
    start = 0
    for sentence in doc.sents:
        end = start+len(sentence.text)
        # note that we yield a Span object, so we can keep track of the locations of the sentences
        yield Span(text,start,end)
        start = end + 1

In [None]:
#| exports
class LemmaFromList():
    def __init__(self,lemma_list):
        self.lemma_list = lemma_list

    def __call__(self,text):
        doc = nlp(str(text))
        for word in doc:
            start = word.idx
            end = start + len(word.text)
            if word.lemma_ in self.lemma_list:
                yield (Span(text,start,end),word.lemma_)
            elif word.like_num:
                yield (Span(text,start,end),'like_num')
            else:
                pass

lemma_list = (data_dir/'lemma_words.txt').read_text().split()
lemmatizer = LemmaFromList(lemma_list)

In [None]:
#| exports
class PosFromList():
    def __init__(self,pos_list):
        self.pos_list = pos_list
    def __call__(self,text):
        doc = nlp(str(text))
        for word in doc:
            start = word.idx
            end = start + len(word.text)
            if word.pos_ in self.pos_list:
                yield (Span(text,start,end),word.pos_)

pos_annotator = PosFromList(["NOUN", "PROPN", "PRON", "ADJ"])

In [None]:
#| exports
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.register('split_sentence',split_sentence,[Span],[Span])
    sess.register('pos',pos_annotator,[Span],[Span,str])
    sess.register('lemma',lemmatizer,[Span],[Span,str])
else:
    sess.register('split_sentence',split_sentence,[str],[Span])
    sess.register('pos',pos_annotator,[str],[Span,str])
    sess.register('lemma',lemmatizer,[str],[Span,str])

In [None]:
#| exports
def rewrite(text,span_label_pairs):
    """rewrites a string given a dataframe with spans and the string to rewrite them to
    assumes that the spans belong to the text

    Args:
        text (str like): string to rewrite
        span_label_pairs (pd.Dataframe) dataframe with two columns, first is spans in the doc to rewrite
            second is what to rewrite to
    Returns:
        The rewritten string
    """    
    if isinstance(text,Span):
        text = text.as_str()
    span_label_pairs = sorted(list(span_label_pairs.itertuples(index=False,name=None)), key=lambda x: x[0].start)

    rewritten_text = ''
    current_pos = 0
    for span,label in span_label_pairs:
        rewritten_text += text[current_pos:span.start] + label 
        current_pos = span.end

    rewritten_text += text[current_pos:]

    return rewritten_text


In [None]:
#| export
def rewrite_docs(docs,span_label,new_version):
    """Given a dataframe of documents of the form (path,doc,version) and a dataframe of spans to rewrite
    of the form (path,word,from_span,to_tag), rewrites the documents and returns a new dataframe of the form
    (path,doc,new_version)

    """
    new_tuples =[]
    span_label.columns = ['P','D','W','L']
    for path,doc,_ in docs.itertuples(index=False,name=None):
        span_label_per_doc = span_label[span_label['P'] == path][['W','L']]
        new_text = rewrite(doc,span_label_per_doc)
        new_tuples.append((path,new_text,new_version))
    return pd.DataFrame(new_tuples,columns=['P','D','V'])
    

In [None]:
#| export
sess.import_rel("ConceptTagRules",data_dir/"concept_tags_rules.csv" , delim=",")
sess.import_rel("TargetTagRules",data_dir/"target_rules.csv",delim=",")
sess.import_rel("SectionTags",data_dir/"section_tags.csv",delim=",")
sess.import_rel("PositiveSectionTags",data_dir/"positive_section_tags.csv",delim=",")
sess.import_rel("SentenceContextRules",data_dir/'sentence_context_rules.csv',delim="#")
sess.import_rel("PostprocessPatternRules",data_dir/'postprocess_pattern_rules.csv',delim="#")
sess.import_rel("PostprocessRulesWithAttributes",data_dir/'postprocess_attributes_rules.csv',delim="#")
sess.import_rel("NextSentencePostprocessPatternRules",data_dir/'postprocess_pattern_next_sentence_rules.csv',delim=',')


In [None]:
#| export
from glob import glob
file_paths = [Path(p) for p in glob(str(input_dir/'*.txt'))]
raw_docs = pd.DataFrame([
    [p.name,p.read_text(),'raw_text'] for p in file_paths
],columns=['Path','Doc','Version']
)
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.import_rel('Docs',raw_docs, scheme=[str, Span, str])
else:
    sess.import_rel('Docs',raw_docs)

In [None]:
%%spannerlog
Lemmas(P,D,Word,Lem)<-Docs(P,D,"raw_text"),lemma(D)->(Word,Lem).

In [None]:
#| export
lemma_tags = sess.export('?Lemmas(P,D,W,L)')
lemma_docs = rewrite_docs(raw_docs,lemma_tags,'lemma')
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.import_rel('Docs',lemma_docs, scheme=[str, Span, str])
else:
    sess.import_rel('Docs',lemma_docs)


In [None]:

%%spannerlog
LemmaConceptMatches(Path,Doc,Span,Label) <- 
    Docs(Path,Doc,"lemma"),
    ConceptTagRules(Pattern, Label, "lemma"),
    # TODO CHANGE: on different version
    py_rgx(Pattern,Doc) -> (Span).

In [None]:
#| export
lemma_concept_matches = sess.export('?LemmaConceptMatches(Path,Doc,Span,Label)')
lemma_concepts = rewrite_docs(lemma_docs,lemma_concept_matches,'lemma_concept')
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.import_rel('Docs',lemma_concepts, scheme=[str, Span, str])
else:
    sess.import_rel('Docs',lemma_concepts)

In [None]:
%%spannerlog
# here we get the spans of all POS
Pos(P,D,Word,Lem)<-Docs(P,D,"lemma_concept"),pos(D)->(Word,Lem).

In [None]:
%%spannerlog
# here we look for concept rule matches where the matched word is also tagged via POS
PosConceptMatches(Path,Doc,Span,Label) <- 
    Docs(Path,Doc,"lemma_concept"),
    ConceptTagRules(Pattern, Label, "pos"),
    # TODO CHANGE: on different version
    py_rgx(Pattern,Doc) -> (Span),
    Pos(Path,Doc,Span,POSLabel).

In [None]:
#| export
pos_concept_matches = sess.export('?PosConceptMatches(P,D,W,L)')
pos_concept_docs = rewrite_docs(lemma_concepts,pos_concept_matches,'pos_concept')
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.import_rel('Docs',pos_concept_docs, scheme=[str, Span, str])
else:
    sess.import_rel('Docs',pos_concept_docs)

In [None]:
%%spannerlog
TargetMatches(Path,Doc, Span, Label) <- 
    Docs(Path,Doc,"pos_concept"),
    # TODO CHANGE: on different version
    TargetTagRules(Pattern, Label), py_rgx(Pattern,Doc) -> (Span).

In [None]:
#| export
target_matches = sess.export('?TargetMatches(P,D,W,L)')
target_rule_docs = rewrite_docs(pos_concept_docs,target_matches,'target_concept')
if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.import_rel('Docs',target_rule_docs, scheme=[str, Span, str])
else:
    sess.import_rel('Docs',target_rule_docs)

In [None]:
#| export
section_tags = pd.read_csv(data_dir/'section_tags.csv',names=['literal','tag'])

In [None]:
#| export
# we will programatically build a regex that matches all the section patterns
section_delimeter_pattern = section_tags['literal'].str.cat(sep='|')
sess.import_var('section_delimeter_pattern',section_delimeter_pattern)
section_delimeter_pattern

In [None]:
%%spannerlog
# we get section spans and their content using our regex pattern and the py_rgx_split ie function
Sections(P,D,Sec,Content)<-Docs(P,D,"target_concept"),
    py_rgx_split($section_delimeter_pattern,D)->(SecSpan,Content),
    as_str(SecSpan)->(Sec).

PositiveSections(P,D,Sec,Content)<-Sections(P,D,Sec,Content),SectionTags(Sec,Tag),PositiveSectionTags(Tag).


In [None]:
%%spannerlog
Sents(P,S)<-Docs(P,D,"target_concept"),split_sentence(D)->(S).


In [None]:
from itertools import pairwise

def sentence_pairs(text):
    yield from pairwise(split_sentence(text))

if VERSION in ["SPANNERFLOW", "SPANNERFLOW_PYTHON_IE"]:
    sess.register('sentence_pairs',sentence_pairs,[Span],[Span,Span])
else:
    sess.register('sentence_pairs',sentence_pairs,[str],[Span,Span])

In [None]:
def is_adjacent(span1,span2):
    yield span1.name==span2.name and span1.end +1 == span2.start

sess.register('is_adjacent',is_adjacent,[Span,Span],[bool])

In [None]:
%%spannerlog
SentPairs(P,S1,S2)<-Sents(P,S1),Sents(P,S2),is_adjacent(S1,S2)->(True).


In [None]:
%%spannerlog
# first we get the covid mentions and their surrounding sentences, using the py_span_contained ie function
# TODO CHANGE: on different version
CovidMentions(Path, Span) <- Docs(Path,D,"target_concept"), py_rgx("COVID-19",D) -> (Span).

# TODO CHANGE: on different version
CovidMentionSents(P,Mention,Sent)<-CovidMentions(P,Mention),Sents(P,Sent),py_span_contained(Mention,Sent)->(True).


In [None]:
%%spannerlog

# note that for ease of debugging, we extended our head to track which rule a fact was derived from

# a tag is positive if it is contained in a positive section
CovidTags(Path,Mention,'positive','section')<-
    PositiveSections(Path,D,Title,Section),
    CovidMentions(Path,Mention),
    # TODO CHANGE: on different version
    py_span_contained(Mention,Section)->(True).

# Context rules tags
CovidTags(Path,Mention,Tag,'sentence context')<-
    CovidMentionSents(Path,Mention,Sent),
    SentenceContextRules(Pattern,Tag,DisambiguationPattern),
    # TODO CHANGE: on different version
    py_rgx(Pattern,Sent)->(ContextSpan),
    # TODO CHANGE: on different version
    py_span_contained(Mention,ContextSpan)->(True),
    # TODO CHANGE: on different version
    py_rgx_is_match(DisambiguationPattern,Sent)->(False).

# post processing based on pattern
CovidTags(Path,Mention,Tag,'post pattern')<-
    CovidMentionSents(Path,Mention,Sent),
    PostprocessPatternRules(Pattern,Tag),
    # TODO CHANGE: on different version
    py_rgx(Pattern,Sent)->(ContextSpan),
    # TODO CHANGE: on different version
    py_span_contained(Mention,ContextSpan)->(True).

# post processing based on pattern and existing attributes
# notice the recursive call to CovidTags
CovidTags(Path,Mention,Tag,"post attribute change")<-
    CovidTags(Path,Mention,OldTag,Derivation),
    PostprocessRulesWithAttributes(Pattern,OldTag,Tag),
    CovidMentionSents(Path,Mention,Sent),
    # TODO CHANGE: on different version
    py_rgx(Pattern,Sent)->(ContextSpan),
    # TODO CHANGE: on different version
    py_span_contained(Mention,ContextSpan)->(True).


# post processing based on pattern in the next sentence
CovidTags(Path,Mention,Tag,"next sentence")<-
    CovidMentionSents(Path,Mention,Sent),
    SentPairs(Path,Sent,NextSent),
    PostprocessPatternRules(Pattern,Tag),
    # TODO CHANGE: on different version
    py_rgx(Pattern,NextSent)->(ContextSpan).


In [None]:
#| export
def agg_mention(group):
    """
    aggregates attribute groups of covid spans
    """
    if 'IGNORE' in group:
        return 'IGNORE'
    elif 'negated' in group and not 'no_negated' in group:
        return 'negated'
    elif 'future' in group and not 'no_future' in group:
        return 'negated'
    elif 'other experiencer' in group or 'not relevant' in group:
        return 'negated'
    elif 'positive' in group and not 'uncertain' in group and not 'no_positive' in group:
        return 'positive'
    else:
        return 'uncertain'

#| export
def AggDocumentTags(group):
    """
    Classifies a document as 'POS', 'UNK', or 'NEG' based on COVID-19 attributes.
    """
    if 'positive' in group:
        return 'POS'
    elif 'uncertain' in group:
        return 'UNK'
    elif 'negated' in group:
        return 'NEG'
    else:
        return 'UNK'


sess.register_agg('agg_mention',agg_mention,[str],[str])
sess.register_agg('agg_doc_tags',AggDocumentTags,[str],[str])

In [None]:
%%spannerlog
AggregatedCovidTags(Path,Mention,agg_mention(Tag))<-
    CovidTags(Path,Mention,Tag,Derivation).

DocumentTags(Path,agg_doc_tags(Tag))<-
    AggregatedCovidTags(Path,Mention,Tag).


In [None]:
#| export
doc_tags = sess.export('?DocumentTags(P,T)')

In [None]:
#| export
paths = pd.DataFrame([p.name for p in file_paths],columns=['P'])
classification = paths.merge(doc_tags,on='P',how='outer')
classification['T']=classification['T'].fillna('UNK')
classification

In [None]:
end_time = time.time()
print(f"Number of Documents: {len(file_paths)}")
print(f"Time taken: {end_time-start_time:.2f} seconds")

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     