Skip to content

Commit

Permalink
Merge pull request #31 from allenai/converters-main-code-refactoring
Browse files Browse the repository at this point in the history
converter's main code refactoring
  • Loading branch information
aryehgigi committed Feb 8, 2021
2 parents 708a39a + 7d51bda commit c280589
Show file tree
Hide file tree
Showing 17 changed files with 2,501 additions and 2,736 deletions.
Binary file removed dist/pybart-nlp-2.2.6.tar.gz
Binary file not shown.
Binary file removed dist/pybart_nlp-2.2.6-py3-none-any.whl
Binary file not shown.
24 changes: 12 additions & 12 deletions pybart/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import math

from .conllu_wrapper import parse_conllu, serialize_conllu, parse_odin, conllu_to_odin, parsed_tacred_json
from .converter import convert, ConvsCanceler
from .converter import convert, get_conversion_names as inner_get_conversion_names


def convert_bart_conllu(conllu_text, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, preserve_comments=False, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=ConvsCanceler()):
def convert_bart_conllu(conllu_text, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, preserve_comments=False, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=None):
parsed, all_comments = parse_conllu(conllu_text)
converted, _ = convert(parsed, enhance_ud, enhanced_plus_plus, enhanced_extra, conv_iterations, remove_eud_info, remove_extra_info, remove_node_adding_conversions, remove_unc, query_mode, funcs_to_cancel)
return serialize_conllu(converted, all_comments, preserve_comments)
Expand All @@ -16,7 +16,7 @@ def _convert_bart_odin_sent(doc, enhance_ud, enhanced_plus_plus, enhanced_extra,
return conllu_to_odin(converted_sents, doc)


def convert_bart_odin(odin_json, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=ConvsCanceler()):
def convert_bart_odin(odin_json, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=None):
if "documents" in odin_json:
for doc_key, doc in odin_json["documents"].items():
odin_json["documents"][doc_key] = _convert_bart_odin_sent(doc, enhance_ud, enhanced_plus_plus, enhanced_extra, conv_iterations, remove_eud_info, remove_extra_info, remove_node_adding_conversions, remove_unc, query_mode, funcs_to_cancel)
Expand All @@ -26,36 +26,36 @@ def convert_bart_odin(odin_json, enhance_ud=True, enhanced_plus_plus=True, enhan
return odin_json


def convert_bart_tacred(tacred_json, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=ConvsCanceler()):
def convert_bart_tacred(tacred_json, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=None):
sents = parsed_tacred_json(tacred_json)
converted_sents, _ = convert(sents, enhance_ud, enhanced_plus_plus, enhanced_extra, conv_iterations, remove_eud_info, remove_extra_info, remove_node_adding_conversions, remove_unc, query_mode, funcs_to_cancel)

return converted_sents


def convert_spacy_doc(doc, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=ConvsCanceler()):
def convert_spacy_doc(doc, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=None):
from .spacy_wrapper import parse_spacy_sent, serialize_spacy_doc
parsed_doc = [parse_spacy_sent(sent) for sent in doc.sents]
converted, convs_done = convert(parsed_doc, enhance_ud, enhanced_plus_plus, enhanced_extra, conv_iterations, remove_eud_info, remove_extra_info, remove_node_adding_conversions, remove_unc, query_mode, funcs_to_cancel)
return serialize_spacy_doc(doc, converted), parsed_doc, convs_done
return serialize_spacy_doc(doc, converted), converted, convs_done


class Converter:
def __init__(self, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=ConvsCanceler()):
def __init__(self, enhance_ud=True, enhanced_plus_plus=True, enhanced_extra=True, conv_iterations=math.inf, remove_eud_info=False, remove_extra_info=False, remove_node_adding_conversions=False, remove_unc=False, query_mode=False, funcs_to_cancel=None):
self.config = (enhance_ud, enhanced_plus_plus, enhanced_extra, conv_iterations, remove_eud_info, remove_extra_info, remove_node_adding_conversions, remove_unc, query_mode, funcs_to_cancel)

def __call__(self, doc):
serialized_spacy_doc, parsed_doc, convs_done = convert_spacy_doc(doc, *self.config)
self._parsed_doc = parsed_doc
serialized_spacy_doc, converted_sents, convs_done = convert_spacy_doc(doc, *self.config)
self._converted_sents = converted_sents
self._convs_done = convs_done
return serialized_spacy_doc

def get_parsed_doc(self):
return self._parsed_doc
def get_converted_sents(self):
return self._converted_sents

def get_max_convs(self):
return self._convs_done


def get_conversion_names():
return ConvsCanceler.get_conversion_names()
return inner_get_conversion_names()
91 changes: 48 additions & 43 deletions pybart/conllu_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from .graph_token import Token, add_basic_edges
from .graph_token import Token, add_basic_edges, TokenId


def parse_conllu(text):
Expand All @@ -26,7 +26,7 @@ def parse_conllu(text):
if not lines:
continue
comments = []
sentence = dict()
sentence = []

# for each line (either comment or token)
for line in lines:
Expand Down Expand Up @@ -54,11 +54,11 @@ def parse_conllu(text):
xpos = upos if xpos == '_' else xpos

# add current token to current sentence
sentence[int(new_id)] = Token(
int(new_id), form, lemma, upos, xpos, feats, int(head), deprel, deps, misc)
sentence.append(Token(
TokenId(int(new_id)), form, lemma, upos, xpos, feats, TokenId(int(head)), deprel, deps, misc))

# add root
sentence[0] = Token(0, None, None, None, None, None, None, None, None, None)
sentence.append(Token(TokenId(0), None, None, None, None, None, None, None, None, None))

# after parsing entire sentence, add basic deprel edges,
# and add sentence to output list
Expand All @@ -85,8 +85,7 @@ def serialize_conllu(converted, all_comments, preserve_comments=False):
if preserve_comments:
comments = ["\n".join(per_sent_comments)]

# TODO - fix case of more than 9 copy nodes - needs special ordering e.g 1.1 ... 1.9 1.10 and not 1.1 1.10 ... 1.9
text.append(comments + [token.get_conllu_string() for (cur_id, token) in sorted(sentence.items()) if cur_id != 0])
text.append(comments + [token.get_conllu_string() for token in sorted(sentence, key=lambda tok: tok.get_conllu_field("id")) if token.get_conllu_field("id").major != 0])

return "\n".join(["\n".join(sent) + "\n" for sent in text])

Expand All @@ -96,48 +95,53 @@ def serialize_conllu(converted, all_comments, preserve_comments=False):
def parse_odin(odin_json):
sentences = []
for sent in odin_json['sentences']:
sentence = {0: Token(0, None, None, None, None, None, None, None, None, None)}
sentence = list()
for i, (word, tag, lemma) in enumerate(zip(sent['words'], sent['tags'], sent['lemmas'])):
sentence[i + 1] = Token(i + 1, word, lemma, "_", tag, "_", "_", "_", "_", "_")
sentence.append(Token(TokenId(i + 1), word, lemma, "_", tag, "_", None, "_", "_", "_"))
for edge in sent['graphs']['universal-basic']['edges']:
sentence[edge['destination'] + 1].set_conllu_field('head', edge['source'] + 1)
sentence[edge['destination'] + 1].set_conllu_field('head', TokenId(edge['source'] + 1))
sentence[edge['destination'] + 1].set_conllu_field('deprel', edge['relation'])
for root in sent['graphs']['universal-basic']['roots']:
sentence[root + 1].set_conllu_field('head', 0)
sentence[root + 1].set_conllu_field('head', TokenId(0))
sentence[root + 1].set_conllu_field('deprel', "root")

sentence.append(Token(TokenId(0), None, None, None, None, None, None, None, None, None))

add_basic_edges(sentence)
sentences.append(sentence)

return sentences


def _fix_sentence_keep_order(conllu_sentence):
sorted_sent = sorted(conllu_sentence.items())
sorted_sent = sorted(conllu_sentence)
addon = 0
fixed = dict()
fixed = list()

for iid, token in sorted_sent:
if round(iid) != iid:
for token in sorted_sent:
iid = token.get_conllu_field("id")
if token.get_conllu_field("id").minor != 0:
if "CopyOf" in token.get_conllu_field("misc"):
token.set_conllu_field("form", token.get_conllu_field("form") + "[COPY_NODE]")
addon += 1

new_id = round(iid) + addon
token.set_conllu_field("id", new_id)
fixed[new_id] = token
new_id = iid.major + addon
token.set_conllu_field("id", TokenId(new_id))
fixed.append(token)

return fixed


def _fix_sentence_push_to_end(conllu_sentence):
fixed = dict()

for i, (iid, token) in enumerate([(iid2, t) for (iid2, t) in conllu_sentence.items() if iid2 != 0]):
if round(iid) != iid:
token.set_conllu_field("id", i + 1)
fixed = list()

for i, token in enumerate(conllu_sentence):
iid = token.get_conllu_field("id")
if iid.major == 0:
continue
if iid.get_conllue_field("id").major != 0:
token.set_conllu_field("id", TokenId(i + 1))

fixed[i + 1] = token
fixed.append(token)

return fixed

Expand All @@ -158,24 +162,25 @@ def fix_graph(conllu_sentence, odin_sentence, is_basic):
else:
odin_sentence["graphs"] = {"universal-enhanced": {"edges": [], "roots": []}}

for iid, token in conllu_sentence.items():
if iid == 0:
for iid, token in enumerate(conllu_sentence):
if token.get_conllu_field("id").major == 0:
continue

if is_basic:
if token.get_conllu_field("deprel").lower().startswith("root"):
odin_sentence["graphs"]["universal-basic"]["roots"].append(iid - 1)
odin_sentence["graphs"]["universal-basic"]["roots"].append(iid)
else:
odin_sentence["graphs"]["universal-basic"]["edges"].append(
{"source": token.get_conllu_field("head") - 1, "destination": iid - 1,
{"source": token.get_conllu_field("head").major - 1, "destination": iid,
"relation": token.get_conllu_field("deprel")})
else:
for head, rel in token.get_new_relations():
if rel.lower().startswith("root"):
odin_sentence["graphs"]["universal-enhanced"]["roots"].append(iid - 1)
else:
odin_sentence["graphs"]["universal-enhanced"]["edges"].append(
{"source": head.get_conllu_field("id") - 1, "destination": iid - 1, "relation": rel})
for head, rels in token.get_new_relations():
for rel in rels:
if rel.to_str().lower().startswith("root"):
odin_sentence["graphs"]["universal-enhanced"]["roots"].append(iid)
else:
odin_sentence["graphs"]["universal-enhanced"]["edges"].append(
{"source": head.get_conllu_field("id").major - 1, "destination": iid, "relation": rel.to_str()})

return odin_sentence

Expand All @@ -184,8 +189,8 @@ def append_odin(odin_sent, fixed_sentence, text):
cur_sent_text = text
cur_offset = 0

for node in list(fixed_sentence.values())[len(odin_sent['words']):]:
if node.get_conllu_field('id') == 0:
for node in fixed_sentence[len(odin_sent['words']):]:
if node.get_conllu_field('id').major == 0:
continue

if 'words' in odin_sent:
Expand Down Expand Up @@ -232,7 +237,7 @@ def conllu_to_odin(conllu_sentences, odin_to_enhance=None, is_basic=False, push_
fix_offsets(odin_to_enhance['sentences'][i], summed_offset)

# when added nodes appear fix sent
if any([round(iid) != iid for iid in conllu_sentence.keys()]):
if any([tok.get_conllu_field("id").minor != 0 for tok in conllu_sentence]):
fixed_sentence = fix_sentence(fixed_sentence, push_new_to_end)
if odin_to_enhance:
odin_to_enhance['sentences'][i], text, cur_offset = append_odin(odin_to_enhance['sentences'][i], fixed_sentence, text)
Expand All @@ -246,8 +251,8 @@ def conllu_to_odin(conllu_sentences, odin_to_enhance=None, is_basic=False, push_
fixed_sentences.append(fixed_sentence)
odin_sentences.append(fix_graph(
fixed_sentence, odin_to_enhance['sentences'][i] if odin_to_enhance else
{'words': [token.get_conllu_field("form") for token in fixed_sentence.values() if token.get_conllu_field("id") != 0],
'tags': [token.get_conllu_field("xpos") for token in fixed_sentence.values() if token.get_conllu_field("id") != 0]},
{'words': [token.get_conllu_field("form") for token in fixed_sentence if token.get_conllu_field("id").major != 0],
'tags': [token.get_conllu_field("xpos") for token in fixed_sentence if token.get_conllu_field("id").major != 0]},
is_basic))

if odin_to_enhance:
Expand All @@ -257,8 +262,8 @@ def conllu_to_odin(conllu_sentences, odin_to_enhance=None, is_basic=False, push_
else:
odin = {"documents": {"": {
"id": str(uuid.uuid4()),
"text": " ".join([token.get_conllu_field("form") for conllu_sentence in fixed_sentences for (_, token) in
(sorted(conllu_sentence.items()) if not push_new_to_end else conllu_sentence.items()) if token.get_conllu_field("id") != 0]),
"text": " ".join([token.get_conllu_field("form") for conllu_sentence in fixed_sentences for token in
(sorted(conllu_sentence) if not push_new_to_end else conllu_sentence) if token.get_conllu_field("id").major != 0]),
"sentences": odin_sentences
}}, "mentions": []}

Expand All @@ -274,7 +279,7 @@ def parsed_tacred_json(data):
sentence[i + 1] = Token(i + 1, t, t, p, p, "_", int(h), dep, "_", "_")
sentence[0] = Token(0, None, None, None, None, None, None, None, None, None)
add_basic_edges(sentence)
[child.remove_edge(rel, sentence[0]) for child, rel in sentence[0].get_children_with_rels()]
[child.remove_edge(rel, sentence[0]) for child, rels in sentence[0].get_children_with_rels() for rel in rels]
_ = sentence.pop(0)
sentences.append(sentence)

Expand Down
Loading

0 comments on commit c280589

Please sign in to comment.