Skip to content

Commit

Permalink
Merge pull request #55 from Archaic-Atom/refator_code_in_background_jack
Browse files Browse the repository at this point in the history
fix some bugs
  • Loading branch information
ZhiboRao committed Jan 20, 2022
2 parents 5fbc8ff + c734850 commit 1e85ec3
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions Source/JackFramework/Core/build_training_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,18 @@ def __pass_data2device(self, data: list) -> list:

return data

def __calculation_process(self, model_item: object, input_data: list,
model_id: int) -> tuple:
output_data, _, _ = self.__init_calculation_result()
def __calculation_process(self, model_item: object, input_data: list, label_data: list,
model_id: int, is_training: bool = True) -> tuple:
output_data, loss, acc = self.__init_calculation_result()

# get ouput
output_data = self.__jf_model.inference(model_item, input_data, model_id)
# loss and acc
if is_training:
loss = self.__jf_model.loss(output_data, label_data, model_id)
acc = self.__jf_model.accuary(output_data, label_data, model_id)

return output_data
return output_data, loss, acc

def __variable2tensor(self, data: list) -> None:
res = []
Expand Down Expand Up @@ -174,15 +178,9 @@ def train_model(self, input_data: list, label_data: list) -> list:
input_data = self.__pass_data2device(input_data)
label_data = self.__pass_data2device(label_data)

output_data_list = []
for i, model_item in enumerate(self.__model):
self.__opt[i].zero_grad()
output_data = self.__calculation_process(model_item, input_data, i)
output_data_list.append(output_data)

for i, output_data in enumerate(output_data_list):
loss = self.__jf_model.loss(output_data, label_data, i)
acc = self.__jf_model.accuary(output_data, label_data, i)
_, loss, acc = self.__calculation_process(model_item, input_data, label_data, i)

loss[self.OPT_LOSS_ID].backward()
self.__opt[i].step()
Expand All @@ -202,14 +200,8 @@ def val_model(self, input_data: list, label_data: list) -> list:
label_data = self.__pass_data2device(label_data)

with torch.no_grad():
output_data_list = []
for i, model_item in enumerate(self.__model):
output_data = self.__calculation_process(model_item, input_data, i)
output_data_list.append(output_data)

for i, output_data in enumerate(output_data_list):
loss = self.__jf_model.loss(output_data, label_data, i)
acc = self.__jf_model.accuary(output_data, label_data, i)
_, loss, acc = self.__calculation_process(model_item, input_data, label_data, i)
tower_loss_iteration.append(self.__variable2tensor(loss))
tower_acc_iteration.append(self.__variable2tensor(acc))

Expand Down

0 comments on commit 1e85ec3

Please sign in to comment.