Permalink
executable file 409 lines (367 sloc) 14 KB
#!/usr/bin/env python2
__author__ = 'chetan'
import os, sys
import argparse
import traceback
import re, operator
import subprocess
import csv, time
import math
import glob
import datetime
#'''
#Descripton of vw_lda wrapper
#
#Input dataset: accepts two kinds of inputs for datasets
#0) list of files separated by space
#1) directory containing files and directories
#....
# The script executes the following steps:
# 0. Read the input file(s) to prepare the dataset
# 1. Convert input dataset to vw format
# 2. Invoke vw with lda
# 3. Convert vw format output to human readable format
# 4. Write the output to file
# 5. Output partial results on to the console
# 6. The following intermediate files are generated:
# a. lda_ip.vw : vw input file in word:count format
# b. hashed_ip_lda.vw : input file in vw format
# c. hash_values.csv : integer mapping for each word
DEFAULT_NUM_TOPICS=100
def elapsed_time_str(e):
m, s = divmod(e.total_seconds(), 60)
return "%02d:%02d MM:SS" % (m, round(s))
class vw_lda(object):
def __init__(self):
self.fcount = 0
self.hash = {}
self.count = 0
self.hash_values = []
self.new_hash = {}
self.num_topics = 0
self.op_file = ""
self.word_list = []
self.lda_ipfile = "lda_ip.vw"
self.hashed_lda_ipfile = "hashed_lda_ip.vw"
self.hashed_lda_model = ""
self.hash_values = "hash_values.csv"
self.max_file_count = 10000
def create_hash(self, doc):
'''
Creates hash value for each unseen word
'''
words = re.split("[\s+]", doc)
words = words[1:]
wrd_lst = []
for w in words:
wrd_lst = re.split(":",w)
#empty words to be ignored
if wrd_lst[0] == '':
continue
if wrd_lst[0] not in self.hash.keys():
self.hash[wrd_lst[0]] = self.count
self.count += 1
def calc_bits_required(self):
'''
Calculates the value of b based on the unique word cound
'''
return int(math.ceil(math.log(self.count,2)))
def generate_hash(self, f):
'''
reads the input file and creates a hash for each line
file is in vw format without the freq count for each word
'''
with open(f, "r") as input_file:
lines = input_file.readlines()
input_file.close()
for line in lines:
self.create_hash(line)
self.print_hash()
return
def transform_inputs(self,f):
'''
replaces the vw file with word:count format to hash_value:count format
'''
f1 = open(self.hashed_lda_ipfile, "w")
with open(f, "r") as input_file:
lines = input_file.readlines()
input_file.close()
for line in lines:
if(line != '\n'):
l = self.get_hashed_line(line)
f1.write(l)
f1.close()
return
def get_hashed_line(self, line):
'''
does a lookup for each word to get its corresponding hash value
'''
words = re.split("\s+", line)
words = words[1:]
op = []
for w in words:
wrd_lst = re.split(":",w)
wrd = wrd_lst[0]
if(wrd == ''):
continue
if len(wrd_lst) > 1 and wrd_lst[1] > 1:
v = str(self.hash[wrd]) + ":" + str(wrd_lst[1])
else:
v = str(self.hash[wrd])
op.append(v)
op = "| " + " ".join(op) + "\n"
return op
def print_hash(self):
'''
prints the hash for each unique word to an
intermediate file in csv format
'''
sorted_x = sorted(self.hash.items(), key=operator.itemgetter(1))
ofile = open(self.hash_values, "w")
for l in sorted_x:
s = str(l[1]) + "," + l[0] + "\n"
ofile.write(s)
ofile.close()
return
def convert_docs2lda(self,files):
lda_ip = open(self.lda_ipfile, "w")
for f in files:
str2vw_cmd = "vw-doc2lda -f " + f
str2vw_list = re.split("\s+", str2vw_cmd)
op = ''
try:
op = subprocess.check_output(str2vw_list)
except Exception as e:
if(e.args[1] == 'No such file or directory'):
print "ERROR: vw-doc2lda not found in the path"
sys.exit(1)
'''
print type(e)
print e.args
print e
'''
lda_ip.write(op + "\n")
lda_ip.close()
return
def read_files_from_dir(self, src_dir, num_docs):
if(num_docs == 0 or num_docs is None):
num_docs = self.max_file_count
file_list = []
file_count = 0
for root, dirs,files in os.walk(src_dir):
for f in files:
if(file_count > num_docs):
break
file_count += 1
filename = root + "/" + f
file_list.append(filename)
return file_list
def process_input_dataset(self, src_dir, file_list, num_docs):
if(file_list is None or not file_list):
file_list = self.read_files_from_dir(src_dir, num_docs)
else:
del file_list[num_docs:]
self.convert_docs2lda(file_list)
return
def run_vw(self, args):
"""
runs vw with required inputs, output of vw is dumped onto the console
vw -d hashed_lda_ip.vw -b 8 --lda 2 --lda_D 5
--readable_model lda.model.vw
"""
b = self.calc_bits_required()
vw_cmd_line = "vw -d " + str(self.hashed_lda_ipfile) + " -b " + str(b) \
+ " --lda " + str(args.lda) + " --lda_D " + str(args.lda_d) \
+ " --readable_model " + str(self.hashed_lda_model) \
+ " --lda_alpha " + str(args.lda_alpha) \
+ " --lda_epsilon " + str(args.lda_epsilon) \
+ " --lda_rho " + str(args.lda_rho)
print "vw_cmd_line : ", vw_cmd_line
#print vw_cmd_line
inp = re.split("\s+", vw_cmd_line)
try:
output = subprocess.check_output(inp)
except Exception as e:
if(e.args[1] == 'No such file or directory'):
print "ERROR: vw not found in the path"
sys.exit(1)
return
def process_args(self,args):
'''
reads command line arguments, transforms inputs,
runs vw and transforms outputs
'''
self.op_file = args.out_file
self.hashed_lda_model = args.rd_model
if(args.lda is None):
print "WARNING: -t <number_of_topics> not given, defaulting to 100"
args.lda = DEFAULT_NUM_TOPICS
self.num_topics = args.lda
file_list = []
print 'reading dataset...'
rd_start = datetime.datetime.now()
if(args.lda_d is None or args.lda_d == 0):
args.lda_d = self.max_file_count
self.process_input_dataset(args.src_dir, args.flist, args.lda_d)
rd_end = datetime.datetime.now()
print "Completed in " + elapsed_time_str(rd_end - rd_start)
print "creating hash... "
rd_start = datetime.datetime.now()
self.generate_hash(self.lda_ipfile)
rd_end = datetime.datetime.now()
print "Completed in " + elapsed_time_str(rd_end - rd_start)
print "transforming inputs "
rd_start = datetime.datetime.now()
self.transform_inputs(self.lda_ipfile)
rd_end = datetime.datetime.now()
print "Completed in " + elapsed_time_str(rd_end - rd_start)
print "running vw... "
rd_start = datetime.datetime.now()
self.run_vw(args)
rd_end = datetime.datetime.now()
print "Completed in " + elapsed_time_str(rd_end - rd_start)
print "transforming outputs "
rd_start = datetime.datetime.now()
self.transform_outputs(args.max_terms)
rd_end = datetime.datetime.now()
print "Completed in " + elapsed_time_str(rd_end - rd_start)
return
def populate_hash(self):
'''
get word for each hash value to transform output
'''
with open(self.hash_values, "r") as hfile:
self.hash_values = csv.reader(hfile, delimiter=",")
self.new_hash = {}
for s in self.hash_values:
self.new_hash[s[0]] = s[1]
self.max_word_count = int(s[0])
return
def transform_outputs(self, max_terms):
'''
Parse the vw lda output to map integers back to words
'''
self.populate_hash()
self.f = open(self.hashed_lda_model, "r")
lines = self.f.readlines()
p = re.compile("^Checksum")
got_data = 0
topics = []
line_count = 0
topic1 = {}
topic2 = {}
num_topics = 0
for l in lines:
if (got_data):
word_list = re.split("\s+",l)
w = self.new_hash[str(word_list[0])]
num_topics = len(word_list) - 1
if ( int(word_list[0]) >= self.max_word_count):
break
line_count += 1
top = {}
top['word'] = w
for i in range(1,num_topics):
k = "topic" + str(i)
top[k] = float(word_list[i])
topics.append(top)
if (p.match(l)):
got_data = 1
continue
self.f.close()
#print topics on to the consolde and write to file
self.print_topics_summary(topics, num_topics, max_terms)
return
def print_topics_summary(self,topics, num_topics, top_n):
'''
For each topic, print the word and its corresponding weight
in descending order
'''
op_file = open(self.op_file,"w")
print "Topics are..."
for i in range (1,num_topics):
sorted_topics = []
topic = "topic" + str(i)
sorted_topics = sorted(topics,key=lambda x:x[topic], reverse=True)
count = 0
topic = topic.strip()
print topic
l = ''
op_file.write(topic + "\n")
for top in sorted_topics:
l = "%20s : %5f" % (top['word'] , top[topic])
op_file.write(str(l) + "\n")
if (count <= top_n):
print l
count += 1
print "\n"
op_file.write("\n")
op_file.close()
return
def print_sorted_dict(self, d, top_n):
'''
helper function to print top_n keys dictionary in sorted order
'''
sorted_x = sorted(d.items(), key=operator.itemgetter(1), reverse=True)
count = 0
for x in sorted_x:
if (count > top_n):
break
print("%20s : %5f" % (x[0], float(x[1])))
count += 1
return sorted_x
def main(argv):
'''
read command line arguments and process them
'''
#mandatory arguments
helper = argparse.ArgumentParser(description='Inputs for using vw with lda')
group = helper.add_mutually_exclusive_group(required=True)
group.add_argument('-f', '--file_list', dest = 'flist',nargs='+',type=str,
required=False, metavar='',default=[],
help='List of input files separated by space')
group.add_argument('-s','--src_dir', dest='src_dir', type=str,
required=False, default='.',
help='Directory containing input documents, defaults to current directory', metavar='')
#optional arguments
helper.add_argument('-t', '--lda', dest='lda',
help='Run lda with <int> topics, defaults to ' + str(DEFAULT_NUM_TOPICS),
required=False, metavar='')
helper.add_argument('-n','--lda_D', dest='lda_d', type=int,required=False,
help='Number of documents, defaults to 0, all files to be included', metavar='')
helper.add_argument('-b', '--bit_precision', default=18, dest='b',
help='Number of bits in the feature table, defaults to 18', metavar='')
helper.add_argument('-a', '--lda_alpha', dest='lda_alpha', type=float,
required=False,default=0.1,
help='Prior on sparsity of per document topic, defaults to 0.1',
metavar='')
helper.add_argument('-r', '--lda_rho', dest='lda_rho', type=float,
required=False, default=0.1,
help='Prior on sparsity of topic distributions, defaults to 0.1',
metavar='')
helper.add_argument('-e', '--lda_epsilon', dest='lda_epsilon', type=float,
required=False, help='Loop convergence threshold, defaults to 0.1',
default = 0.1,
metavar='')
helper.add_argument('-o','--output_file', dest='out_file', type=str,
required=False, default="op.model", metavar='',
help='Output file which has the list of topics, defaults to op.model')
helper.add_argument('-m','--readable_model', dest='rd_model', type=str,
required=False, default="hashed_lda_model.vw", metavar='',
help='Output human readable final regressor, defaults to hashed_lda_model.vw')
helper.add_argument('-mt','--max_terms', dest='max_terms', default=10,
type=int,required=False, metavar='',
help='Max terms per topic to print on the console, defaults to 10')
args = ''
try:
if(len(sys.argv) <= 1):
helper.print_help()
sys.exit(1)
args = helper.parse_args()
vw_lda().process_args(args)
except Exception, e:
print "Exception: " + str(e)
traceback.print_exc()
print("The arguments are", args)
if __name__ == "__main__":
main(sys.argv[1:])