forked from ethancaballero/first_stmn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
skip_thought_pruning.py
66 lines (52 loc) · 1.86 KB
/
skip_thought_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from gensim.models import Word2Vec
import numpy
import re
import os
import theano
import theano.tensor as tensor
import cPickle as pkl
import numpy
import copy
import nltk
from collections import OrderedDict, defaultdict
from scipy.linalg import norm
from nltk.tokenize import word_tokenize
import skipthoughts
#you need to find a way to iterate for each statement/question pair; look at "for line in f:" in theano.util
def prune_thoughts(dataset, questions, input_dir):
i = open(input_dir)
text = i.read()
clean = re.sub("[0-9]", "", text)
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
sents = sent_detector.tokenize(clean)
X = sents
#print sents
Y = dataset
model = skipthoughts.load_model()
vectors = skipthoughts.encode(model, X)
#nearest_neighbor = skipthoughts.nn(model, X, vectors, Y, k=5)
#print dataset
#print questions
return X
def prune_statements(dataset, questions):
total_old = 0
total_new = 0
wvs = Word2Vec(dataset, min_count=0)
for i in range(len(questions)):
question = questions[i]
new_statements = []
old_statements = question[2][:-1]
# Use word vectors and keep only the top 5
sims = []
q = question[2][-1]
for s in old_statements:
sims.append(wvs.n_similarity(q,s))
sims2 = map(lambda x: x if type(x) is numpy.float64 else 0.0, sims)
top = sorted(range(len(sims2)), key=sims2.__getitem__, reverse=True)
new_statements = map(lambda x: old_statements[x], top[:5])
questions[i][2] = new_statements
total_old += len(old_statements)
total_new += len(new_statements)
#print("Question: ", questions[i][2][-1], " before %d after %d" % (len(old_statements), len(new_statements)))
print("Before %d After %d" % (total_old, total_new))
return questions