diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d3920d --- /dev/null +++ b/.gitignore @@ -0,0 +1,107 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + + +*.bin diff --git a/OpenGNN b/OpenGNN index 5a4dc28..3c1229e 160000 --- a/OpenGNN +++ b/OpenGNN @@ -1 +1 @@ -Subproject commit 5a4dc2862dbc68e41e5168fca4c0beb34b7defbc +Subproject commit 3c1229ef58c0d95fcbe58082e89eb9a2a2694011 diff --git a/README.md b/README.md index 6c2fc5c..ab3c45e 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ ognn-build-vocab --no_pad_token --field_name edges --string_index 0 \ /data/naturallanguage/cnn_dailymail/split/train/inputs.jsonl.gz ognn-build-vocab --with_sequence_tokens \ --save_vocab /data/naturallanguage/cnn_dailymail/output.vocab \ - /data/naturallanguage/cnn_dailymail/split/train/inputs.jsonl.gz + /data/naturallanguage/cnn_dailymail/split/train/targets.jsonl.gz ``` Then run diff --git a/parsers/__init__.py b/parsers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/parsers/naturallanguage/dmcnn/convert2graph.py b/parsers/naturallanguage/dmcnn/convert2graph.py index c76d8b3..d33c437 100644 --- a/parsers/naturallanguage/dmcnn/convert2graph.py +++ b/parsers/naturallanguage/dmcnn/convert2graph.py @@ -15,13 +15,154 @@ import pdb import codecs import traceback -from typing import Callable, List +from collections import OrderedDict +from typing import Any, List, Optional, Tuple, Callable from docopt import docopt from data.utils import load_xml -from parsers.naturallanguage.gigaword.loadgigacorpus import parse_sample +from nltk.tree import Tree + +from parsers.naturallanguage.graphtextrepr import (DependencyEdge, + GraphTextRepresentation, + Token) +from parsers.naturallanguage.textsummary import TextSummary +from data.utils import load_xml_gz + + +def parse_tree_to_sentence(parse_tree:str)-> List[str]: + return Tree.fromstring(parse_tree).leaves() + +def try_find_RB_span(tokens: List[str]) -> Optional[Tuple[int, int]]: + try: + lrb_idx = tokens.index('-LRB-') + rrb_idx = tokens.index('-RRB-') + if lrb_idx > rrb_idx: + return None # Malformed title, parentheses misplaced + return (lrb_idx, rrb_idx+1) + except ValueError: + return None + + +def parse_sample(datapoint, provenance: str, headline_getter: Optional[Callable[[Any], List[str]]]=None)-> Optional[TextSummary]: + if headline_getter is None and (datapoint.get('HEADLINE') is None or len(datapoint['HEADLINE']) == 0): + return None + try: + if headline_getter is None: + headline_tokens = parse_tree_to_sentence(datapoint['HEADLINE']) + else: + headline_tokens = headline_getter(datapoint) + # Remove LRB-RRB chunks + rb_span = try_find_RB_span(headline_tokens) + while rb_span is not None: + headline_tokens = headline_tokens[:rb_span[0]] + headline_tokens[rb_span[1]:] + rb_span = try_find_RB_span(headline_tokens) + if len(headline_tokens) <= 1: + return None + + except Exception as e: + print('Could not parse %s. Ignoring sample.' % datapoint.get('HEADLINE')) + print(e) + return None + + if 'sentences' not in datapoint or datapoint['sentences'] is None: + return None + + all_sentences = datapoint['sentences']['sentence'] + if type(all_sentences) is not list: + all_sentences = [all_sentences] + + tokenized_sentences = [] # type: List[List[Token]] + for sentence in all_sentences: + sentence_tokens = [] + if type(sentence['tokens']['token']) is not list: + # May happen in single-word sentences + sentence['tokens']['token'] = [sentence['tokens']['token']] + for i, token in enumerate(sentence['tokens']['token']): + assert int(token['@id']) == i + 1 + sentence_tokens.append(Token(word=token['word'], lemma=token['lemma'], pos_tag=token['POS'])) + tokenized_sentences.append(sentence_tokens) + + graph_text_representation = GraphTextRepresentation(tokenized_sentences, provenance=provenance) + + # Add named entities, by finding consecutive annotations + for sentence_idx, sentence in enumerate(all_sentences): + sentence_tokens = sentence['tokens']['token'] + for token_idx, token in enumerate(sentence_tokens): + if 'NER' not in token: + return None # Ignore samples that don't have NER output. + if token['NER'] == 'O': + continue + if token_idx + 1 < len(sentence_tokens) - 1 and sentence_tokens[token_idx + 1]['NER'] != token['NER']: + # Create an entity that includes this token as the last one + before_start_token_idx = token_idx - 1 + while before_start_token_idx > 0 and sentence_tokens[before_start_token_idx]['NER'] == token['NER']: + before_start_token_idx -= 1 + graph_text_representation.add_entity(token['NER'], sentence_idx, before_start_token_idx + 1, token_idx + 1) + + def get_collapsed_dependencies(sentence): + if 'dependencies' not in sentence or sentence['dependencies'] is None: + return None + for dependencies in sentence['dependencies']: + if dependencies['@type'] == 'collapsed-dependencies': + return dependencies + return None + + # Add dependencies + for sentence_idx, sentence in enumerate(all_sentences): + if ('collapsed-dependencies' not in sentence or sentence['collapsed-dependencies'] is None) and get_collapsed_dependencies(sentence) is None: + continue + if 'collapsed-dependencies' in sentence: + collapsed_deps = sentence['collapsed-dependencies'] + else: + collapsed_deps = get_collapsed_dependencies(sentence) + + if type(collapsed_deps['dep']) is not list: + collapsed_deps['dep'] = [collapsed_deps['dep']] + for dependency in collapsed_deps['dep']: + if dependency['@type'] == 'root': + continue # Root is not useful for us + dependency_type = dependency['@type'] + underscore_location = dependency_type.find('_') + if underscore_location != -1: + dependency_type = dependency_type[:underscore_location] + if isinstance(dependency['dependent'], OrderedDict): + dependency['dependent'] = dependency['dependent']['@idx'] + if isinstance(dependency['governor'], OrderedDict): + dependency['governor'] = dependency['governor']['@idx'] + + graph_text_representation.add_dependency_edge(DependencyEdge( + dependency_type=dependency_type, + sentence_idx=sentence_idx, + from_idx=int(dependency['dependent']) - 1, + to_idx=int(dependency['governor']) - 1 + )) + + # Add co-references + coreferences = None + if 'coreferences' in datapoint and datapoint['coreferences'] is not None: + coreferences = datapoint['coreferences'] + elif 'coreference' in datapoint and datapoint['coreference'] is not None: + coreferences = datapoint['coreference'] + + if coreferences is not None: + if type(coreferences['coreference']) is not list: + coreferences['coreference'] = [coreferences['coreference']] + for coreference in coreferences['coreference']: + all_mentions = coreference['mention'] + representative = [m for m in all_mentions if '@representative' in m and m['@representative'] == 'true'][0] + + for mention in all_mentions: + if mention.get('@representative') == 'true' or (mention['sentence'] == representative['sentence'] and mention['head'] == representative['head']): + continue + graph_text_representation.add_coreference(int(mention['sentence']) - 1, int(mention['head']) - 1, + int(representative['sentence']) -1, int(representative['head'])-1) + + return TextSummary( + summary_sentence=headline_tokens, + main_text= graph_text_representation + ) def parse_cnndm_file(filename: str, write_sample_callback: Callable, summaries_folder: str) -> None: def process_sample(location, sample): diff --git a/parsers/sourcecode/barone/ast_graph_generator.py b/parsers/sourcecode/barone/ast_graph_generator.py index a9e6941..2639fb5 100644 --- a/parsers/sourcecode/barone/ast_graph_generator.py +++ b/parsers/sourcecode/barone/ast_graph_generator.py @@ -638,7 +638,6 @@ def visit_Raise(self, node): self.terminal('raise') if hasattr(node, 'exc') and node.exc is not None: - self.terminal(' ') self.visit(node.exc) if node.cause is not None: self.terminal('from') @@ -792,8 +791,6 @@ def visit_UnaryOp(self, node): self.terminal('(') op = UNARYOP_SYMBOLS[type(node.op)] self.terminal(op) - if op == 'not': - self.terminal(' ') self.visit(node.operand) self.terminal(')') self.parent = gparent @@ -938,7 +935,6 @@ def visit_excepthandler(self, node): self.terminal('except') if node.type is not None: - self.terminal(' ') self.visit(node.type) if node.name is not None: self.terminal(' as ') diff --git a/train_and_eval.py b/train_and_eval.py index f9ff6ef..06eaea5 100644 --- a/train_and_eval.py +++ b/train_and_eval.py @@ -286,7 +286,9 @@ def train_and_eval(model, args): worse_epochs = 0 print("saving best model...") saver.save(session, os.path.join(args.checkpoint_dir, "best.ckpt")) - worse_epochs += 1 + else: + worse_epochs += 1 + # and stop training if triggered patience if worse_epochs >= args.patience: print("early stopping triggered...") @@ -424,8 +426,7 @@ def infer(model, args): mode=tf.estimator.ModeKeys.PREDICT, batch_size=args.batch_size, metadata=metadata, - features_file=args.train_source_file, - labels_file=args.train_target_file, + features_file=args.infer_source_file, features_bucket_width=args.bucket_width, sample_buffer_size=args.sample_buffer_size) session_config = tf.ConfigProto( @@ -509,11 +510,11 @@ def build_optimizer(args): optimizer_class = getattr(tf.train, optimizer, None) if optimizer_class is None: raise ValueError("Unsupported optimizer %s" % optimizer) - + kwargs = {} # TODO: optimizer params # optimizer_params = params.get("optimizer_params", {}) - def optimizer(lr): return optimizer_class(lr) # **optimizer_params) + def optimizer(lr): return optimizer_class(lr, **kwargs) learning_rate = args.learning_rate if args.lr_decay_rate is not None: