-
Notifications
You must be signed in to change notification settings - Fork 103
/
model.py
183 lines (133 loc) · 5.65 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import tensorflow as tf
"""
Class representing an hierarchically organized model to be initialized in a dependency injection-like manner.
"""
class Model:
score_all_subjects_graph = None
score_all_objects_graph = None
score_graph = None
session = None
next_component=None
save_iter=0
def __init__(self, next_component, settings):
self.next_component = next_component
self.settings = settings
self.entity_count = int(self.settings['EntityCount'])
self.relation_count = int(self.settings['RelationCount'])
self.edge_count = int(self.settings['EdgeCount'])
self.parse_settings()
def parse_settings(self):
pass
def save(self):
variables_to_save = self.get_weights()
if self.saver is None:
self.saver = tf.train.Saver(var_list=variables_to_save)
self.saver.save(self.session, self.settings['ExperimentName'], global_step=self.save_iter)
self.save_iter += 1
'''
High-level functions:
'''
def score(self, triplets):
if self.score_graph is None:
self.score_graph = self.predict()
if self.needs_graph():
d = {self.get_test_input_variables()[0]: self.train_triplets,
self.get_test_input_variables()[1]: triplets}
else:
d = {self.get_test_input_variables()[0]: triplets}
return self.session.run(self.score_graph, feed_dict=d)
def score_all_subjects(self, triplets):
if self.score_all_subjects_graph is None:
self.score_all_subjects_graph = self.predict_all_subject_scores()
if self.needs_graph():
d = {self.get_test_input_variables()[0]: self.test_graph,
self.get_test_input_variables()[1]: triplets}
else:
d = {self.get_test_input_variables()[0]: triplets}
return self.session.run(self.score_all_subjects_graph, feed_dict=d)
def score_all_objects(self, triplets):
if self.score_all_objects_graph is None:
self.score_all_objects_graph = self.predict_all_object_scores()
if self.needs_graph():
d = {self.get_test_input_variables()[0]: self.test_graph,
self.get_test_input_variables()[1]: triplets}
else:
d = {self.get_test_input_variables()[0]: triplets}
return self.session.run(self.score_all_objects_graph, feed_dict=d)
'''
'''
def register_for_test(self, triplets):
self.test_graph = triplets
def preprocess(self, triplets):
self.train_triplets = triplets
pass #return self.__local_run_delegate__('preprocess', triplets)
def initialize_train(self):
return self.__local_run_delegate__('initialize_train')
def get_weights(self):
return self.__local_expand_delegate__('get_weights')
def set_variable(self, name, value):
return self.__local_run_delegate__('set_variable', name, value)
def get_train_input_variables(self):
return self.__local_expand_delegate__('get_train_input_variables')
def get_test_input_variables(self):
return self.__local_expand_delegate__('get_test_input_variables')
def get_loss(self, mode='train'):
return self.__delegate__('get_loss', mode)
def get_regularization(self):
return self.__local_expand_delegate__('get_regularization', base=0)
def get_all_subject_codes(self, mode='train'):
return self.__delegate__('get_all_subject_codes', mode)
def get_all_object_codes(self, mode='train'):
return self.__delegate__('get_all_object_codes', mode)
def get_all_codes(self, mode='train'):
return self.__delegate__('get_all_codes', mode)
def predict(self):
return self.__delegate__('predict')
def predict_all_subject_scores(self):
return self.__delegate__('predict_all_subject_scores')
def predict_all_object_scores(self):
return self.__delegate__('predict_all_object_scores')
def get_graph(self):
return self.__delegate__('get_graph')
def get_additional_ops(self):
return self.__local_expand_delegate__('get_additional_ops')
def needs_graph(self):
if self.next_component is None:
return False
else:
return self.next_component.needs_graph()
'''
Delegate function to the highest-level component with a definition:
'''
def __delegate__(self, name, *args):
if self.next_component is not None:
function = getattr(self.next_component, name)
return function(*args)
return None
'''
Run the function locally if it exists, then delegate to the next component:
'''
def __local_run_delegate__(self, name, *args):
local_function_name = 'local_' + name
if hasattr(self, local_function_name):
local_function = getattr(self, local_function_name)
local_function(*args)
if self.next_component is not None:
function = getattr(self.next_component, name)
function(*args)
'''
Run the function locally if it exists, then compose with the next component through addition:
'''
def __local_expand_delegate__(self, name, *args, base=None):
if base is None:
base = []
local_function_name = 'local_'+name
if hasattr(self, local_function_name):
local_function = getattr(self, local_function_name)
local_result = local_function(*args)
else:
local_result = base
if self.next_component is not None:
function = getattr(self.next_component, name)
return function(*args) + local_result
return local_result