In [None]:
%run Net_block.ipynb


class Net(object):

    @staticmethod
    def parse_arch(arch):
        net_depth  = len(arch) - 1
        layer_list = [None] * net_depth
        param_list = [None] * net_depth
        
        input_dims = arch[0][1]
        for i in range(1, net_depth + 1):
            layer_list[i-1] = arch[i][0]
            param_list[i-1] = arch[i][1]
        return input_dims, layer_list, param_list
    
    @staticmethod
    def parse_layers(arch):
        input_dims, layer_list, param_list = Net.parse_arch(arch)
        net_depth = len(layer_list)
        layers = []
        for i in range(net_depth):
            layers.append(Block_mapping.module_mapping[layer_list[i]](param_list[i]))
        return input_dims, layers
    
    @staticmethod
    def init_weights(net, init_type='normal', init_gain=0.02):
        """Initialize network weights.
        Parameters:
            net (network)   -- network to be initialized
            init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
            init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
        """
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
                nn.init.normal_(m.weight.data, 1.0, init_gain)
                nn.init.constant_(m.bias.data, 0.0)
                
        print('initialize network with %s' % init_type)
        net.apply(init_func)  # apply the initialization function <init_func>
    
class Network_template(nn.Module):
    
    def __init__(self, ngpu, arch, attention=0):
        super(Network_template, self).__init__()
        self.ngpu = ngpu
        self.attention = attention
        # self.main = nn.Sequential(*nn.ModuleList(arch))
        #############################################set attention
        if self.attention == 1 or self.attention ==5:
            self.attention1 = nn.Sequential(*nn.ModuleList(arch[:9])) ### 6/9/12
            self.attention2 = nn.Sequential(*nn.ModuleList(arch[9:])) ### 6/9/12
        elif self.attention ==2 :
            self.attention1 = nn.Sequential(*nn.ModuleList(arch[:1]))
            self.attention2 = nn.Sequential(*nn.ModuleList(arch[1:]))
        elif self.attention == 3:
            self.attention1 = nn.Sequential(*nn.ModuleList(arch[:3]))
            self.attention2 = nn.Sequential(*nn.ModuleList(arch[3:7]))
            self.attention3 = nn.Sequential(*nn.ModuleList(arch[7:11]))
            self.attention4 = nn.Sequential(*nn.ModuleList(arch[11:15]))
            self.attention_remain = nn.Sequential(*nn.ModuleList(arch[15:]))
        elif self.attention == 4:
            self.attention1 = nn.Sequential(*nn.ModuleList(arch[:9]))
            self.attention2 = nn.Sequential(*nn.ModuleList(arch[9:]))
        else:
            self.main = nn.Sequential(*nn.ModuleList(arch))

    def forward(self, input):
        if self.training:
            # print('test...............',self.training)
            if self.attention == 1 or self.attention ==5:
                if input.is_cuda and self.ngpu > 1:
                    output1,mask = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                else:
                    output1,mask = self.attention1(input)
                    output2      = self.attention2(output1)
                return output2,mask
            
            elif self.attention == 2:
                if input.is_cuda and self.ngpu > 1:
                    output1,mask = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                else:
                    output1,mask = self.attention1(input)
                    output2      = self.attention2(output1)
                return output2,mask
            
            elif self.attention == 3:
                if input.is_cuda and self.ngpu > 1:
                    output1,mask1 = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2,mask2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                    output3,mask3 = nn.parallel.data_parallel(self.attention3, output2, range(self.ngpu))
                    output4,mask4 = nn.parallel.data_parallel(self.attention4, output3, range(self.ngpu))
                    output = nn.parallel.data_parallel(self.attention_remain, output4, range(self.ngpu))
                else:
                    output1,mask1 = self.attention1(input)
                    output2,mask2 = self.attention2(output1)
                    output3,mask3 = self.attention3(output2)
                    output4,mask4 = self.attention4(output3)
                    output = self.attention_remain(output4)
                return output,mask1,mask2,mask3,mask4
            
            elif self.attention == 4:
                if input.is_cuda and self.ngpu > 1:
                    output1,mask = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                else:
                    output1,mask = self.attention1(input)
                    output2      = self.attention2(output1)
                return output2,mask

            else:
                if input.is_cuda and self.ngpu > 1:
                    output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
                else:
                    output = self.main(input)
                #return torch.squeeze(output)
                return output
            
        else:
            # print('test...............',self.training)
            if self.attention == 1 or self.attention == 5:
                if input.is_cuda and self.ngpu > 1:
                    output1 = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                else:
                    output1 = self.attention1(input)
                    output2 = self.attention2(output1)
                return output2
            
            elif self.attention == 2:
                if input.is_cuda and self.ngpu > 1:
                    output1  = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                else:
                    output1 = self.attention1(input)
                    output2 = self.attention2(output1)
                return output2
            
            elif self.attention == 3:
                if input.is_cuda and self.ngpu > 1:
                    output1 = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                    output3 = nn.parallel.data_parallel(self.attention3, output2, range(self.ngpu))
                    output4 = nn.parallel.data_parallel(self.attention4, output3, range(self.ngpu))
                    output = nn.parallel.data_parallel(self.attention_remain, output4, range(self.ngpu))
                else:
                    output1 = self.attention1(input)
                    output2 = self.attention2(output1)
                    output3 = self.attention3(output2)
                    output4 = self.attention4(output3)
                    output = self.attention_remain(output4)
                return output
            
            elif self.attention == 4:
                if input.is_cuda and self.ngpu > 1:
                    output1 = nn.parallel.data_parallel(self.attention1, input, range(self.ngpu))
                    output2 = nn.parallel.data_parallel(self.attention2, output1, range(self.ngpu))
                else:
                    output1 = self.attention1(input)
                    output2 = self.attention2(output1)
                return output2

            else:
                if input.is_cuda and self.ngpu > 1:
                    output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
                else:
                    output = self.main(input)
                #return torch.squeeze(output)
                return output

