-
Notifications
You must be signed in to change notification settings - Fork 24
/
tmp.py
90 lines (77 loc) · 2.74 KB
/
tmp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from base.corpus import *
import numpy as np
import chainer.links as L
import chainer.functions as F
# probability = np.array([0,1,2,1],dtype=np.int32)
# vv = np.array([[0,1,2,1],[0,1,2,1],[0,1,2,1],[0,1,2,1]],dtype=np.float32)
#
# slices=[]
# for i, v in enumerate(probability):
# slices.append((i,v))
# print slices[:2]
# print F.get_item(vv,slices[:2])
# indices = np.argsort([-p for p in probability]).astype(dtype=np.int32)
# results=[result[:2] for result in indices]
# print results
# n=0
# o=0
# with codecs.open('data/diginetica/digi_test.txt', encoding='utf-8') as f:
# for line in f:
# lines = line.strip('\n').strip('\r').split('\t')
# input = ast.literal_eval(lines[0])
# if int(lines[1]) not in input:
# n+=1
# else:
# o+=1
#
# print n/float(n+o),o/float(n+o),n,o
ff = codecs.open('data/yoo_1_4/test_1_over_4.repeat.txt', encoding='utf-8', mode='w')
fff = codecs.open('data/yoo_1_4/test_1_over_4.nonrepeat.txt', encoding='utf-8', mode='w')
with codecs.open('data/yoo_1_4/test_1_over_4.txt', encoding='utf-8') as f:
for line in f:
lines = line.strip('\n').strip('\r').split('\t')
input = ast.literal_eval(lines[0])
output=int(lines[1])
if output in input:
ff.write(line)
else:
fff.write(line)
ff.close()
fff.close()
# item2id,id2item=load_item(file='data/yoo_items_1_4.txt')
# train_set=SessionCorpus(file_path='data/yoo_test_1_4.txt',item2id=item2id).load()
# items=set()
# with codecs.open('data/diginetica/digi_train.txt', encoding='utf-8') as f:
# for line in f:
# lines = line.strip('\n').strip('\r').split('\t')
# input = ast.literal_eval(lines[0])
# for item in input:
# if item not in items:
# items.add(item)
# if lines[1] not in items:
# items.add(int(lines[1]))
#
# with codecs.open('data/diginetica/digi_valid.txt', encoding='utf-8') as f:
# for line in f:
# lines = line.strip('\n').strip('\r').split('\t')
# input = ast.literal_eval(lines[0])
# for item in input:
# if item not in items:
# items.add(item)
# if lines[1] not in items:
# items.add(int(lines[1]))
#
# with codecs.open('data/diginetica/digi_test.txt', encoding='utf-8') as f:
# for line in f:
# lines = line.strip('\n').strip('\r').split('\t')
# input = ast.literal_eval(lines[0])
# for item in input:
# if item not in items:
# items.add(item)
# if lines[1] not in items:
# items.add(int(lines[1]))
#
# f = codecs.open('data/diginetica/digi_items.txt', encoding='utf-8', mode='w')
# for k in items:
# f.write(str(k)+ os.linesep)
# f.close()