In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import math
from collections import defaultdict, namedtuple
from dataclasses import dataclass, field
from enum import Enum
from operator import itemgetter

import nltk

In [3]:
def trigrams(text):
    for snt in nltk.tokenize.sent_tokenize(text):
        words = [word.lower() for word in nltk.tokenize.word_tokenize(snt) if word.isalpha()]
        for trigram in nltk.ngrams(words, 3):
            yield trigram

In [4]:
class Label(Enum):
    ADJ = "ADJ"
    ADV = "ADV"
    NOUN = "NOUN"
    PRO = "PRO"
    DET = "DET"
    CONJ = "CONJ"
    PREP = "PREP"
    VERB = "VERB"
    # VERB_PAST = "VERB_PAST"
    # VERB_PROG = "VERB_PROG"
    # NOUN_PLRL = "NOUN_PLRL"

> The 28 seed words used for English were the following - pronoun: you, we,
me; verb: come, play, put; preposition: on, out, in; determiner: this, these; noun: baby, car, train,
box, house, boy, man, book; adjective: big, silly, green; adverb: well, very, now; conjunction: and,
or, but.

In [5]:
SEEDS = [
    ("you", Label.PRO),
    ("we", Label.PRO),
    ("me", Label.PRO),
    ("come", Label.VERB),
    ("play", Label.VERB),
    ("put", Label.VERB),
    ("on", Label.PREP),
    ("out", Label.PREP),
    ("in", Label.PREP),
    ("this", Label.DET),
    ("these", Label.DET),
    ("baby", Label.NOUN),
    ("car", Label.NOUN),
    ("train", Label.NOUN),
    ("box", Label.NOUN),
    ("house", Label.NOUN),
    ("boy", Label.NOUN),
    ("man", Label.NOUN),
    ("book", Label.NOUN),
    ("big", Label.ADJ),
    ("silly", Label.ADJ),
    ("green", Label.ADJ),
    ("well", Label.ADV),
    ("very", Label.ADV),
    ("now", Label.ADV),
    ("and", Label.CONJ),
    ("or", Label.CONJ),
    ("but", Label.CONJ)
]

In [21]:
Frame = namedtuple("Frame", ["left", "right"])


@dataclass
class Model:
    frames: dict = field(default_factory=lambda: defaultdict(lambda: defaultdict(int)))  # frame -> label -> score
    lexicon: dict = field(default_factory=lambda: defaultdict(lambda: defaultdict(int))) # word  -> label -> score
    fthresh: int = field(default=15)
    wthresh: int = field(default=15)
    
    def __post_init__(self):
        for word, lbl in SEEDS:
            self.lexicon[word][lbl] = math.inf
    
    def train(self, text):
        for left, target, right in trigrams(text):
            frame = Frame(left, right)
            wlabel, flabel = self.wlabel(target), self.flabel(frame)
            # The target word is part of the trusted lexicon.
            if wlabel:
                # Update frame labels.
                bframe = self.best_frame(target, frame, wlabel)
                self.frames[bframe][wlabel] += 1
                for frm in self.applicable_frames(frame):
                    for lbl in self.frames[frm]:
                        if frm == bframe and lbl == wlabel:
                            continue
                        self.frames[frm][lbl] -= 1
                # Update frame labels.
                # self.frames[frame][wlabel] += 1
                # for lbl in self.frames[frame]:
                #     if lbl != wlabel:
                #         self.frames[frame][lbl] -= 1
            # The frame is a trusted context.
            if flabel:
                # Update word labels.
                self.lexicon[target][flabel] += 1
                for lbl in self.lexicon[target]:
                    if lbl != flabel:
                        self.lexicon[target][lbl] -= 0.75
                
    # TODO: This should not be taking the "max"        
    def wlabel(self, word):
        # Retrieve the highest scoring label for the word.
        if word not in self.lexicon:
            return None
        label, score = max(self.lexicon[word].items(), key=itemgetter(1))
        if score <= self.wthresh:
            return None
        return label
    
    def flabel(self, frame):
        if frame not in self.frames:
            # llbl, rlbl = self.wlabel(frame.left), self.wlabel(frame.right)
            # lfrm, rfrm = Frame(llbl, frame.right), Frame(frame.left, rlbl)
            # cfrm = Frame(llbl, rlbl)
            # if llbl and lfrm in self.lframes:
            #     return self.lframes[lfrm]
            # if rlbl and rfrm in self.rframes:
            #     return self.rframes[rfrm]
            # if llbl and rlbl and cfrm in self.cframes:
            #     return self.cframes[cfrm]
            return None
        label, score = max(self.frames[frame].items(), key=itemgetter(1))
        if score <= self.fthresh:
            return None
        return label
    
    def applicable_frames(self, ctx):
        ret = []
        for frm in self.frames:
            if isinstance(frm.left, Label) and self.wlabel(frm.left) != self.wlabel(ctx.left):
                continue
            if isinstance(frm.right, Label) and self.wlabel(frm.right) != self.wlabel(ctx.right):
                continue
            if isinstance(frm.left, str) and frm.left != ctx.left:
                continue
            if isinstance(frm.right, str) and frm.left != ctx.right:
                continue
            ret.append(frm)
        return ret
            
    def best_frame(self, target, ctx, label):
        # Look for lexical frames.
        for frm in self.frames:
            lbl = self.frames[frm]
            if frm == ctx and lbl == label:
                return frm
        for frm in self.applicable_frames(ctx):
            lbl = self.frames[frm]
            if isinstance(frm.left, Label) and isinstance(frm.right, str) and lbl == label:
                return frm
            if isinstance(frm.right, Label) and isinstance(frm.left, str) and lbl == label:
                return frm
        for frm in self.applicable_frames(ctx):
            lbl = self.frames[frm]
            if isinstance(frm.left, Label) and isinstance(frm.right, Label) and lbl == label:
                return frm
        return ctx
    
    def words(self):
        for word in self.lexicon:
            lbl = self.wlabel(word)
            if lbl:
                yield (word, lbl)