#########################################zzl
# class BCEFocalLoss(torch.nn.Module):
#     """
#     二分类的Focalloss alpha 固定
#     alpha is the weight for loss p:0 -> l:1 for us: alpha > 0.5 because 0 for most
#     """
#     def __init__(self, gamma, alpha, reduction='mean'):
#         super().__init__()
#         self.gamma = gamma
#         self.alpha = alpha
#         self.reduction = reduction
 
#     def forward(self, _input, target):
#         pt = torch.sigmoid(_input)
#         alpha = self.alpha
#         loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
#                (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
#         if self.reduction == 'mean':
#             loss = torch.mean(loss)
#         elif self.reduction == 'sum':
#             loss = torch.sum(loss)
#         return loss
    
class BCEFocalLoss(torch.nn.Module):
    def __init__(self, gamma,alpha,reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha=alpha
        self.reduction = reduction
    def forward(self, input, target):
        # input:size is M*2. M　is the batch　number
        # target:size is M.
        pt=torch.softmax(input,dim=1)
        p=pt[:,1]
        loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
               (1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
        if self.reduction == 'mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss

class StandardLoss(nn.Module):
    def __init__(self, mode, reduction, gamma, alpha, theta=0.0):
        super(StandardLoss, self).__init__()
        self.loss_mode = mode
        self.theta = theta
        if mode == 'xentropy':
            self.criterion = nn.CrossEntropyLoss(reduction=reduction)
        elif mode == 'focal':
            self.criterion = BCEFocalLoss(gamma= gamma, alpha= alpha, reduction=reduction)
        else:
            raise NotImplementedError('loss mode %s not implemented' % mode)
            
    def xentropy_loss(self, params):
        # params: [0]net, [1]Dtrain, [2]labels
        assert(len(params) == 3)
        predictions = params[0](params[1])
        loss = self.criterion(predictions, params[2])
        return loss

    def attention_loss(self,params):
        # params: [0]net, [1]Dtrain, [2]labels, [3]pred_mask1, [4]pred_mask2, [5]topo_mask1, [6]topo_mask2
        assert(len(params) == 7)
        predictions = params[0](params[1])

        size1 = params[3].size()
        size2 = params[4].size()
        
        params[5] = params[5].resize_(size1)
        params[6] = params[6].resize_(size2)

        # factor1 = params[3].size(1)/params[5].size(1)
        # factor2 = params[4].size(1)/params[6].size(1)
        # params[5] = torch.nn.functional.interpolate(params[5], scale_factor=factor1, mode='bilinear',
        #                                             align_corners=False)
        # params[6] = torch.nn.functional.interpolate(params[6], scale_factor=factor2, mode='bilinear',
        #                                             align_corners=False)

        l1 = F.mse_loss(params[3],params[5])
        l2 = F.mse_loss(params[4],params[6])

        att_loss = l1+l2
        class_loss = self.criterion(predictions, params[2])
        loss = class_loss + self.theta*att_loss
        return loss
    
    def attention_loss_onestream(self,params):
        # params: [0]net, [1]Dtrain, [2]labels, [3]pred_mask, [4]topo_mask
        assert(len(params) == 5)
        predictions = params[0](params[1])

        size = params[3].size()
        
        params[4] = params[4].resize_(size)

        att_loss = F.mse_loss(params[3],params[4])

        class_loss = self.criterion(predictions, params[2])
        loss = class_loss + self.theta*att_loss
        return loss
    
    def attention_loss_multi(self,params):
        # params: [0]net, [1]Dtrain, [2]labels, [3]pred_mask11, [4]pred_mask12, [5]pred_mask13, [6]pred_mask14, [7]pred_mask21,[8]pred_mask22,[9]pred_mask23,[10]pred_mask24,[11]topo_mask1, [12]topo_mask2
        assert(len(params) == 13)
        predictions = params[0](params[1])

        size1 = params[3].size()
        size2 = params[4].size()
        size3 = params[5].size()
        size4 = params[6].size()
       
        
        mask11_label = params[11].resize_(size1)
        # mask11_label = nn.functional.interpolate(params[11], size=tuple(size1[-3:]), mode='trilinear', align_corners=True)
        l11 = nn.SmoothL1Loss(reduction='mean')(params[3],mask11_label)
        mask12_label = params[11].resize_(size2)
        # mask12_label = nn.functional.interpolate(params[11], size=tuple(size2[-3:]), mode='trilinear', align_corners=True)
        l12 = nn.SmoothL1Loss(reduction='mean')(params[4],mask12_label)
        mask13_label = params[11].resize_(size3)
        # mask13_label = nn.functional.interpolate(params[11], size=tuple(size3[-3:]), mode='trilinear', align_corners=True)
        l13 = nn.SmoothL1Loss(reduction='mean')(params[5],mask13_label)
        mask14_label = params[11].resize_(size4)
        # mask14_label = nn.functional.interpolate(params[11], size=tuple(size4[-3:]), mode='trilinear', align_corners=True)
        l14 = nn.SmoothL1Loss(reduction='mean')(params[6],mask14_label)

        mask21_label = params[12].resize_(size1)
        # mask21_label = nn.functional.interpolate(params[12], size=tuple(size1[-3:]), mode='trilinear', align_corners=True)
        l21 = nn.SmoothL1Loss(reduction='mean')(params[7],mask21_label)
        mask22_label = params[12].resize_(size2)
        # mask22_label = nn.functional.interpolate(params[12], size=tuple(size2[-3:]), mode='trilinear', align_corners=True)
        l22 = nn.SmoothL1Loss(reduction='mean')(params[8],mask22_label)
        mask23_label = params[12].resize_(size3)
        # mask23_label = nn.functional.interpolate(params[12], size=tuple(size3[-3:]), mode='trilinear', align_corners=True)
        l23 = nn.SmoothL1Loss(reduction='mean')(params[9],mask23_label)
        mask24_label = params[12].resize_(size4)
        # mask24_label = nn.functional.interpolate(params[12], size=tuple(size4[-3:]), mode='trilinear', align_corners=True)
        l24 = nn.SmoothL1Loss(reduction='mean')(params[10],mask24_label)

        # print('mse1:',l11.item(),l21.item())
        # print('mse2:',l12.item(),l22.item())
        # print('mse3:',l13.item(),l23.item())
        # print('mse4:',l14.item(),l24.item())

        att_loss = l11+l12+l13+l14+l21+l22+l23+l24
        class_loss = self.criterion(predictions, params[2])
        loss = class_loss + self.theta*att_loss
        # #################################### print test
        # print(l11,l12,l13,l14)
        # print(l21,l22,l23,l24)
        # print(class_loss)
        # ####################################
        return loss
    
    def __call__(self, params):
        if self.loss_mode == 'xentropy' or 'focal':
            if self.theta == 0.0:
                assert(len(params)==3)
                loss = self.xentropy_loss(params)
            else:
                if len(params) == 7:
                    # assert(len(params)==7)
                    loss = self.attention_loss(params)
                elif len(params) == 5:
                    loss = self.attention_loss_onestream(params)
                else:
                    assert(len(params)==13)
                    loss = self.attention_loss_multi(params) 
        return loss
#########################################zzl

# class StandardLoss(nn.Module):
    
#     def __init__(self, mode, reduction, gamma, alpha):
#         super(StandardLoss, self).__init__()
#         self.loss_mode = mode
#         if mode == 'xentropy':
#             self.criterion = nn.CrossEntropyLoss(reduction=reduction)
#         elif mode == 'focal':
#             self.criterion = BCEFocalLoss(gamma= gamma, alpha= alpha, reduction=reduction)
#         else:
#             raise NotImplementedError('loss mode %s not implemented' % mode)
            
#     def xentropy_loss(self, params):
#         # params: [0]net, [1]Dtrain, [2]labels
#         assert(len(params) == 3)
#         predictions = params[0](params[1])
#         loss = self.criterion(predictions, params[2])
#         return loss
    
#     def __call__(self, params):
#         if self.loss_mode == 'xentropy' or 'focal':
#             loss = self.xentropy_loss(params)
#         return loss
    
class GANLoss(nn.Module):
    
    def __init__(self, gan_mode, reduction):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(1.0))
        self.register_buffer('fake_label', torch.tensor(0.0))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.criterion = nn.MSELoss(reduction=reduction)
        elif gan_mode == 'vanilla' or gan_mode == "vanilla_topo":
            self.criterion = nn.BCEWithLogitsLoss(reduction=reduction)
        elif gan_mode == "wgangp" or gan_mode == "wgan":
            pass
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)
            
    def get_target_tensor(self, prediction, is_real):
        if is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)
    
    def calc_gradient_penalty(self, netD, device, Dreal, Dfake, constant=1.0, lambda_gp=10.0):
        batch_size = Dreal.shape[0]
        alpha = torch.rand(batch_size, 1, 1, 1, device=device)
        alpha = alpha.expand_as(Dreal)
        interpolated = alpha * Dreal + (1 - alpha) * Dfake
        interpolated.requires_grad_(True)
        out_interp = netD(interpolated)
        gradients = torch.autograd.grad(outputs=out_interp, inputs=interpolated,
                    grad_outputs=torch.ones(out_interp.size()).to(device),
                    create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients  = gradients.view(batch_size, -1)
        grad_norm  = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)
        gp_penalty = ((grad_norm - constant)**2).mean()
        return gp_penalty * lambda_gp
    
    def vanilla_topo_loss(self, params):
        '''
        params: Dfake_device, Dfix_device
        both Dfake_device and Dfix_device should be output from tanh() layer,
        which have values between -1.0 and 1.0.
        '''
        image_shape = list(params[0].shape)
        image_shape = image_shape[-2:]
        flat_fake = params[0].view(-1, np.prod(image_shape))
        flat_fix  = params[1].view(-1, np.prod(image_shape))
        topo_err  = self.criterion(flat_fake, flat_fix)
        return topo_err
    
    def wgan_gp_loss(self, params):
        if params[0] == "G":
            # if G params[1]: netD, params[2]: Dfake
            assert(len(params) == 3)
            out_fake = params[1](params[2])
            err_fake = -out_fake.mean()
            return err_fake
        elif params[0] == "D":
            # if D params[1]: netD, params[2]: device, params[3]: Dreal, params[4]: Dfake
            assert(len(params) >= 5)
            out_real = params[1](params[3])
            out_fake = params[1](params[4].detach())
            gp_penalty = self.calc_gradient_penalty(params[1], params[2], params[3], params[4].detach())
            return out_fake.mean() - out_real.mean() + gp_penalty
        else:
            raise NotImplementedError('Unrecognized network' % params[0])
            
    def wgan_loss(self, params):
        if params[0] == "G":
            # if G params[1]: netD, params[2]: Dfake
            assert(len(params) == 3)
            out_fake = params[1](params[2])
            err_fake = -out_fake.mean()
            return err_fake
        elif params[0] == "D":
            # if D params[1]: netD, params[2]: device, params[3]: Dreal, params[4]: Dfake
            assert(len(params) >= 5)
            out_real = params[1](params[3])
            out_fake = params[1](params[4].detach())
            return out_fake.mean() - out_real.mean()
        else:
            raise NotImplementedError('Unrecognized network' % params[0])
            
    def gan_loss(self, params):
        if params[0] == "G":
            # if G params[1]: netD, params[2]: Dfake
            assert(len(params) == 3)
            out_fake      = params[1](params[2])
            target_tensor = self.get_target_tensor(out_fake, True)
            err_fake      = self.criterion(out_fake, target_tensor)
            return err_fake
        elif params[0] == "D":
            # if D params[1]: netD, params[2]: device, params[3]: Dreal, params[4]: Dfake
            assert(len(params) >= 5)
            out_real      = params[1](params[3])
            out_fake      = params[1](params[4].detach())
            target_tensor = self.get_target_tensor(out_real, True)
            err_real      = self.criterion(out_real, target_tensor)
            target_tensor = self.get_target_tensor(out_fake, False)
            err_fake      = self.criterion(out_fake, target_tensor)
            return err_real + err_fake
        else:
            raise NotImplementedError('Unrecognized network' % params[0])
        
    def __call__(self, params):
        if self.gan_mode in ['lsgan', 'vanilla']:
            loss = self.gan_loss(params)
        elif self.gan_mode == 'vanilla_topo':
            loss = self.vanilla_topo_loss(params)
        elif self.gan_mode == 'wgangp':
            loss = self.wgan_gp_loss(params)
        elif self.gan_mode == "wgan":
            loss = self.wgan_loss(params)
        return loss