-
Notifications
You must be signed in to change notification settings - Fork 6
/
generateEggs.py
76 lines (67 loc) · 2.46 KB
/
generateEggs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import numpy as np
import tensorflow as tf
class GenerateEggs():
def __init__(self, FLAGS, insDataPro, modelSavePath):
self.FLAGS = FLAGS
self.insDataPro = insDataPro
self.modelSavePath = modelSavePath
def generateEggs2Files(self):
self.insDataPro.generateUnlabeledData4Eggs()
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph(self.modelSavePath + ".meta")
graph = tf.get_default_graph()
saver.restore(sess, self.modelSavePath)
xData = graph.get_operation_by_name("xData").outputs[0]
yLabel = graph.get_operation_by_name("yLabel").outputs[0]
yOutput = graph.get_operation_by_name("outputLayer/hOutput").outputs[0]
keepProb = graph.get_operation_by_name("dropOut/keepProb").outputs[0]
for i in xrange(0,
self.insDataPro.allUnlabeledData.shape[0],
self.FLAGS.batchSize):
feedData = {
xData: self.insDataPro.allUnlabeledData[
i: i + self.FLAGS.batchSize],
yLabel: np.zeros((self.FLAGS.batchSize, 2)),
keepProb: 1.0}
probTemp = sess.run(yOutput, feed_dict = feedData)
if i == 0:
probRes = probTemp
else:
probRes = np.append(probRes, probTemp, axis = 0)
print("The number of unlabeled data and feature dimension are:", \
self.insDataPro.allUnlabeledData.shape)
with open(
os.path.join(
self.FLAGS.path4SaveEggsFile,
"eggsfile.txt"),
'w') as filePointer:
# Write drug names
flag = 0
for iele in self.insDataPro.drugName:
if flag == 0:
strLine = str(iele)
flag = 1
else:
strLine += '\t' + str(iele)
strLine += '\n'
filePointer.write(strLine)
# Write tp names and distance
ind4dis = 0
distance = self.insDataPro.calcDistance(probRes)
count4Eggs = 0.0
for ind, iele in enumerate(self.insDataPro.tpName):
strLine = str(iele)
for jele in self.insDataPro.mapMatrix[ind]:
if jele < 0.5:
strLine += '\t' + str(distance[ind4dis][0])
if distance[ind4dis][0] > 0:
count4Eggs += 1
ind4dis += 1
elif jele > 0.5:
strLine += '\t' + "6666666"
strLine += '\n'
filePointer.write(strLine)
print("The percentage of positive cases is:", \
count4Eggs / self.insDataPro.allUnlabeledData.shape[0])