In [22]:
mdl = Model()
mdl.train(" ".join([" ".join(snt) for snt in nltk.corpus.brown.sents()]))

KeyboardInterrupt: 

In [11]:
set(mdl.words()) - set(SEEDS)

{('a', <Label.DET: 'DET'>),
 ('about', <Label.PREP: 'PREP'>),
 ('absence', <Label.NOUN: 'NOUN'>),
 ('across', <Label.PREP: 'PREP'>),
 ('after', <Label.PREP: 'PREP'>),
 ('against', <Label.PREP: 'PREP'>),
 ('all', <Label.PREP: 'PREP'>),
 ('also', <Label.PREP: 'PREP'>),
 ('among', <Label.PREP: 'PREP'>),
 ('amount', <Label.NOUN: 'NOUN'>),
 ('any', <Label.DET: 'DET'>),
 ('are', <Label.PREP: 'PREP'>),
 ('area', <Label.NOUN: 'NOUN'>),
 ('around', <Label.PREP: 'PREP'>),
 ('as', <Label.PREP: 'PREP'>),
 ('at', <Label.PREP: 'PREP'>),
 ('basis', <Label.NOUN: 'NOUN'>),
 ('be', <Label.PREP: 'PREP'>),
 ('because', <Label.PREP: 'PREP'>),
 ('before', <Label.PREP: 'PREP'>),
 ('bottle', <Label.NOUN: 'NOUN'>),
 ('bottom', <Label.NOUN: 'NOUN'>),
 ('by', <Label.PREP: 'PREP'>),
 ('called', <Label.PREP: 'PREP'>),
 ('can', <Label.PREP: 'PREP'>),
 ('center', <Label.NOUN: 'NOUN'>),
 ('corner', <Label.NOUN: 'NOUN'>),
 ('could', <Label.PREP: 'PREP'>),
 ('couple', <Label.NOUN: 'NOUN'>),
 ('day', <Label.NOUN: 'NOUN'

In [28]:
set(mdl.words()) - set(SEEDS)

{('is', <Label.PREP: 'PREP'>),
 ('long', <Label.ADV: 'ADV'>),
 ('much', <Label.ADV: 'ADV'>),
 ('not', <Label.PREP: 'PREP'>),
 ('of', <Label.PREP: 'PREP'>),
 ('that', <Label.PREP: 'PREP'>),
 ('to', <Label.PREP: 'PREP'>)}