Skip to content

Commit

Permalink
combination of loss
Browse files Browse the repository at this point in the history
  • Loading branch information
taigw committed May 21, 2021
1 parent 83b6cf6 commit 5cbbca7
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions pymic/loss/seg/combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division

import torch
import torch.nn as nn

class CombinedLoss(nn.Module):
def __init__(self, params, loss_dict):
super(CombinedLoss, self).__init__()
loss_names = params['loss_type']
self.loss_weight = params['loss_weight']
assert (len(loss_names) == len(self.loss_weight))
self.loss_list = []
for loss_name in loss_names:
if(loss_name in loss_dict):
one_loss = loss_dict[loss_name](params)
self.loss_list.append(one_loss)
else:
raise ValueError("{0:} is not defined, or has not been added to the \
loss dictionary".format(loss_name))

def forward(self, loss_input_dict):
loss_value = 0.0
for i in range(len(self.loss_list)):
loss_value = self.loss_weight[i] + self.loss_list[i](loss_input_dict)
return loss_value

0 comments on commit 5cbbca7

Please sign in to comment.