In [26]:
#! /usr/bin/python

__author__="Alexander Rush <srush@csail.mit.edu>"
__date__ ="$Sep 12, 2012"

import sys, json
from collections import defaultdict

"""
Count rule frequencies in a binarized CFG.
"""
# 

class Counts:
  def __init__(self):
    self.unary = {}
    self.binary = {}
    self.nonterm = {}
    # rare words counts and rare words set
    self.count_word = defaultdict(int)
    self.rare=set()
    
    
    
    
  def show(self):
    for symbol, count in self.nonterm.items():
      print (count, "NONTERMINAL", symbol)

    for (sym, word), count in self.unary.items():
      print (count, "UNARYRULE", sym, word)

    for (sym, y1, y2), count in self.binary.items():
      print (count, "BINARYRULE", sym, y1, y2)

  def count(self, tree):
    """
    Count the frequencies of non-terminals and rules in the tree.
    """
    if isinstance(tree, str): return

    # Count the non-terminal symbol. 
    symbol = tree[0]
    self.nonterm.setdefault(symbol, 0)
    self.nonterm[symbol] += 1
    
    if len(tree) == 3:
      # It is a binary rule.
      y1, y2 = (tree[1][0], tree[2][0])
      key = (symbol, y1, y2)
      self.binary.setdefault(key, 0)
      self.binary[(symbol, y1, y2)] += 1
      
      # Recursively count the children.
      self.count(tree[1])
      self.count(tree[2])
    elif len(tree) == 2:
      # It is a unary rule.
      y1 = tree[1]
      key = (symbol, y1)
      self.unary.setdefault(key, 0)
      self.unary[key] += 1

  def check_rare(self):
    for (sym, word), count in self.unary.items():
        self.count_word[word]+=count

    for (word ,time) in self.count_word.items():
        if time < 5:
            self.rare.add(word)
    
    #print (len(self.count_word))
    #print (len(self.rare))
    
    
  def replace(self,infile,outfile):
    output = open (outfile,'w')
    for l in open(infile):
      t = json.loads(l)
      adict = self.replace_word(t)
      #print (type(adict))
      json.dump(adict,output,separators=(',',','))
      output.write('\n')
    output.close()
    print ('Done')
     
    
  

  def replace_word(self, tree):
    """
    Count the frequencies of non-terminals and rules in the tree.
    """
    if isinstance(tree, str): return 

    # Count the non-terminal symbol. 
    symbol = tree[0]
    
    if len(tree) == 3:
      # It is a binary rule.
      y1, y2 = (tree[1][0], tree[2][0])
      # Recursively count the children.
      self.replace_word(tree[1])
      self.replace_word(tree[2])
    elif len(tree) == 2:
      # It is a unary rule.
      y1 = tree[1]
      for word in self.rare:
        if y1 == word:
          tree [1] = '_RARE_'
    return tree
    
    
        
def main(parse_file,replace_file):
  counter = Counts() 
  for l in open(parse_file):
    t = json.loads(l)
    counter.count(t)
    
  #counter.show()
  counter.check_rare()
  counter.replace(parse_file,replace_file)
    

def usage():
    sys.stderr.write("""
    Usage: python count_cfg_freq.py [tree_file]
        Print the counts of a corpus of trees.\n""")

    
if __name__ == "__main__":
    parse_file = 'parse_train.dat'
    replace_file ='replace.dat'
    
    main (parse_file,replace_file)
    # will generate cfg.replace.counts as new count
  


Done
