In [None]:
import os
import time
import copy
import json
import jieba
import random
from tqdm.notebook import tqdm
from similarityMetrics import alignmentIndex


#### FocalLoss

In [None]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * CE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

#### 1D IoU

In [None]:
"""
The final version of the code was lost, so an early version is provided here.   
There may be some differences and bugs, so please modify it as needed.

 尤其是这里的计算方式与最终版可能有较大差别，比如miou、wiou等最终并未被使用，请以论文为准
 这里未提供DIoU和ADIoU的计算方式，请自行实现（较为简单）
"""


class IOU_1D:
    def __init__(self, cal_type='max_iou', hit_threshold=0.5):
        self.cal_type = cal_type
        self.hit_threshold = hit_threshold
        
    def block2ids(self, block):
        # [2,1,3] -> [1,2,5]
        return [sum(block[:x+1])-1 for x in range(len(block))]
        
    def cal_iou(self, label_, pred_):
        a,b = label_
        c,d = pred_
        if c>b or a>d:
            raise ValueError('wrong pred index: ',pred_,', label: ',label_)
        its = min(b,d) - max(a,c)+1
        uni = max(b,d) - min(a,c)+1
        return its/uni
    
    def cal_dice(self, label_, pred_):
        # 交集长度/(p长度+t长度)
        smooth = 1e-6
        a,b = label_
        c,d = pred_
        if c>b or a>d:
            raise ValueError('wrong pred index: ',pred_,', label: ',label_)
        its = min(b,d) - max(a,c)+1+smooth
        length_sum = b-a+d-c+2+smooth
        return its/length_sum

    def intersection_area(self, label_, pred_):
        a,b = label_
        c,d = pred_
        if c>b or a>d:
            return 0
        its = min(b,d) - max(a,c)+1
        return its

    def union_area(self, label_, pred_):
        # 前提是有交集
        a,b = label_
        c,d = pred_
        if c>b or a>d:
            raise ValueError('no intersection')
        uni = max(b,d) - min(a,c)+1
        return uni

    def num2l(self, l):
        # 包含首尾
        interval = []
        for i in range(len(l)):
            if i == 0:
                interval.append([0,l[0]])
            else:
                interval.append([l[i-1]+1,l[i]])
        return interval
    
    def cal_maxiou_iou(self, label, pred, sents_num):
        hit = 0
        sum_iou = 0
        weighted_iou_sum = 0
        for i in range(len(label)):
            max_iou= 0
            max_iou_pred = []
            for j in range(len(pred)):
                # 有交集：
                if pred[j][0]<=label[i][1] and pred[j][1]>=label[i][0]:
                    iou_score = self.cal_iou(label[i], pred[j])
                    max_iou = max(max_iou, iou_score)
            if max_iou > self.hit_threshold:
                hit += 1
            sum_iou += max_iou
            weighted_iou_sum += (max_iou*(label[i][1]-label[i][0]+1))
        precision = hit / len(label)
        return {'mean_iou':sum_iou / len(label) , 
                'weighted_iou':weighted_iou_sum / sents_num ,
                'precision':precision}
    
    def cal_maxi_iou(self, label, pred, sents_num):
        hit = 0
        sum_iou = 0
        weighted_iou_sum = 0
        for i in range(len(label)):
            the_iou = 0
            the_pred = pred[0]
            for j in range(len(pred)):
                if pred[j][0]<=label[i][1] and pred[j][1]>=label[i][0]:
                    if self.intersection_area(label[i], pred[j]) > self.intersection_area(label[i], the_pred):
                        the_pred = pred[j]
                    elif self.intersection_area(label[i], pred[j]) == self.intersection_area(label[i], the_pred):
                        if self.union_area(label[i], pred[j]) < self.union_area(label[i], the_pred):
                            the_pred = pred[j]
            # 可能所有预测都没有与之有交集，检测一下
            if not (the_pred[0]<=label[i][1] and the_pred[1]>=label[i][0]):
                the_iou = 0
            else:
                the_iou = self.cal_iou(label[i], the_pred)
            if the_iou > self.hit_threshold:
                hit += 1
            sum_iou += the_iou
            weighted_iou_sum += (the_iou * (label[i][1]-label[i][0]+1))
        precision = hit / len(label)
        return {'mean_iou':sum_iou / len(label), 
                'weighted_iou':weighted_iou_sum / sents_num, 
                'precision':precision}

    def cal_label_pred_iou(self, label, pred):
        # Precision： 以IOU为得分，IOU > threshold的认为是命中，求精确率
        sents_num = label[-1]+1 # 一定要是连续的
        pred_l = self.num2l(pred)
        label_l = self.num2l(label)
        # 最大IOU值
        if self.cal_type == 'max_iou':
            return self.cal_maxiou_iou(label_l, pred_l, sents_num)
        
        # 优先最大交集，交集相同最小并集
        # [4,15,20], [10,17,20]
        # [0,10][11,17]
        #    [5,15]
        #    6 > 5    取max_i
        # 6/16 < 5/13 取max_iou
        elif self.cal_type == 'max_i':
            return self.cal_maxi_iou(label_l, pred_l, sents_num)
        #elif:
        #     label和pred中心点距离：IOU - 中心点距离/并集大小
        else:
            print('cal_type must be "max_i" or "max_iou"!')
            return
        
    def cal_maxiou_iou_with_mask(self, label, pred, sents_num):
        hit = 0
        iou_list = []
        weighted_iou_sum = 0
        # 0 表示没有被匹配
        miss_seg = [0] * len(pred)
        for i in range(len(label)):
            max_iou= 0
            max_iou_pred_id = []
            for j in range(len(pred)):
                # 有交集：
                if pred[j][0]<=label[i][1] and pred[j][1]>=label[i][0]:
                    iou_score = iou_fct.cal_iou(label[i], pred[j])
                    # max_iou = max(max_iou, iou_score)
                    # miss分数最低原则
                    if iou_score >= max_iou:
                        max_iou_pred_id.append(j)
                        max_iou = iou_score
            for pred_id in max_iou_pred_id:
                if miss_seg[pred_id] == 0:
                    miss_seg[pred_id] = 1
                    break
            if max_iou > self.hit_threshold:
                hit += 1
            iou_list.append(max_iou)
        #weighted_iou_sum += (max_iou*(label[i][1]-label[i][0]+1))
        precision = hit / len(label)
        return iou_list, precision, miss_seg
        
    def cal_label_pred_iou_with_punish(self, label, pred):
        sents_num = label[-1]+1
        pred_l = self.num2l(pred)
        label_l = self.num2l(label)
        # 最大IOU值
        if self.cal_type == 'max_iou':
            iou_list, precision, mask = self.cal_maxiou_iou_with_mask(label_l, pred_l, sents_num)
            mask = [1-i for i in mask]
            rev_iou_list, _, _ = self.cal_maxiou_iou_with_mask(pred_l, label_l, sents_num)

            iou_mean = sum(iou_list) / len(label_l)
            iou_weight = sum([iou_list[i]*(label_l[i][1]-label_l[i][0]+1) for i in range(len(iou_list))]) / sents_num
            # miss seg 惩罚
            pred_weight = [x[1]-x[0]+1 for x in pred_l]
            #print(rev_iou_list, pred_weight, mask)
            miss_sum = sum([a*b*c for a,b,c in zip(rev_iou_list, pred_weight, mask)])
            miss_seg_punish_w = 1/miss_sum/sents_num if miss_sum != 0 else 0
            iou_weight_with_punish = iou_weight - miss_seg_punish_w
            return {'mean_iou':iou_mean, 
                    'weighted_iou':iou_weight, 
                    'precision':precision,
                    'weighted_iou_p':iou_weight_with_punish, 
                    'miss_segment_punishment':miss_seg_punish_w}
        
        elif self.cal_type == 'max_i':
            return self.cal_maxi_iou(label_l, pred_l, sents_num)
        #elif:
        #     label和pred中心点距离：IOU - 中心点距离/并集大小
        else:
            print('cal_type must be "max_i" or "max_iou"!')
            return
        
    def cal_label_pred_iou_rev(self, label, pred, w=2/3):
        
        iou = self.cal_label_pred_iou(label, pred)
        iou_rev = self.cal_label_pred_iou(pred, label) if pred != [] else 0
        
        return {'mean_iou_revavg':iou['mean_iou']*w+iou_rev['mean_iou']*(1-w), 
                'weighted_iou_revavg':iou['weighted_iou']*w+iou_rev['weighted_iou']*(1-w)}
        
    def cal_label_pred_dice(self, label, pred):
        sents_num = label[-1]+1 # 一定要是连续的
        pred = self.num2l(pred)
        label = self.num2l(label)
        # 最大dice值
        weighted_sum_dice = 0
        for i in range(len(label)):
            max_dice= 0
            label_inter = []
            for j in range(len(pred)):
                if pred[j][0]<=label[i][1] and pred[j][1]>=label[i][0]:
                    dice_score = self.cal_dice(label[i], pred[j])
                    label_inter.append(pred[j]+[dice_score])
                    max_dice = max(max_dice, dice_score)
            weighted_sum_dice += (max_dice*(label[i][1]-label[i][0]+1))
        return weighted_sum_dice / sents_num


