In [None]:
import os
import shutil
import torch
import random
from torch import nn
from torch.nn import CrossEntropyLoss
from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from modelscope.utils.constant import ConfigFields, ModelFile
from transformers.modeling_outputs import TokenClassifierOutput
from modelscope.models.nlp.ponet import PoNetPreTrainedModel, PoNetModel



In [None]:
@MODELS.register_module("token-classification-task", module_name=Models.ponet)
class PoNetForTokenClassificationWithIOUloss(PoNetPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.ponet = PoNetModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        self.ce_loss_coefficient=1.0
        self.l1_loss_coefficient=0.5
        self.iou_loss_coefficient=0.5
        self.dice_loss_coefficient=0.5
        self.TS_iou_dice_fct = IOU_1D()
        self.focal_loss = FocalLoss(alpha=0.75, gamma=2, reduction='mean')

        self.init_weights()

    def l1_loss_fct(self, pred, target):
        criterion = torch.nn.L1Loss()
        return criterion(pred, target) 
    
    # 以下几个方法的实现与最终版本可能存在较大差别
    # 由于代码丢失（压缩包寄了），且时间久远，我就不做修改了

    def iou_loss_fct(self, pred, target):
        iou_loss_sum = 0
        for i in range(len(target)):
            pred_l = pred[i].tolist()
            target_l = target[i].tolist()
            pred_ids = [i for i in range(len(pred_l)) if pred_l[i]==0]
            target_ids = [i for i in range(len(target_l)) if target_l[i]==0]
            pred_ids = [4095] if pred_ids==[] else pred_ids
            target_ids = [4095] if target_ids==[] else target_ids
            iou_score = self.TS_iou_dice_fct.cal_label_pred_iou_rev(target_ids, pred_ids)
            iou_loss_sum += (1 - iou_score['weighted_iou_revavg'])
        return torch.tensor(iou_loss_sum/len(target))
    
    def cal_iou_with_one_same_boundary(self,label_num, pred_num):
        l0, l1 = label_num
        p0, p1 = pred_num
        iou_0 = min(p0, l0)/max(p0, l0) * l0 if p0!=0 and l0!=0 else 0
        iou_1 = min(p1, l1)/max(p1, l1) * l1 if p1!=0 and l1!=0 else 0
        return (iou_0+iou_1)/(l0+l1)

    def cal_one_same_boundary_iou_fct(self,predication, label):
        iou_loss_sum = 0
        for i in range(len(predication)):
            if 1 in predication[i]:
                num_0 = predication[i].tolist().index(1) # 0数量
            else:
                num_0 = len(predication[0])
            label_batch_list = label[i].tolist()
            label_num_1 = label_batch_list.count(1)
            if label_num_1 == 0:
                label_num_0 = label_batch_list.count(0)
            else:
                label_num_0 = label_batch_list.index(1)
            
            pred_num_1 = label_num_0 + label_num_1 - num_0
            #print(num_0, pred_num_1, label_num_0, label_num_1, cal_iou_with_one_same_boundary([label_num_0, label_num_1], [num_0, pred_num_1]))
            iou_loss_sum += self.cal_iou_with_one_same_boundary([label_num_0, label_num_1], [num_0, pred_num_1])
        return torch.tensor(1-iou_loss_sum/len(predication))
    
    def dice_loss_fct(self, pred, target):
        dice_loss_sum = 0
        for i in range(len(target)):
            pred_l = pred[i].tolist()
            target_l = target[i].tolist()
            pred_ids = [i for i in range(len(pred_l)) if pred_l[i]==0]
            target_ids = [i for i in range(len(target_l)) if target_l[i]==0]
            target_ids = [4095] if target_ids==[] else target_ids
            dice_score = self.TS_iou_dice_fct.cal_label_pred_dice(target_ids, pred_ids)
            dice_loss_sum += (1 - 2*dice_score)
        return torch.tensor(dice_loss_sum/len(target))
    
    def cal_dice_with_one_same_boundary(self, label_num, pred_num):
        # 交集长度/(p长度+t长度)
        smooth = 1e-6
        l0, l1 = label_num
        p0, p1 = pred_num
        iou_0 = (min(p0, l0)+smooth)/(p0+l0+smooth) * l0
        iou_1 = (min(p1, l1)+smooth)/(p1+l1+smooth) * l1
        return (iou_0+iou_1)/(l0+l1)
    
    def cal_one_same_boundary_dice_fct(self,predication, label):
        dice_loss_sum = 0
        for i in range(len(predication)):
            if 1 in predication[i]:
                num_0 = predication[i].tolist().index(1) # 0数量
            else:
                num_0 = len(predication[0])
            label_batch_list = label[i].tolist()
            label_num_1 = label_batch_list.count(1)
            if label_num_1 == 0:
                label_num_0 = label_batch_list.count(0)
            else:
                label_num_0 = label_batch_list.index(1)
            pred_num_1 = label_num_0 + label_num_1 - num_0
            #print(num_0, pred_num_1, label_num_0, label_num_1, cal_iou_with_one_same_boundary([label_num_0, label_num_1], [num_0, pred_num_1]))
            dice_score = self.cal_dice_with_one_same_boundary([label_num_0, label_num_1], [num_0, pred_num_1])
            dice_loss_sum += (1 - 2*dice_score)
        return torch.tensor(dice_loss_sum)
        
    def forward(
            self,
            #seg_type=='sent',
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            segment_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        r"""
    labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
        Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
        1]``.
    """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.ponet(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            segment_ids=segment_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        #predication = torch.tensor(np.argmax(logits.cpu().detach().numpy(), axis=2)).to(logits.device)
        predication = torch.argmax(logits, -1)
        

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                # 其实就是label用-100pad
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                #ce_loss = loss_fct(active_logits, active_labels) * self.ce_loss_coefficient
                ce_loss = self.focal_loss(active_logits, active_labels) * self.ce_loss_coefficient
                
            else:
                #ce_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) * self.ce_loss_coefficient
                ce_loss = self.focal_loss(logits.view(-1, self.num_labels), labels.view(-1)) * self.ce_loss_coefficient
            # rest loss
            #l1_loss = self.l1_loss_fct(predication.to(torch.float32), torch.where(labels==-100,0,labels).to(torch.float32)) * self.l1_loss_coefficient
            l1_loss = self.l1_loss_fct(predication.to(torch.float32), labels.to(torch.float32)) * self.l1_loss_coefficient
            #iou_loss = self.iou_loss_fct(predication, labels) * self.iou_loss_coefficient
            iou_loss = self.cal_one_same_boundary_iou_fct(predication, labels) * self.iou_loss_coefficient
            #dice_loss = self.dice_loss_fct(predication, labels) * self.dice_loss_coefficient
            dice_loss = self.cal_one_same_boundary_dice_fct(predication, labels) * self.dice_loss_coefficient
            loss = ce_loss + l1_loss + iou_loss + dice_loss
            #print(ce_loss, l1_loss, iou_loss, dice_loss)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
    def save_pretrained(self, output_dir, state_dict=None):
        output_file = os.path.join(output_dir, WEIGHTS_NAME)
        # print("save state_dict to %s" % output_file)
        # print("state_dict is ", state_dict)
        torch.save(state_dict, output_file)
        if os.path.isfile(os.path.join(self.model_dir, CONFIG_NAME)):
            self.config.to_json_file(os.path.join(output_dir, CONFIG_NAME))
        if os.path.isfile(os.path.join(self.model_dir, ModelFile.CONFIGURATION)):
            shutil.copy(os.path.join(self.model_dir, ModelFile.CONFIGURATION), os.path.join(output_dir, ModelFile.CONFIGURATION))
