In [1]:
import nltk
import pickle
import pandas as pd
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
pTrees = nltk.corpus.treebank.parsed_sents()
%time len(pTrees) #hm... so counting through all of the trees takes about 5 seconds, but this isn't even touching any part of the tree.
pTrees[0].pretty_print()

Wall time: 6.11 s


3914

                                                     S                                                                         
                         ____________________________|_______________________________________________________________________   
                        |                                               VP                                                   | 
                        |                        _______________________|___                                                 |  
                      NP-SBJ                    |                           VP                                               | 
         _______________|___________________    |     ______________________|______________________________________          |  
        |          |              ADJP      |   |    |        |                PP-CLR                              |         | 
        |          |           ____|____    |   |    |        |          ________|_________          

In [3]:
pTreeProductions = []
for tree in pTrees:
    pTreeProductions.append(tree.productions())

In [4]:
len(pTreeProductions)
pTreeProductions[0]

3914

[S -> NP-SBJ VP .,
 NP-SBJ -> NP , ADJP ,,
 NP -> NNP NNP,
 NNP -> 'Pierre',
 NNP -> 'Vinken',
 , -> ',',
 ADJP -> NP JJ,
 NP -> CD NNS,
 CD -> '61',
 NNS -> 'years',
 JJ -> 'old',
 , -> ',',
 VP -> MD VP,
 MD -> 'will',
 VP -> VB NP PP-CLR NP-TMP,
 VB -> 'join',
 NP -> DT NN,
 DT -> 'the',
 NN -> 'board',
 PP-CLR -> IN NP,
 IN -> 'as',
 NP -> DT JJ NN,
 DT -> 'a',
 JJ -> 'nonexecutive',
 NN -> 'director',
 NP-TMP -> NNP CD,
 NNP -> 'Nov.',
 CD -> '29',
 . -> '.']

In [5]:
def getAllInstances(wordType):
    instances = []
    for tree in pTreeProductions:
        for rule in tree:
            text = str(rule)
            left, right = text.split(" -> ")
            #print("Left, Rule:")
            #print(left)
            #print(rule)
            #print()
            if(left == wordType):
                instances.append(rule)
    return instances

In [6]:
adjs = getAllInstances('JJ')
len(adjs)

5834

In [7]:
file = open('cats.pkl','rb') 
cats = pickle.load(file)

In [8]:
for cat in cats:
    if(cat == 'NN'):
        print(cat)

NN


In [9]:
cats = list(cats)
type(cats)

list

In [10]:
cat_lists = [] #list of lists
for cat in sorted(cats):
    print("Loading all instances of: ", cat)
    %time cat_lists.append(getAllInstances(cat))

Loading all instances of:  ,
Wall time: 1.23 s
Loading all instances of:  .
Wall time: 1.22 s
Loading all instances of:  ADJP
Wall time: 1.3 s
Loading all instances of:  ADVP
Wall time: 1.34 s
Loading all instances of:  CC
Wall time: 1.29 s
Loading all instances of:  CD
Wall time: 1.26 s
Loading all instances of:  DT
Wall time: 1.33 s
Loading all instances of:  EX
Wall time: 1.21 s
Loading all instances of:  FRAG
Wall time: 1.24 s
Loading all instances of:  IN
Wall time: 1.28 s
Loading all instances of:  JJ
Wall time: 1.69 s
Loading all instances of:  JJR
Wall time: 1.26 s
Loading all instances of:  JJS
Wall time: 1.23 s
Loading all instances of:  MD
Wall time: 1.24 s
Loading all instances of:  NN
Wall time: 1.3 s
Loading all instances of:  NNP
Wall time: 1.26 s
Loading all instances of:  NNPS
Wall time: 1.23 s
Loading all instances of:  NNS
Wall time: 1.26 s
Loading all instances of:  NP
Wall time: 1.29 s
Loading all instances of:  NP-TMP
Wall time: 1.32 s
Loading all instances of:  P