#### Segeval

In [None]:
# 平均label长度的一半
def trans_id2block(seg_ids):
    seg_ids = [-1]+seg_ids
    return [seg_ids[i]-seg_ids[i-1] for i in range(1,len(seg_ids))]


def cal_segeval(test_pred_b, 
                test_label_b, 
                tolerate_dist=21, 
                type='block', 
                end_process=True, 
                decimal_places=4):
    res_all = []
    full_predictions_list = []
    full_labels_list = []

    for i in tqdm(range(len(test_label_b))):
        if type == 'block':
            p_b = test_pred_b[i][:]
            l_b = test_label_b[i][:]
            # block -> ids
            p_id = [sum(p_b[:i+1])-1 for i in range(len(p_b))]
            l_id = [sum(l_b[:i+1])-1 for i in range(len(l_b))]
        else:
            p_id = test_pred_b[i][:]
            l_id = test_label_b[i][:]
        # 尾处理
        if end_process:
            for i in range(len(p_id)-1,-1,-1):
                if p_id[i] > l_id[-1]-tolerate_dist:
                    p_id.pop(i)
            p_id += [l_id[-1]]
        # f1
        if l_id[-1] not in p_id:
            p_id += [l_id[-1]]
        full_labels = ['O']*(l_id[-1]+1)
        for x in l_id:
            full_labels[x] = 'B-EOP'
        full_predictions = ['O']*(l_id[-1]+1)
        for x in p_id:
            full_predictions[x] = 'B-EOP'
        full_predictions[-1] = 'B-EOP'
        full_predictions_list.append(full_predictions)
        full_labels_list.append(full_labels)
        # ids -> block
        new_pred_block = trans_id2block(p_id)
        label_block = trans_id2block(l_id)
        
        avg_block_size = round(sum(label_block)/len(label_block)/2)

        res = {
            'pk':float(1-segeval.pk(new_pred_block, label_block, window_size=avg_block_size)),
            'wd':float(1-segeval.window_diff(new_pred_block, label_block, window_size=avg_block_size)),
            'ss':float(segeval.segmentation_similarity(new_pred_block, label_block, n_t=avg_block_size)),
            'bs':float(segeval.boundary_similarity(new_pred_block, label_block, n_t=avg_block_size)),
            'a' :alignmentIndex(new_pred_block, label_block)
        }
        res_all.append(res)

    #results = metric.compute(predictions=full_predictions_list, references=full_labels_list)
    res = {k:round(sum([i[k] for i in res_all])/len(res_all)*100, decimal_places) for k,v in res_all[0].items()}
    #res['f1'] = round(results['EOP']['f1']*100, decimal_places)
    return res