From 8473ab666e3fe3a154ea4f77f7f5441fe185c211 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Mon, 16 Nov 2015 23:27:52 +0000 Subject: [PATCH] Trying to do windowing, expand our range beyond 16 words --- modelio.py | 132 ++++++++++++++++++++++++++++++++++-------------- test_io.py | 17 ++++++- test_matcher.py | 10 +++- 3 files changed, 120 insertions(+), 39 deletions(-) diff --git a/modelio.py b/modelio.py index 9a47738..5da7c9d 100644 --- a/modelio.py +++ b/modelio.py @@ -9,59 +9,117 @@ import logging import sys -def load_pos_tagged_data(path, chardict = {}, worddict={}, posdict={}, allow_append=True): - cur_chars, cur_words, cur_labels = [], [], [] - words, chars, labels = [], [], [] +from collections import defaultdict + +def get_windowed(seq, window_length=16, overlap=8): + # Base case: return whatever we have + if len(seq) <= window_length: + return [seq] + + # Otherwise, maintain a buffer of elements + buf = [] + ret = [] + for idx, i in enumerate(seq): + buf.append(i) + if len(buf) == window_length: + ret.append(tuple(buf)) + buf = buf[overlap+1:] + + if len(buf) != window_length - overlap - 1: + ret.append(tuple(buf)) + + return ret + +def build_character_dictionary(path, chars = {}): + with open(path, 'r') as fin: + for line in fin: + line = line.strip() + if len(line) == 0: + continue + word, _ = line.split('\t') + for c in word: + if c not in chars: + chars[c] = len(chars) + 1 + return chars + +def build_word_dictionary(path, words = {}): with open(path, 'r') as fin: for line in fin: line = line.strip() if len(line) == 0: - chars.append(cur_chars[:-1]) - words.append(cur_words[:-1]) - labels.append(cur_labels) - cur_chars = [] - cur_labels = [] - cur_words = [] continue + word, _ = line.split('\t') + if word not in words: + words[word] = len(words) + 1 + return words - word, pos = line.split('\t') +def build_tag_dictionary(path, tags={}): + with open(path, 'r') as fin: + for line in fin: + line = line.strip() + if len(line) == 0: + continue + _, tag = line.split('\t') + if tag not in tags: + tags[tag] = len(tags) + 1 + return tags - if word not in worddict and allow_append: - worddict[word] = len(worddict)+1 +def load_pos_tagged_data(path, chardict = {}, worddict={}, posdict={}, allow_append=True): - for c in word: - if c not in chardict and allow_append: - chardict[c] = len(chardict)+1 + if allow_append: + build_character_dictionary(path, chardict) + build_word_dictionary(path, worddict) + build_tag_dictionary(path, posdict) + + words, chars, labels = [], [], [] + wordbuf, charbuf, labelsbuf = defaultdict(list), defaultdict(list), defaultdict(list) + tweetidx = 0 + with open(path, 'r') as fin: + for line in fin: + cur_words, cur_chars, cur_labels = wordbuf[tweetidx], charbuf[tweetidx], labelsbuf[tweetidx] + cur_word, cur_char, cur_label = [], [], [] + line = line.strip() + if len(line) == 0: + # Tweet boundary + tweetidx += 1 + continue + word, pos = line.split('\t') + for c in '%s ' % (word, ): - if c in chardict: - cur_chars.append(chardict[c]) + if c in chardict and c != ' ': + cur_char.append(chardict[c]) + elif c == ' ': + cur_char.append(0) else: - cur_chars.append(0) + cur_char.append(0) if word in worddict: - cur_words.append(worddict[word]) + cur_word.append(worddict[word]) else: - cur_words.append(0) - - if pos not in posdict and allow_append: - posdict[pos] = len(posdict)+1 + cur_word.append(0) if pos in posdict: - cur_labels.append(posdict[pos]) + cur_label.append(posdict[pos]) else: - cur_labels.append(0) - - if word in worddict: - cur_words.append(worddict[word]) - else: - cur_words.append(0) - cur_chars.append(0) - - if len(cur_chars) > 0: - chars.append(cur_chars) - words.append(cur_words) - labels.append(cur_labels) - + cur_label.append(0) + + cur_words.append(cur_word) + cur_labels.append(cur_label) + cur_chars.append(cur_char) + + for tweetidx in wordbuf: + cur_words, cur_chars, cur_labels = wordbuf[tweetidx], charbuf[tweetidx], labelsbuf[tweetidx] + for window in get_windowed(zip(cur_chars, cur_words, cur_labels), 16, 15): + window_chars, window_words, window_labels = [], [], [] + for (cs, ws, ls) in window: + for (c, w) in zip(cs, ws): + window_chars.append(c) + window_words.append(w) + for l in ls: + window_labels.append(l) + words.append(window_words) + chars.append(window_chars) + labels.append(window_labels) return chars, words, labels def string_to_unprepared_format(text, chardict, worddict): diff --git a/test_io.py b/test_io.py index 7a846dc..86bc04b 100644 --- a/test_io.py +++ b/test_io.py @@ -5,7 +5,7 @@ """ import unittest -from modelio import load_pos_tagged_data, prepare_data +from modelio import load_pos_tagged_data, prepare_data, get_windowed class TestIOMethods(unittest.TestCase): @@ -95,6 +95,8 @@ def test_prepare_data(self): chars, words, labels = load_pos_tagged_data("Data/test_read_2.conll") xc, xw, x_mask, words_mask, y, y_mask = prepare_data(chars, words, labels) + print chars + # 15 is the maximum length of any word self.assertEquals(xc.shape, (15, 10)) self.assertEquals(xw.shape, (15, 10)) # 15 is also the maximum number of words in a tweet @@ -109,3 +111,16 @@ def test_prepare_data(self): self.assertEquals(list(y_mask[:, 0]), [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # words mask: must be tested via the word_averaging_layer op + + def test_window(self): + seq = "not here darling".split() + windowed = get_windowed(seq, 2, 0) + self.assertEquals(windowed, [("not", "here"), ("here", "darling")]) + + def test_window_2(self): + seq = "RT @JosetteSheeran : @WFP #Libya breakthru ! We Move urgently needed #food ( wheat , flour ) by truck convoy into western Libya for 1st time ..." + windowed = get_windowed(seq.split(), 16, 15) + self.assertEquals(windowed[0], ("RT", "@JosetteSheeran", ":", "@WFP", "#Libya", "breakthru", "!", "We", + "Move", "urgently", "needed", "#food", "(", "wheat", ",", "flour")) + self.assertEquals(windowed[1], (")", "by", "truck", "convoy", "into", "western", "Libya", "for", "1st", "time", + "...")) \ No newline at end of file diff --git a/test_matcher.py b/test_matcher.py index 0a50f1d..ca8a817 100644 --- a/test_matcher.py +++ b/test_matcher.py @@ -45,6 +45,14 @@ def test_expand_dictionary(self): test_dict = {} load_pos_tagged_data("Data/TweeboDaily547.conll", worddict=test_dict) + for w in list(word_dict): + if w[0] != "i": + word_dict.pop(w, None) + + for w in list(test_dict): + if w[0] != "i": + test_dict.pop(w, None) + self.assertTrue("ipod" in word_dict) self.assertTrue("ipod" not in test_dict) self.assertTrue("iPod" not in word_dict) @@ -55,4 +63,4 @@ def test_expand_dictionary(self): sim.expand_dict(word_dict, test_dict) - self.assertEqual(word_dict["ipod"], word_dict["iPod"]) \ No newline at end of file + self.assertEqual(word_dict["ipod"], word_dict["iPod"])