In [11]:
cat_lists[5][0]

CD -> '61'

In [12]:
def getRuleProb(wordType, key, isLeaf):
    tagIndex = 0
    for tag in cat_lists:
        try: #needs this try/except because there's one tag that's just empty, the line below errors out.
            raw = str(cat_lists[tagIndex][0])
        except:
            continue
        if(wordType == raw.split(" -> ")[0]): #now at this point, we're on the list of entries for the tag. count hits/total
            hits = 0 #^ this must be true for 'hits' to initialize, if it isn't then the tag is not in the sample
            total = 0
            for rule in cat_lists[tagIndex]:
                #print(rule)
                text = str(rule)
                left, right = text.split(" -> ")
                #print(left, right)
                if(isLeaf):
                    if(key == right[1:-1] or key == right[1:-1].lower()):
                        hits += 1
                else:
                        if(key == right):
                            hits += 1
                total += 1
        tagIndex += 1
    try:
        print("Hits: ", hits, "\nTotal: ", total)
    except:
        print("This tag was not found in the sample of the Penn Treebank.") #FW, LS, POS, UH, RP, SYM
        return
    return hits/total

In [13]:
getRuleProb('NN', 'money', True)

Hits:  56 
Total:  13166


0.00425337991797053

In [14]:
getRuleProb('NP', 'DT JJ NN', False)

Hits:  740 
Total:  23724


0.031192041814196596

***Getting Surprisal***

In [15]:
from nltk import StanfordPOSTagger
from nltk.parse import stanford
from nltk.parse import CoreNLPParser
parser = CoreNLPParser(url='http://localhost:9000')

In [37]:
foo = parser.raw_parse("The cow jumped over the moon.")

In [38]:
for tree in foo:
    foo = tree
    tree.pretty_print()
    rules = tree.productions()

                   ROOT                     
                    |                        
                    S                       
      ______________|_____________________   
     |                   VP               | 
     |         __________|___             |  
     |        |              PP           | 
     |        |      ________|___         |  
     NP       |     |            NP       | 
  ___|___     |     |         ___|___     |  
 DT      NN  VBD    IN       DT      NN   . 
 |       |    |     |        |       |    |  
The     cow jumped over     the     moon  . 



In [53]:
getRuleProb("ROOT", "S", False)
getRuleProb("S", "NP VP .", False)

This tag was not found in the sample of the Penn Treebank.
This tag was not found in the sample of the Penn Treebank.


In [69]:
cat_lists[29][:10]
len(cat_lists[31])
sRules = []
for rule in cat_lists[31]:
    text = str(rule)
    if 'NP' in text:
        sRules.append(text)
len(sRules)
sRules[:10]

[]

8650

7857

['S -> NP-SBJ VP .',
 'S -> NP-SBJ VP .',
 'S -> NP-SBJ-1 VP .',
 'S -> NP-SBJ NP-PRD',
 'S -> S-TPC-1 , NP-SBJ VP .',
 'S -> S-TPC-2 , NP-SBJ VP .',
 'S -> NP-SBJ VP',
 'S -> NP-SBJ VP',
 'S -> NP-SBJ VP .',
 'S -> NP-SBJ VP']

In [52]:
len(rules)
rules[:5]
point = getRuleProb("S", "NP VP .", False) + getRuleProb("NP", "DT NN DT", False) + getRuleProb("DT", "The", True)
given = point + getRuleProb("NN", "cow", True)

13

[ROOT -> S, S -> NP VP ., NP -> DT NN, DT -> 'The', NN -> 'cow']

This tag was not found in the sample of the Penn Treebank.
Hits:  0 
Total:  23724


TypeError: unsupported operand type(s) for +: 'NoneType' and 'float'

In [27]:
point = getRuleProb("DT", "the", True)
given = getRuleProb("IN", "over", True)
surprisal = point/given
surprisal

Hits:  4753 
Total:  8165
Hits:  75 
Total:  9857


76.50593345580731