@@ -188,8 +188,15 @@ def model(data_flow, train=True):
188188 # Predictions for the training, validation, and test data.
189189 with tf .name_scope ('train' ):
190190 self .train_prediction = tf .nn .softmax (logits , name = 'train_prediction' )
191+ tf .add_to_collection ("prediction" , self .train_prediction )
191192 with tf .name_scope ('test' ):
192193 self .test_prediction = tf .nn .softmax (model (self .tf_test_samples , train = False ), name = 'test_prediction' )
194+ tf .add_to_collection ("prediction" , self .test_prediction )
195+
196+ single_shape = (1 , 32 , 32 , 1 )
197+ single_input = tf .placeholder (tf .float32 , shape = single_shape , name = 'single_input' )
198+ self .single_prediction = tf .nn .softmax (model (single_input , train = False ), name = 'single_prediction' )
199+ tf .add_to_collection ("prediction" , self .single_prediction )
193200
194201 self .merged_train_summary = tf .merge_summary (self .train_summaries )
195202 self .merged_test_summary = tf .merge_summary (self .test_summaries )
@@ -277,17 +284,19 @@ def test(self, test_samples, test_labels, *, data_iterator):
277284 self .define_model ()
278285 if self .writer is None :
279286 self .writer = tf .train .SummaryWriter ('./board' , tf .get_default_graph ())
287+
288+ print ('Before session' )
280289 with tf .Session (graph = tf .get_default_graph ()) as session :
281290 self .saver .restore (session , self .save_path )
282291 ### 测试
283292 accuracies = []
284293 confusionMatrices = []
285294 for i , samples , labels in data_iterator (test_samples , test_labels , chunkSize = self .test_batch_size ):
286- result , summary = session .run (
287- [ self .test_prediction , self . merged_test_summary ] ,
295+ result = session .run (
296+ self .test_prediction ,
288297 feed_dict = {self .tf_test_samples : samples }
289298 )
290- self .writer .add_summary (summary , i )
299+ # self.writer.add_summary(summary, i)
291300 accuracy , cm = self .accuracy (result , labels , need_confusion_matrix = True )
292301 accuracies .append (accuracy )
293302 confusionMatrices .append (cm )
@@ -311,13 +320,13 @@ def accuracy(self, predictions, labels, need_confusion_matrix=False):
311320 return accuracy , cm
312321
313322 def visualize_filter_map (self , tensor , * , how_many , display_size , name ):
314- print (tensor .get_shape )
323+ # print(tensor.get_shape)
315324 filter_map = tensor [- 1 ]
316- print (filter_map .get_shape ())
325+ # print(filter_map.get_shape())
317326 filter_map = tf .transpose (filter_map , perm = [2 , 0 , 1 ])
318- print (filter_map .get_shape ())
327+ # print(filter_map.get_shape())
319328 filter_map = tf .reshape (filter_map , (how_many , display_size , display_size , 1 ))
320- print (how_many )
329+ # print(how_many)
321330 self .test_summaries .append (tf .image_summary (name , tensor = filter_map , max_images = how_many ))
322331
323332 def print_confusion_matrix (self , confusionMatrix ):
0 commit comments