In [None]:
def anchored_npair_loss_meta_v2(inputs_logit, y_true, hidden, net):
    # 控制参数
    eps = 1e-10
    with_l2reg = True
    alpha = 1

    # 读取数据
    corr_threshold = 0.2
    input_unspv, input_labels = hidden, y_true

    # 计算input_labels相关的变量
    mask = torch.einsum("ik,jk->ij", input_labels, input_labels)
    mask_neg = 1. - mask

    mask_pos_avg = mask / (torch.sum(mask, dim=1, keepdim=True))
    mask_neg_avg = mask_neg / (torch.sum(mask_neg, dim=1, keepdim=True))

    sample_weight = torch.sum(mask, dim=1) - 1.
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_pos = torch.sign(sample_weight)

    sample_weight = torch.sum(mask_neg, dim=1)
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_neg = torch.sign(sample_weight)

    input_unspv = F.normalize(input_unspv, p=2, dim=1)
    with torch.no_grad():
        input_unspv = input_unspv
    similarity_matrix_unspv = torch.einsum("ik,jk->ij", input_unspv, input_unspv)
    linear_score_pos = alpha * similarity_matrix_unspv
    linear_score_neg = alpha * -similarity_matrix_unspv
    prob_target_unspv_pos = F.softmax(linear_score_pos, dim=1)
    prob_target_unspv_neg = F.softmax(linear_score_neg, dim=1)

    gate_unspv_pos = torch.sign(torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - corr_threshold))
    gate_unspv_neg = torch.sign(torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - corr_threshold))

    if with_l2reg:
        reg = torch.mean(torch.sum(torch.pow(inputs_logit, 2), dim=1)).float()
        l2loss = torch.mul(0.25 * 0.002, reg)
    else:
        l2loss = 0.0

    # 计算正样本和负样本与含噪标签的cross entropy loss
    similarity_matrix = torch.einsum("ik,jk->ij", inputs_logit, inputs_logit)
    prob_pos = F.softmax(similarity_matrix, dim=1)
    prob_neg = F.softmax(-similarity_matrix, dim=1)

    log_prob_pos = torch.log(prob_pos + eps)
    log_prob_neg = torch.log(prob_neg + eps)

    ce_loss_pos = -log_prob_pos * mask_pos_avg
    ce_loss_neg = -log_prob_neg * mask_neg_avg

    unspv_loss_pos = -log_prob_pos * prob_target_unspv_pos
    unspv_loss_neg = -log_prob_neg * prob_target_unspv_neg

    # 计算正样本的loss
    xent_loss_pos_clean = gate_unspv_pos * ce_loss_pos
    xent_loss_pos_noise = (1. - gate_unspv_pos) * ce_loss_pos

    xent_loss_pos_clean = torch.sum(xent_loss_pos_clean, dim=1)
    xent_loss_pos_noise = torch.sum(xent_loss_pos_noise, dim=1)
    # 去除非法的正负样本的loss
    xent_loss_pos_clean = torch.sum(xent_loss_pos_clean * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_pos_noise = torch.sum(xent_loss_pos_noise * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_pos_list = [xent_loss_pos_clean, xent_loss_pos_noise]

    grad_list = []
    grad_norm_list = []

    listOfVariableTensors = []
    listOfVariableTensors.extend([p for p in net.parameters()])

    for r in range(len(xent_loss_pos_list)):
        grad_r = torch.autograd.grad(xent_loss_pos_list[r], listOfVariableTensors, retain_graph=True,allow_unused=True)
        grad_norm = get_grad_norm(grad_r)
        grad_list.append(grad_r)
        grad_norm_list.append(grad_norm)

    grad_prod = get_grad_prod(grad_list[0], grad_list[1])
    loss_grad_match = - grad_prod / (grad_norm_list[0] * grad_norm_list[1])

    xent_loss_pos = xent_loss_pos_clean + loss_grad_match * 0.001

    ## 计算负样本的loss
    xent_loss_neg_clean = gate_unspv_neg * ce_loss_neg
    xent_loss_neg_noise = (1. - gate_unspv_neg) * ce_loss_neg
    xent_loss_neg_clean = torch.sum(xent_loss_neg_clean, dim=1)
    xent_loss_neg_noise = torch.sum(xent_loss_neg_noise, dim=1)
    xent_loss_neg_clean = torch.sum(xent_loss_neg_clean * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss_neg_noise = torch.sum(xent_loss_neg_noise * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss_neg_list = [xent_loss_neg_clean, xent_loss_neg_noise]

    grad_list = []
    grad_norm_list = []

    listOfVariableTensors = []
    listOfVariableTensors.extend([p for p in net.parameters()])

    for r in range(len(xent_loss_neg_list)):
        grad_r = torch.autograd.grad(xent_loss_neg_list[r], listOfVariableTensors, retain_graph=True,allow_unused=True)
        grad_norm = get_grad_norm(grad_r)
        grad_list.append(grad_r)
        grad_norm_list.append(grad_norm)

    grad_prod = get_grad_prod(grad_list[0], grad_list[1])
    loss_grad_match = - grad_prod / (grad_norm_list[0] * grad_norm_list[1])

    xent_loss_neg = xent_loss_pos_clean + loss_grad_match * 0.001

    unspv_loss_neg = torch.sum(unspv_loss_neg, dim=1)
    unspv_loss_pos = torch.sum(unspv_loss_pos, dim=1)
    unspv_loss_neg = torch.sum(unspv_loss_neg * sample_weight_neg) / torch.sum(sample_weight_neg)
    unspv_loss_pos = torch.sum(unspv_loss_pos * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_neg_all = xent_loss_neg + unspv_loss_neg
    xent_loss_pos_all = xent_loss_pos + unspv_loss_pos
    xent_loss = xent_loss_neg_all + xent_loss_pos_all

    loss = (l2loss + xent_loss)
    return loss

In [None]:
def anchored_npair_loss_v2(inputs_logit, y_true, hidden):
    # 控制参数
    eps = 1e-10
    with_l2reg = True
    alpha = 1

    # 读取数据
    corr_threshold = 0.2
    input_unspv, input_labels = hidden, y_true

    # 计算input_labels相关的变量
    mask = torch.einsum("ik,jk->ij", input_labels, input_labels)
    mask_neg = 1. - mask

    mask_pos_avg = mask / (torch.sum(mask, dim=1, keepdim=True))
    mask_neg_avg = mask_neg / (torch.sum(mask_neg, dim=1, keepdim=True))

    sample_weight = torch.sum(mask, dim=1) - 1.
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_pos = torch.sign(sample_weight)

    sample_weight = torch.sum(mask_neg, dim=1)
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_neg = torch.sign(sample_weight)

    input_unspv = F.normalize(input_unspv, p=2, dim=1)
    with torch.no_grad():
        input_unspv = input_unspv
    similarity_matrix_unspv = torch.einsum("ik,jk->ij", input_unspv, input_unspv)
    linear_score_pos = alpha * similarity_matrix_unspv
    linear_score_neg = alpha * -similarity_matrix_unspv
    prob_target_unspv_pos = F.softmax(linear_score_pos, dim=1)
    prob_target_unspv_neg = F.softmax(linear_score_neg, dim=1)
    gate_unspv_pos = torch.sign(torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - corr_threshold))
    gate_unspv_neg = torch.sign(torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - corr_threshold))

    if with_l2reg:
        reg = torch.mean(torch.sum(torch.pow(inputs_logit, 2), dim=1)).float()
        l2loss = torch.mul(0.25 * 0.002, reg)
    else:
        l2loss = 0.0

    # 计算正样本和负样本与含噪标签的cross entropy loss
    similarity_matrix = torch.einsum("ik,jk->ij", inputs_logit, inputs_logit)
    prob_pos = F.softmax(similarity_matrix, dim=1)
    prob_neg = F.softmax(-similarity_matrix, dim=1)

    log_prob_pos = torch.log(prob_pos + eps)
    log_prob_neg = torch.log(prob_neg + eps)

    ce_loss_pos = -log_prob_pos * mask_pos_avg
    ce_loss_neg = -log_prob_neg * mask_neg_avg

    unspv_loss_pos = -log_prob_pos * prob_target_unspv_pos
    unspv_loss_neg = -log_prob_neg * prob_target_unspv_neg

    # 计算最终loss
    xent_loss_pos = gate_unspv_pos * ce_loss_pos + (1. - gate_unspv_pos) * unspv_loss_pos
    xent_loss_neg = gate_unspv_neg * ce_loss_neg + (1. - gate_unspv_neg) * unspv_loss_neg
    xent_loss_pos = torch.sum(xent_loss_pos, dim=1)
    xent_loss_neg = torch.sum(xent_loss_neg, dim=1)

    # 去除非法的正负样本的loss
    xent_loss_pos = torch.sum(xent_loss_pos * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_neg = torch.sum(xent_loss_neg * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss = xent_loss_pos + xent_loss_neg

    # 计算l2loss

    loss = (l2loss + xent_loss)
    return loss

In [3]:
16*4*4*(1+2)*6/(24*4)

48

In [None]:
def anchored_npair_loss_v3(inputs_logit, y_true, hidden):
    # 控制参数
    eps = 1e-10
    with_l2reg = True
    alpha = 1

    # 读取数据
    corr_threshold = 0.2
    input_unspv, input_labels = hidden, y_true

    # 计算input_labels相关的变量
    mask = torch.einsum("ik,jk->ij", input_labels, input_labels)
    mask_neg = 1. - mask

    mask_pos_avg = mask / (torch.sum(mask, dim=1, keepdim=True))
    mask_neg_avg = mask_neg / (torch.sum(mask_neg, dim=1, keepdim=True))

    sample_weight = torch.sum(mask, dim=1) - 1.
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_pos = torch.sign(sample_weight)

    sample_weight = torch.sum(mask_neg, dim=1)
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_neg = torch.sign(sample_weight)

    input_unspv = F.normalize(input_unspv, p=2, dim=1)
    with torch.no_grad():
        input_unspv = input_unspv
    similarity_matrix_unspv = torch.einsum("ik,jk->ij", input_unspv, input_unspv)
    linear_score_pos = alpha * similarity_matrix_unspv
    linear_score_neg = alpha * -similarity_matrix_unspv
    prob_target_unspv_pos = F.softmax(linear_score_pos, dim=1)
    prob_target_unspv_neg = F.softmax(linear_score_neg, dim=1)
    gate_unspv_pos = torch.sign(torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - corr_threshold))
    gate_unspv_neg = torch.sign(torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - corr_threshold))

    if with_l2reg:
        reg = torch.mean(torch.sum(torch.pow(inputs_logit, 2), dim=1)).float()
        l2loss = torch.mul(0.25 * 0.002, reg)
    else:
        l2loss = 0.0

    # 计算正样本和负样本与含噪标签的cross entropy loss
    similarity_matrix = torch.einsum("ik,jk->ij", inputs_logit, inputs_logit)
    prob_pos = F.softmax(similarity_matrix, dim=1)
    prob_neg = F.softmax(-similarity_matrix, dim=1)

    log_prob_pos = torch.log(prob_pos + eps)
    log_prob_neg = torch.log(prob_neg + eps)

    ce_loss_pos = -log_prob_pos * mask_pos_avg
    ce_loss_neg = -log_prob_neg * mask_neg_avg

    unspv_loss_pos = -log_prob_pos * prob_target_unspv_pos
    unspv_loss_neg = -log_prob_neg * prob_target_unspv_neg

    # 计算最终loss
    xent_loss_pos = gate_unspv_pos * ce_loss_pos + unspv_loss_pos
    xent_loss_neg = gate_unspv_neg * ce_loss_neg + unspv_loss_neg
    xent_loss_pos = torch.sum(xent_loss_pos, dim=1)
    xent_loss_neg = torch.sum(xent_loss_neg, dim=1)

    # 去除非法的正负样本的loss
    xent_loss_pos = torch.sum(xent_loss_pos * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_neg = torch.sum(xent_loss_neg * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss = xent_loss_pos + xent_loss_neg

    # 计算l2loss

    loss = (l2loss + xent_loss)
    return loss

In [None]:
def anchored_npair_loss_meta_2_tmp_step(inputs_logit, y_true, hidden, net):
    # 控制参数
    eps = 1e-10
    with_l2reg = True
    alpha = 1

    # 读取数据
    corr_threshold = 0.2
    input_unspv, input_labels = hidden, y_true

    # 计算input_labels相关的变量
    mask = torch.einsum("ik,jk->ij", input_labels, input_labels)
    mask_neg = 1. - mask

    mask_pos_avg = mask / (torch.sum(mask, dim=1, keepdim=True))
    mask_neg_avg = mask_neg / (torch.sum(mask_neg, dim=1, keepdim=True))

    sample_weight = torch.sum(mask, dim=1) - 1.
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_pos = torch.sign(sample_weight)

    sample_weight = torch.sum(mask_neg, dim=1)
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight_neg = torch.sign(sample_weight)

    input_unspv = F.normalize(input_unspv, p=2, dim=1)
    with torch.no_grad():
        input_unspv = input_unspv
    similarity_matrix_unspv = torch.einsum("ik,jk->ij", input_unspv, input_unspv)
    linear_score_pos = alpha * similarity_matrix_unspv
    linear_score_neg = alpha * -similarity_matrix_unspv
    prob_target_unspv_pos = F.softmax(linear_score_pos, dim=1)
    prob_target_unspv_neg = F.softmax(linear_score_neg, dim=1)

    gate_unspv_pos = torch.sign(torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - corr_threshold))
    gate_unspv_neg = torch.sign(torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - corr_threshold))
    gate_unspv_pos_label_correction = torch.sign(torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - 1. + corr_threshold))
    gate_unspv_neg_label_correction = torch.sign(torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - 1. + corr_threshold))
    
    if with_l2reg:
        reg = torch.mean(torch.sum(torch.pow(inputs_logit, 2), dim=1)).float()
        l2loss = torch.mul(0.25 * 0.002, reg)
    else:
        l2loss = 0.0

    # 计算正样本和负样本与含噪标签的cross entropy loss
    similarity_matrix = torch.einsum("ik,jk->ij", inputs_logit, inputs_logit)
    prob_pos = F.softmax(similarity_matrix, dim=1)
    prob_neg = F.softmax(-similarity_matrix, dim=1)

    log_prob_pos = torch.log(prob_pos + eps)
    log_prob_neg = torch.log(prob_neg + eps)

    ce_loss_pos = -log_prob_pos * mask_pos_avg
    ce_loss_neg = -log_prob_neg * mask_neg_avg
    ce_loss_pos_label_correction = -log_prob_pos * mask_neg_avg
    ce_loss_neg_label_correction = -log_prob_neg * mask_pos_avg

    unspv_loss_pos = -log_prob_pos * prob_target_unspv_pos
    unspv_loss_neg = -log_prob_neg * prob_target_unspv_neg

    # 计算正样本的loss
    xent_loss_pos_clean = gate_unspv_pos * ce_loss_pos
    xent_loss_pos_noise = (1. - gate_unspv_pos) * ce_loss_pos
    xent_loss_pos_label_correction = gate_unspv_pos_label_correction * ce_loss_pos_label_correction

    xent_loss_pos_clean = torch.sum(xent_loss_pos_clean, dim=1)
    xent_loss_pos_noise = torch.sum(xent_loss_pos_noise, dim=1)
    xent_loss_pos_label_correction = torch.sum(xent_loss_pos_label_correction, dim=1)
    # 去除非法的正负样本的loss
    xent_loss_pos_clean = torch.sum(xent_loss_pos_clean * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_pos_noise = torch.sum(xent_loss_pos_noise * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_pos_label_correction = torch.sum(xent_loss_pos_label_correction * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss_pos_list = [xent_loss_pos_clean, xent_loss_pos_noise,xent_loss_pos_label_correction]

    grad_list = []
    grad_norm_list = []

    listOfVariableTensors = []
    listOfVariableTensors.extend([p for p in net.parameters()])

    for r in range(len(xent_loss_pos_list)):
        grad_r = torch.autograd.grad(xent_loss_pos_list[r], listOfVariableTensors, retain_graph=True,allow_unused=True)
        grad_norm = get_grad_norm(grad_r)
        grad_list.append(grad_r)
        grad_norm_list.append(grad_norm)

    loss_grad_match = 0.
    for r in [1,2]:    
        grad_prod = get_grad_prod(grad_list[0], grad_list[r])
        loss_grad_match += - grad_prod / (grad_norm_list[0] * grad_norm_list[r])

    xent_loss_pos = xent_loss_pos_clean + loss_grad_match * 0.001

    ## 计算负样本的loss
    xent_loss_neg_clean = gate_unspv_neg * ce_loss_neg
    xent_loss_neg_noise = (1. - gate_unspv_neg) * ce_loss_neg
    xent_loss_neg_label_correction = gate_unspv_neg_label_correction * ce_loss_neg_label_correction
    xent_loss_neg_clean = torch.sum(xent_loss_neg_clean, dim=1)
    xent_loss_neg_noise = torch.sum(xent_loss_neg_noise, dim=1)
    xent_loss_neg_label_correction = torch.sum(xent_loss_neg_label_correction, dim=1)
    xent_loss_neg_clean = torch.sum(xent_loss_neg_clean * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss_neg_noise = torch.sum(xent_loss_neg_noise * sample_weight_neg) / torch.sum(sample_weight_neg)
    xent_loss_neg_label_correction = torch.sum(xent_loss_neg_label_correction * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_neg_list = [xent_loss_neg_clean, xent_loss_neg_noise,xent_loss_neg_label_correction]

    grad_list = []
    grad_norm_list = []

    listOfVariableTensors = []
    listOfVariableTensors.extend([p for p in net.parameters()])

    for r in range(len(xent_loss_neg_list)):
        grad_r = torch.autograd.grad(xent_loss_neg_list[r], listOfVariableTensors, retain_graph=True,allow_unused=True)
        grad_norm = get_grad_norm(grad_r)
        grad_list.append(grad_r)
        grad_norm_list.append(grad_norm)

    loss_grad_match = 0.
    for r in [1,2]:    
        grad_prod = get_grad_prod(grad_list[0], grad_list[r])
        loss_grad_match += - grad_prod / (grad_norm_list[0] * grad_norm_list[r])

    xent_loss_neg = xent_loss_neg_clean + loss_grad_match * 0.001

    unspv_loss_neg = torch.sum(unspv_loss_neg, dim=1)
    unspv_loss_pos = torch.sum(unspv_loss_pos, dim=1)
    unspv_loss_neg = torch.sum(unspv_loss_neg * sample_weight_neg) / torch.sum(sample_weight_neg)
    unspv_loss_pos = torch.sum(unspv_loss_pos * sample_weight_pos) / torch.sum(sample_weight_pos)
    xent_loss_neg_all = xent_loss_neg + unspv_loss_neg
    xent_loss_pos_all = xent_loss_pos + unspv_loss_pos
    xent_loss = xent_loss_neg_all + xent_loss_pos_all

    loss = (l2loss + xent_loss)
    return loss

In [11]:
import numpy as np
def stat_sim_mat_v2(v):
    v_min,v_max,v_mean,v_median =np.min(v),np.max(v),np.mean(v),np.median(v)
 
    p= np.array([2,5,10,20,30,40,50,60,70,80,90,95,98])
    v_percentile = np.percentile(v, p)
  
    return v_min,v_max,v_mean,v_median,p,v_percentile

In [None]:
def stat_sim_mat(v):
    v = np.round(10.*v)
    uni,count = np.unique(v,return_counts=True)
    return uni,count

In [13]:
v = np.random.rand(100000)
v = 2.*v - 1.
stat_sim_mat(v)

(array([-10.,  -9.,  -8.,  -7.,  -6.,  -5.,  -4.,  -3.,  -2.,  -1.,  -0.,
          1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.]),
 array([2486, 5023, 4963, 4992, 5063, 4883, 5059, 4973, 5043, 4977, 5076,
        4966, 4884, 5033, 4970, 5044, 5043, 5021, 5008, 4994, 2499]))

In [14]:
np.round(10.*(-0.35))

-4.0

In [15]:
np.round(10.*(-0.34))

-3.0

In [16]:
stat_sim_mat_v2(v)

(-0.9999843372547392,
 0.9999855061806806,
 0.0006476926796734654,
 0.00041907590330292344,
 array([ 2,  5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 98]),
 array([-9.59512230e-01, -9.00404185e-01, -7.99478725e-01, -5.99975564e-01,
        -3.98982849e-01, -1.98410591e-01,  4.19075903e-04,  2.00990077e-01,
         4.02902092e-01,  6.01259393e-01,  8.00619520e-01,  9.00633024e-01,
         9.60487017e-01]))

In [17]:
128*128*60000/128

7680000

In [18]:
np.random.beta(32,32,size=(100))

array([0.51659246, 0.49376934, 0.3792147 , 0.44338098, 0.60585649,
       0.55737313, 0.5348576 , 0.53472207, 0.40176676, 0.50525578,
       0.58404559, 0.47501527, 0.47971671, 0.4993767 , 0.40948152,
       0.53372301, 0.55832021, 0.51466041, 0.48120701, 0.52186655,
       0.53313782, 0.50681102, 0.51702175, 0.43891825, 0.50798977,
       0.37407591, 0.36704216, 0.64532432, 0.36081241, 0.51436255,
       0.47910773, 0.57868547, 0.4169693 , 0.4900508 , 0.46897939,
       0.50849507, 0.53130475, 0.33823885, 0.42791037, 0.50275684,
       0.48166469, 0.43064536, 0.56555076, 0.61207521, 0.4913266 ,
       0.50763468, 0.40263366, 0.50088329, 0.52251777, 0.54730101,
       0.47770008, 0.47790514, 0.51243745, 0.53064291, 0.55369189,
       0.48523546, 0.42491078, 0.56233045, 0.5209299 , 0.47378245,
       0.55242614, 0.36478064, 0.54527767, 0.39869996, 0.45112624,
       0.61577146, 0.53706447, 0.45521491, 0.44147735, 0.55630267,
       0.58552156, 0.53391271, 0.55576919, 0.50212352, 0.47082

In [None]:
def gen_loss_sub(mask,log_prob,bias=0.):
    
    sample_weight = torch.sum(mask, dim=1) + float(bias)
    sample_weight = torch.max(0.0 * sample_weight, sample_weight)
    sample_weight = torch.sign(sample_weight)
    
    mask_avg = mask / (torch.sum(mask, dim=1, keepdim=True)+1e-6)
    
    loss = - log_prob * mask_avg
    loss = torch.sum(loss, dim=1)
    loss = torch.sum(loss * sample_weight) / (torch.sum(sample_weight)+1e-6)
    
    return loss

def gen_loss_sub_unspv(log_prob,prob_target):
     
    loss = - log_prob * prob_target
    loss = torch.sum(loss, dim=1)
    loss = torch.mean(loss)
    
    return loss

def gen_meta_loss(xent_loss_pos_list,net):
    
    grad_list = []
    grad_norm_list = []

    listOfVariableTensors = []
    listOfVariableTensors.extend([p for p in net.parameters()])

    for r in range(len(xent_loss_pos_list)):
        grad_r = torch.autograd.grad(xent_loss_pos_list[r], listOfVariableTensors, retain_graph=True, allow_unused=True)
        grad_norm = get_grad_norm(grad_r)
        grad_list.append(grad_r)
        grad_norm_list.append(grad_norm)

    loss_grad_match = 0.
    for r in [1,2]:
        with torch.no_grad():
            grad_list_0 = grad_list[0]
            grad_norm_list_0 = grad_norm_list[0]
        grad_prod = get_grad_prod(grad_list_0, grad_list[r])
        loss_grad_match += - grad_prod / (grad_norm_list_0 * grad_norm_list[r])
        
    return loss_grad_match
    
    
def anchored_npair_loss_meta_2_tmp_step_fix_clean(inputs_logit, y_true, hidden, net):
    # 控制参数
    eps = 1e-10
    with_l2reg = True
    alpha = 1

    # 读取数据
    corr_threshold = 0.75454967
    input_unspv, input_labels = hidden, y_true

    # 计算input_labels相关的变量
    mask = torch.einsum("ik,jk->ij", input_labels, input_labels)
    mask_neg = 1. - mask
 
    # 计算无监督相关性，以及prob
    input_unspv = F.normalize(input_unspv, p=2, dim=1)
    with torch.no_grad():
        input_unspv = input_unspv
    similarity_matrix_unspv = torch.einsum("ik,jk->ij", input_unspv, input_unspv)
    prob_target_unspv_pos = F.softmax(alpha * similarity_matrix_unspv, dim=1)
    prob_target_unspv_neg = F.softmax(alpha * -similarity_matrix_unspv, dim=1)
    
    # 计算有监督相关性，以及log prob
    similarity_matrix = torch.einsum("ik,jk->ij", inputs_logit, inputs_logit)
    log_prob_pos = torch.log(F.softmax(similarity_matrix, dim=1)+ eps)
    log_prob_neg = torch.log(F.softmax(-similarity_matrix, dim=1)+ eps)

    # 计算gate
    gate_unspv_pos = torch.sign(torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - 0.75454967))
    gate_unspv_neg = torch.sign(torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - (-0.85592163)))
    gate_unspv_pos_label_correction = torch.sign(
        torch.max(0.0 * similarity_matrix_unspv, similarity_matrix_unspv - 0.85592163))
    gate_unspv_neg_label_correction = torch.sign(
        torch.max(0.0 * similarity_matrix_unspv, -similarity_matrix_unspv - (-0.75454967)))

    # 计算l2 loss
    if with_l2reg:
        reg = torch.mean(torch.sum(torch.pow(inputs_logit, 2), dim=1)).float()
        l2loss = torch.mul(0.25 * 0.002, reg)
    else:
        l2loss = 0.0

    #计算无监督loss
    unspv_loss_pos = gen_loss_sub_unspv(log_prob_pos,prob_target_unspv_pos)  
    unspv_loss_neg = gen_loss_sub_unspv(log_prob_neg,prob_target_unspv_neg) 
     
    # 计算正样本的loss
    xent_loss_pos_clean = gen_loss_sub(mask*gate_unspv_pos,log_prob_pos,bias=-1)
    xent_loss_pos_noise = gen_loss_sub(mask*(1. - gate_unspv_pos),log_prob_pos,bias=0.)
    xent_loss_pos_label_correction = gen_loss_sub(mask_neg*gate_unspv_pos_label_correction,log_prob_pos,bias=0.)
      
    xent_loss_pos_list = [xent_loss_pos_clean, xent_loss_pos_noise, xent_loss_pos_label_correction]
    loss_grad_match = gen_meta_loss(xent_loss_pos_list,net)

    xent_loss_pos = xent_loss_pos_clean + loss_grad_match * 0.001

    ## 计算负样本的loss
    xent_loss_neg_clean = gen_loss_sub(mask_neg*gate_unspv_neg,log_prob_neg,bias=0.)
    xent_loss_neg_noise = gen_loss_sub(mask_neg*(1. - gate_unspv_neg),log_prob_neg,bias=0.)
    xent_loss_neg_label_correction = gen_loss_sub(mask_pos*gate_unspv_neg_label_correction,log_prob_neg,bias=0.)
 
    xent_loss_neg_list = [xent_loss_neg_clean, xent_loss_neg_noise, xent_loss_neg_label_correction]
    loss_grad_match = gen_meta_loss(xent_loss_neg_list,net)

    xent_loss_neg = xent_loss_neg_clean + loss_grad_match * 0.001
 
    ## 计算最终的loss
    xent_loss_neg_all = xent_loss_neg + 0.01 * unspv_loss_neg
    xent_loss_pos_all = xent_loss_pos + 0.01 * unspv_loss_pos
    xent_loss = xent_loss_neg_all + xent_loss_pos_all

    loss = (l2loss + xent_loss)
    return loss


