In [2]:
import numpy as np
from typing import List
import torch
import torch.nn.functional as F
from typing import List, Optional
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
def compute_uce(probs, targets, n_bins=100):
    _, nattrs =probs.size()
    nattrs = torch.tensor(nattrs)
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    uce = 0
    bin_uncertainties = []
    bin_errors = []
    prop_in_bin_values = []
    bin_n_samples = []
    bin_variances = []
    # Compute the uncertainty values (entropy)
    uncertainties = (1/torch.log(nattrs))*(-torch.sum(probs * torch.log(probs + 1e-12), dim=1))
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (uncertainties >= bin_lower) * (uncertainties < bin_upper)
        prop_in_bin = in_bin.float().mean()
        prop_in_bin_values.append(prop_in_bin.item() if prop_in_bin.item() > 0 else None)
        if prop_in_bin.item() > 0:
            sample_indices = torch.where(in_bin)[0]
            bin_targets = targets[sample_indices]
            bin_probs = probs[sample_indices]
            error_in_bin = (bin_targets != torch.argmax(bin_probs, dim=1)).float().mean()
            avg_uncertainty_in_bin = uncertainties[in_bin].mean()
            uce += torch.abs(avg_uncertainty_in_bin - error_in_bin) * prop_in_bin
            bin_uncertainties.append(avg_uncertainty_in_bin.item())
            bin_errors.append(error_in_bin.item())
            n_samples_in_bin = sample_indices.size(0)
            bin_n_samples.append(n_samples_in_bin)
            bin_variances.append(torch.var((bin_targets != torch.argmax(bin_probs, dim=1)).float()).item())
        else:
            bin_uncertainties.append(None)
            bin_errors.append(None)
            bin_n_samples.append(None)
            bin_variances.append(None)

    return uce, bin_uncertainties, bin_errors, prop_in_bin_values, bin_n_samples, bin_variances


def find_error_rates(uncertainties, bin_uncertainties, bin_errors):
    error_rates = []
    for uncertainty in uncertainties:
        found = False
        for idx, (bin_uncertainty_lower, bin_uncertainty_upper, bin_error) in enumerate(zip(bin_uncertainties[:-1], bin_uncertainties[1:], bin_errors)):
            if bin_uncertainty_lower is not None and bin_uncertainty_upper is not None and bin_error is not None:
                if bin_uncertainty_lower <= uncertainty.item() < bin_uncertainty_upper:

                    error_rates.append(bin_error)
                    found = True
                    break
        if not found:
            found_= False
            if bin_uncertainties[0] is not None:
                if 0 <=uncertainty< bin_uncertainties[0]:
                    error_rates.append(bin_errors[0])
                    found_= True
                else:
                    for bin_error in reversed(bin_errors):
                        if bin_error is not None:
                            error_rates.append(bin_error)
                            found_= True
                            break
            else:
                if bin_uncertainties[1] is not None:
                    error_rates.append(bin_errors[1])
                    found_= True
                else:
                    error_rates.append(bin_errors[2])
                    found_= True

            if not found_:
                if bin_uncertainties[4] is not None and 0 <=uncertainty< bin_uncertainties[4]:
                    error_rates.append(bin_errors[4])
                    found_= True
            if not found_:
                if bin_uncertainties[8] is not None and bin_uncertainties[8] <=uncertainty< 1:
                    error_rates.append(bin_errors[8])
                    found_= True


    return torch.tensor(error_rates)

def accuracy(y_true, y_pred):

    # 計算正確預測的數量
    correct_predictions = torch.sum(y_true == y_pred)

    # 計算準確度
    accuracy = correct_predictions.item() / y_true.size(0)

    return accuracy
def choose_best_expert_ex(probs_expert1, probs_expert2, targets,val_uce_list_ep1,val_uce_list_ep2, n_bins=10):
    # Compute UCE and bin values for both experts

#     uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1 = compute_uce(probs_expert1, targets, n_bins)
#     uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2 = compute_uce(probs_expert2, targets, n_bins)

    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1,bin_variances_ep1 = val_uce_list_ep1[0],val_uce_list_ep1[1],val_uce_list_ep1[2],val_uce_list_ep1[3],val_uce_list_ep1[4],val_uce_list_ep1[5]
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2,bin_variances_ep2 = val_uce_list_ep2[0],val_uce_list_ep2[1],val_uce_list_ep2[2],val_uce_list_ep2[3],val_uce_list_ep2[4],val_uce_list_ep2[5]

    # Compute uncertainties for both experts
    _, nattrs = probs_expert1.size()
    nattrs = torch.tensor(nattrs)
    uncertainties_expert1 = (1/torch.log(nattrs))*(-torch.sum(probs_expert1 * torch.log(probs_expert1 + 1e-12), dim=1))
    uncertainties_expert2 = (1/torch.log(nattrs))*(-torch.sum(probs_expert2 * torch.log(probs_expert2 + 1e-12), dim=1))
    # Find error rates for both experts
    error_rates_expert1 = find_error_rates(uncertainties_expert1, bin_uncertainties_expert1, bin_errors_expert1)
    error_rates_expert2 = find_error_rates(uncertainties_expert2, bin_uncertainties_expert2, bin_errors_expert2)
    # Choose the expert with lower error rate for each sample
    chosen_expert = (error_rates_expert1 < error_rates_expert2)

    # Get the predictions from both experts
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)

    # Choose the final prediction based on the chosen expert
    final_predictions = torch.where(chosen_expert, preds_expert1, preds_expert2)

    return final_predictions


def choose_best_three_expert(probs_expert1,probs_expert2,probs_expert3 ,targets,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3, n_bins=10):


#     uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1 = compute_uce(probs_expert1, targets, n_bins)
#     uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2 = compute_uce(probs_expert2, targets_pairs, n_bins)
#     uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3, bin_variances_ep3 = compute_uce(probs_expert3, targets, n_bins)
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1,bin_variances_ep1 = val_uce_list_ep1[0],val_uce_list_ep1[1],val_uce_list_ep1[2],val_uce_list_ep1[3],val_uce_list_ep1[4],val_uce_list_ep1[5]
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2,bin_variances_ep2 = val_uce_list_ep2[0],val_uce_list_ep2[1],val_uce_list_ep2[2],val_uce_list_ep2[3],val_uce_list_ep2[4],val_uce_list_ep2[5]
    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3,bin_variances_ep3 = val_uce_list_ep3[0],val_uce_list_ep3[1],val_uce_list_ep3[2],val_uce_list_ep3[3],val_uce_list_ep3[4],val_uce_list_ep3[5]



    # Compute uncertainties for both experts
    _, nattrs = probs_expert1.size()
    nattrs = torch.tensor(nattrs)
    uncertainties_expert1 = (1/torch.log(nattrs))*(-torch.sum(probs_expert1 * torch.log(probs_expert1 + 1e-12), dim=1))
    uncertainties_expert2 = (1/torch.log(nattrs))*(-torch.sum(probs_expert2 * torch.log(probs_expert2 + 1e-12), dim=1))
    uncertainties_expert3 = (1/torch.log(nattrs))*(-torch.sum(probs_expert3 * torch.log(probs_expert3 + 1e-12), dim=1))

    # Find error rates for both experts
    error_rates_expert1 = find_error_rates(uncertainties_expert1, bin_uncertainties_expert1, bin_errors_expert1)
    error_rates_expert2 = find_error_rates(uncertainties_expert2, bin_uncertainties_expert2, bin_errors_expert2)
    error_rates_expert3 = find_error_rates(uncertainties_expert3, bin_uncertainties_expert3, bin_errors_expert3)
    # Choose the expert with lower error rate for each sample

    # Get the predictions from both experts
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)
    preds_expert3 = torch.argmax(probs_expert3, dim=1)

    # 將三個錯誤率堆疊成一個張量
    error_rates = torch.stack([error_rates_expert1, error_rates_expert2, error_rates_expert3])

    # 找出最小錯誤率的索引
    _, min_error_rate_indices = torch.min(error_rates, dim=0)

    # 根據最小錯誤率的索引選擇最終的預測
    final_predictions = torch.where(min_error_rate_indices == 0, preds_expert1,
                                    torch.where(min_error_rate_indices == 1, preds_expert2, preds_expert3))
    
    
    POE_probs_ = torch.stack([probs_expert1, probs_expert2,probs_expert3])
    POE_probs = product_of_experts(POE_probs_)
    POE_pred = np.argmax(POE_probs, axis=1)
    
    POE_pred = torch.tensor(POE_pred)  # Convert numpy array to torch tensor

    SOE_probs_ = (probs_expert1+probs_expert2+probs_expert3)/3
    SOE_pred = np.argmax(SOE_probs_, axis=1)
    
    
    

    initial_predictions_probs = torch.where(min_error_rate_indices.unsqueeze(-1) == 0, probs_expert1,
                                           torch.where(min_error_rate_indices.unsqueeze(-1) == 1, probs_expert2, probs_expert3))
    
    uce_expert_SPE, _, _, _,_,_ =compute_uce(initial_predictions_probs, targets)
    uce_expert_SOE, _, _, _,_,_ =compute_uce(SOE_probs_, targets)
    uce_expert_POE, _, _, _,_,_ =compute_uce(POE_probs, targets)
    print("SPE_UCE: ",uce_expert_SPE,"SOE_UCE: ",uce_expert_SOE,"POE_UCE: ",uce_expert_POE)
    return final_predictions

def voting(preds_expert1: List[int], preds_expert2: List[int], preds_expert3: List[int], default_expert: int) -> torch.Tensor:
    assert default_expert in [1, 2, 3], "Default expert must be either 1, 2, or 3"

    final_preds = []
    for p1, p2, p3 in zip(preds_expert1, preds_expert2, preds_expert3):
        vote_counts = Counter([p1, p2, p3])
        max_vote_count = max(vote_counts.values())
        most_common = [k for k, v in vote_counts.items() if v == max_vote_count]

        if len(most_common) > 1:
            if default_expert == 1:
                final_preds.append(p1)
            elif default_expert == 2:
                final_preds.append(p2)
            else:  # default_expert == 3
                final_preds.append(p3)
        else:
            final_preds.append(most_common[0])

    final_preds = torch.tensor(final_preds, dtype=torch.int64)  # or your desired data type
    return final_preds
def weighted_voting(preds_expert1: List[int], preds_expert2: List[int], preds_expert3: List[int], weights: List[float], default_expert: int) -> torch.Tensor:
    assert default_expert in [1, 2, 3], "Default expert must be either 1, 2, or 3"

    final_preds = []
    for p1, p2, p3 in zip(preds_expert1, preds_expert2, preds_expert3):
        weighted_vote_counts = Counter()
        for pred, weight in zip([p1, p2, p3], weights):
            weighted_vote_counts[pred] += weight

        max_vote_count = max(weighted_vote_counts.values())
        most_common = [k for k, v in weighted_vote_counts.items() if v == max_vote_count]

        if len(most_common) > 1:
            if default_expert == 1:
                final_preds.append(p1)
            elif default_expert == 2:
                final_preds.append(p2)
            else:  # default_expert == 3
                final_preds.append(p3)
        else:
            final_preds.append(most_common[0])

    final_preds = torch.tensor(final_preds, dtype=torch.int64)  # or your desired data type
    return final_preds
def product_of_experts(predictions):
    # Multiply predictions together
    product = torch.prod(predictions, dim=0)

    # Normalize result
    product /= torch.sum(product)

    return product

def weighted_voting_(prob_expert1: List[float], prob_expert2: List[float], weights: List[float], default_expert: int) -> torch.Tensor:
    assert default_expert in [1, 2], "Default expert must be either 1 or 2"

    final_probs = []
    for p1, p2 in zip(prob_expert1, prob_expert2):
        weighted_probs = Counter()
        for prob, weight in zip([p1, p2], weights):
            weighted_probs[prob] += weight

        max_prob_count = max(weighted_probs.values())
        most_common = [k for k, v in weighted_probs.items() if v == max_prob_count]

        if len(most_common) > 1:
            if default_expert == 1:
                final_probs.append(p1)
            else:  # default_expert == 2
                final_probs.append(p2)
        else:
            final_probs.append(most_common[0])

    final_probs = torch.tensor(final_probs, dtype=torch.float32)  # or your desired data type
    return final_probs

def voting_(prob_expert1: Optional[List[float]], prob_expert2: Optional[List[float]], default_expert: int) -> torch.Tensor:
    assert default_expert in [1, 2], "Default expert must be either 1 or 2"

    final_probs = []
    for p1, p2 in zip(prob_expert1, prob_expert2):
        prob_counts = Counter([p for p in [p1, p2] if p is not None])
        if not prob_counts:  # All predictions are None
            final_probs.append(None)
            continue

        max_prob_count = max(prob_counts.values())
        most_common = [k for k, v in prob_counts.items() if v == max_prob_count]

        if len(most_common) > 1:
            if default_expert == 1:
                final_probs.append(p1)
            else:  # default_expert == 2
                final_probs.append(p2)
        else:
            final_probs.append(most_common[0])

    final_probs = torch.tensor(final_probs, dtype=torch.float32)  # or your desired data type
    return final_probs

# def plot_dot_UCE_diagram(uce_value, bin_uncertainties, bin_errors, prop_in_bin_values, bin_n_samples, bin_variances, model_index, threshold=0.005):
#     global save_name
#     plt.figure(figsize=(6, 6))
#     plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect calibration")
    
#     # 筛选prop_in_bin值大于等于threshold的点
#     valid_indices = [i for i, prop in enumerate(prop_in_bin_values) if prop is not None and prop >= threshold]
#     valid_bin_uncertainties = [bin_uncertainties[i] for i in valid_indices]
#     valid_bin_errors = [bin_errors[i] for i in valid_indices]
#     valid_prop_in_bin_values = [prop_in_bin_values[i] for i in valid_indices]
#     valid_bin_n_samples  = [bin_n_samples[i] for i in valid_indices]
#     valid_bin_variances  = [bin_variances[i] for i in valid_indices]
    
#     plt.scatter(valid_bin_uncertainties, valid_bin_errors, marker='o', color='blue', label="Model {}".format(model_index ))
#     plt.xlabel("Uncertainty", fontsize=14)
#     plt.ylabel("Error", fontsize=14)
#     plt.title("Reliability Diagram for Model {} (UCE={:.4f})".format(model_index , uce_value.item()), fontsize=16)
#     plt.xlim(0, 1)
#     plt.ylim(0, 1)
#     plt.xticks(np.arange(0, 1.1, 0.1), fontsize=12)
#     plt.yticks(np.arange(0, 1.1, 0.1), fontsize=12)
#     plt.grid(color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
#     plt.gca().set_axisbelow(True)
#     plt.legend(fontsize=12)
#     plt.tight_layout()

#     for i, txt in enumerate(valid_bin_n_samples):
#         plt.annotate("n={}".format(txt), (valid_bin_uncertainties[i], valid_bin_errors[i]), fontsize=8, ha='center', va='bottom', textcoords="offset points", xytext=(0,5))
#         plt.annotate("var={:.2f}".format(valid_bin_variances[i]), (valid_bin_uncertainties[i], valid_bin_errors[i]), fontsize=8, ha='center', va='bottom', textcoords="offset points", xytext=(0,20))
#     Path('plt/').mkdir(parents=True, exist_ok=True)
#     plt.savefig('plt/'+save_name +"UCE_model_{}.svg".format(model_index))
#     plt.close()
    
from matplotlib.colors import Normalize
import matplotlib.cm as cm
def plot_dot_UCE_diagram(uce_value, bin_uncertainties, bin_errors, prop_in_bin_values, bin_n_samples, bin_variances, model_index, threshold=0.005):
    global save_name
    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect calibration")
    
    # 筛选prop_in_bin值大于等于threshold的点
    valid_indices = [i for i, prop in enumerate(prop_in_bin_values) if prop is not None and prop >= threshold]
    valid_bin_uncertainties = [bin_uncertainties[i] for i in valid_indices]
    valid_bin_errors = [bin_errors[i] for i in valid_indices]
    valid_prop_in_bin_values = [prop_in_bin_values[i] for i in valid_indices]
    valid_bin_n_samples  = [bin_n_samples[i] for i in valid_indices]
    valid_bin_variances  = [bin_variances[i] for i in valid_indices]
    
    # 计算中心点
    centers = [0.05 + 0.1 * i for i in range(10)]
    bins = [0 + 0.1 * i for i in range(11)]  # Including the rightmost edge for binning
    
    # 对bin_uncertainties值进行分箱，并找到对应的中心点
    hist_values = np.digitize(valid_bin_uncertainties, bins) - 1
    hist_centers = [centers[i] for i in hist_values]
    
    # Use Normalize and colormap to change the color of the bars based on bin_n_samples
    norm = Normalize(vmin=min(valid_bin_n_samples), vmax=max(valid_bin_n_samples))
    colormap = cm.cividis   # Changed to plasma colormap
    colors = [colormap(norm(value)) for value in valid_bin_n_samples]
    
    plt.bar(hist_centers, valid_bin_errors, width=0.05, color=colors)
    plt.xlabel("Uncertainty", fontsize=14)
    plt.ylabel("Error", fontsize=14)
    plt.title("Reliability Diagram for Model {} (UCE={:.4f})".format(model_index , uce_value.item()), fontsize=16)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xticks(np.arange(0, 1.1, 0.1), fontsize=12)
    plt.yticks(np.arange(0, 1.1, 0.1), fontsize=12)
    plt.grid(color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
    plt.gca().set_axisbelow(True)
    plt.legend(fontsize=12)
    plt.tight_layout()
    
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # You can set a dummy array
    cbar = plt.colorbar(sm, orientation='vertical', label='Number of Samples')
    cbar.set_label('Number of Samples', rotation=270, labelpad=15)

    for i, txt in enumerate(valid_bin_n_samples):
        plt.annotate("n={}".format(txt), (hist_centers[i], valid_bin_errors[i]), fontsize=8, ha='center', va='bottom', textcoords="offset points", xytext=(0,5))
#         plt.annotate("var={:.2f}".format(valid_bin_variances[i]), (hist_centers[i], valid_bin_errors[i]), fontsize=8, ha='center', va='bottom', textcoords="offset points", xytext=(0,20))

    Path('plt/').mkdir(parents=True, exist_ok=True)
    plt.savefig('plt/'+save_name +"UCE_model_{}.svg".format(model_index))
    plt.close()

### 原本的

In [3]:
def cal_val_state(ep1_logits,ep2_logits,ep3_logits,labels,phase):
  global val_uce_list_ep1
  global val_uce_list_ep2
  global val_uce_list_ep3

  ep1_logits = F.softmax(ep1_logits, dim=1)
  ep2_logits = F.softmax(ep2_logits, dim=1)
  ep3_logits = F.softmax(ep3_logits, dim=1)

  if phase=='val':
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1 = compute_uce(ep1_logits, labels)
    val_uce_list_ep1.append(uce_expert1)
    val_uce_list_ep1.append(bin_uncertainties_expert1)
    val_uce_list_ep1.append(bin_errors_expert1)
    val_uce_list_ep1.append(prop_in_bin_values_expert1)
    val_uce_list_ep1.append(bin_n_samples_ep1)
    val_uce_list_ep1.append(bin_variances_ep1)

    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2 = compute_uce(ep2_logits, labels)
    val_uce_list_ep2.append(uce_expert2)
    val_uce_list_ep2.append(bin_uncertainties_expert2)
    val_uce_list_ep2.append(bin_errors_expert2)
    val_uce_list_ep2.append(prop_in_bin_values_expert2)
    val_uce_list_ep2.append(bin_n_samples_ep2)
    val_uce_list_ep2.append(bin_variances_ep2)

    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3, bin_variances_ep3 = compute_uce(ep3_logits, labels)
    val_uce_list_ep3.append(uce_expert3)
    val_uce_list_ep3.append(bin_uncertainties_expert3)
    val_uce_list_ep3.append(bin_errors_expert3)
    val_uce_list_ep3.append(prop_in_bin_values_expert3)
    val_uce_list_ep3.append(bin_n_samples_ep3)
    val_uce_list_ep3.append(bin_variances_ep3)
    
#     tabel_pred_ep123 = choose_best_three_expert(ep1_logits,ep2_logits,ep3_logits,labels,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3)
#     tabel_acc_ep123 = accuracy(tabel_pred_ep123,labels)
    
#     tabel_pred_ep123_new = choose_best_three_expert_new(ep1_logits,ep2_logits,ep3_logits,labels,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3)
#     tabel_acc_ep123_new = accuracy(tabel_pred_ep123_new,labels)
    
#     print("123: " ,tabel_acc_ep123 ,"123_new: ",tabel_acc_ep123_new)
  else:
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1 = compute_uce(ep1_logits, labels)
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2 = compute_uce(ep2_logits, labels)
    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3, bin_variances_ep3 = compute_uce(ep3_logits, labels)
    
    plot_dot_UCE_diagram( uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1,1)
    plot_dot_UCE_diagram( uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2,2)
    plot_dot_UCE_diagram( uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3, bin_variances_ep3,3)
    
    preds_expert1 = torch.argmax(ep1_logits, dim=1)
    expert1_acc= accuracy(preds_expert1,labels)

    preds_expert2 = torch.argmax(ep2_logits, dim=1)
    expert2_acc= accuracy(preds_expert2,labels)

    preds_expert3 = torch.argmax(ep3_logits, dim=1)
    expert3_acc= accuracy(preds_expert3,labels)

    print(phase,'expert1_acc: ',expert1_acc,'expert2_acc: ',expert2_acc,'expert3_acc: ',expert3_acc)
    print("1: ",uce_expert1," 2: ",uce_expert2,' 3: ',uce_expert3)


    #計算table準確度
    table_pred_ep12 = choose_best_expert_ex(ep1_logits, ep2_logits,labels,val_uce_list_ep1,val_uce_list_ep2)
    table_acc_ep12 = accuracy(table_pred_ep12,labels)

    table_expert_ep13 = choose_best_expert_ex(ep1_logits, ep3_logits, labels,val_uce_list_ep1,val_uce_list_ep3)
    table_acc_ep13 = accuracy(table_expert_ep13,labels)

    table_pred_ep23 = choose_best_expert_ex(ep2_logits, ep3_logits,labels,val_uce_list_ep2,val_uce_list_ep3)
    table_acc_ep23 = accuracy(table_pred_ep23,labels)

    # 計算3位專家綜合準確度
    tabel_pred_ep123 = choose_best_three_expert(ep1_logits,ep2_logits,ep3_logits,labels,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3)
    tabel_acc_ep123 = accuracy(tabel_pred_ep123,labels)
    
    tabel_pred_ep123_new = choose_best_three_expert_new(ep1_logits,ep2_logits,ep3_logits,labels,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3)
    tabel_acc_ep123_new = accuracy(tabel_pred_ep123_new,labels)


    # 簡單投票，都不一樣選2
    simple_voting_pred_12 = voting_(preds_expert1, preds_expert2, default_expert=1)
    simple_voting_acc_12 = accuracy(simple_voting_pred_12, labels)

    simple_voting_pred_23 = voting_(preds_expert2, preds_expert3, default_expert=1)
    simple_voting_acc_23 = accuracy(simple_voting_pred_23, labels)

    simple_voting_pred_13 = voting_(preds_expert1, preds_expert3, default_expert=1)
    simple_voting_acc_13 = accuracy(simple_voting_pred_13, labels)

    # 回到原始函數，進行三個專家的簡單投票
    simple_voting_pred_123 = voting(preds_expert1, preds_expert2, preds_expert3, default_expert=1)
    simple_voting_acc_123 = accuracy(simple_voting_pred_123, labels)

    # 加權投票，都不一樣選2
    weighted_voting_pred_12 = weighted_voting_(preds_expert1, preds_expert2, weights=[0.5, 0.5], default_expert=1)
    weighted_voting_acc_12 = accuracy(weighted_voting_pred_12, labels)

    weighted_voting_pred_23 = weighted_voting_(preds_expert2, preds_expert3, weights=[0.5, 0.5], default_expert=1)
    weighted_voting_acc_23 = accuracy(weighted_voting_pred_23, labels)

    weighted_voting_pred_13 = weighted_voting_(preds_expert1, preds_expert3, weights=[0.5, 0.5], default_expert=1)
    weighted_voting_acc_13 = accuracy(weighted_voting_pred_13, labels)

    # 回到原始函數，進行三個專家的加權投票
    weighted_voting_pred_123 = weighted_voting(preds_expert1, preds_expert2, preds_expert3, weights=[0.3, 0.4, 0.3], default_expert=2)
    weighted_voting_acc_123 = accuracy(weighted_voting_pred_123, labels)

    SOE_probs_ = (ep1_logits+ep2_logits+ep3_logits)/3
    SOE_pred = np.argmax(SOE_probs_, axis=1)
    SOE_acc_123 =accuracy(SOE_pred,labels)

    POE_probs_ = torch.stack([ep1_logits, ep2_logits,ep3_logits])
    POE_probs = product_of_experts(POE_probs_)
    POE_pred = np.argmax(POE_probs, axis=1)
    POE_acc_123 =accuracy(POE_pred,labels)

    def calculate_SOE_POE(ep1, ep2, ep3=None):
        if ep3 is not None:
            SOE_probs = (ep1 + ep2 + ep3) / 3
            POE_probs = product_of_experts(torch.stack([ep1, ep2, ep3]))
        else:
            SOE_probs = (ep1 + ep2) / 2
            POE_probs = product_of_experts(torch.stack([ep1, ep2]))

        SOE_pred = torch.argmax(SOE_probs, dim=1)
        POE_pred = torch.argmax(POE_probs, dim=1)

        return SOE_pred, POE_pred
    SOE_pred_12, POE_pred_12 = calculate_SOE_POE(ep1_logits, ep2_logits)
    SOE_pred_23, POE_pred_23 = calculate_SOE_POE(ep2_logits, ep3_logits)
    SOE_pred_13, POE_pred_13 = calculate_SOE_POE(ep1_logits, ep3_logits)

    SOE_acc_12 = accuracy(SOE_pred_12, labels)
    SOE_acc_23 = accuracy(SOE_pred_23, labels)
    SOE_acc_13 = accuracy(SOE_pred_13, labels)

    POE_acc_12 = accuracy(POE_pred_12, labels)
    POE_acc_23 = accuracy(POE_pred_23, labels)
    POE_acc_13 = accuracy(POE_pred_13, labels)

    print(f"Table Accuracy:\n12: {table_acc_ep12:.3f} 23: {table_acc_ep23:.3f} 13: {table_acc_ep13:.3f} 123: {tabel_acc_ep123:.4f} 123_new: {tabel_acc_ep123_new:.4f}\n")

    print(f"Simple Voting Accuracy:\n12: {simple_voting_acc_12:.3f} 23: {simple_voting_acc_23:.3f} 13: {simple_voting_acc_13:.3f} 123: {simple_voting_acc_123:.3f}\n")
    print(f"weighted Voting Accuracy:\n12: {weighted_voting_acc_12:.3f} 23: {weighted_voting_acc_23:.3f} 13: {weighted_voting_acc_13:.3f} 123: {weighted_voting_acc_123:.3f}\n")
    print(f"SOE Accuracy:\n12: {SOE_acc_12:.3f} 23: {SOE_acc_23:.3f} 13: {SOE_acc_13:.3f} 123: {SOE_acc_123:.4f}\n")
    print(f"POE Accuracy:\n12: {POE_acc_12:.3f} 23: {POE_acc_23:.3f} 13: {POE_acc_13:.3f} 123: {POE_acc_123:.4f}\n")

    

In [4]:
import os
def choose_best_three_expert(probs_expert1,probs_expert2,probs_expert3 ,targets,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3, n_bins=10):

    
#     uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1 = compute_uce(probs_expert1, targets, n_bins)
#     uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2 = compute_uce(probs_expert2, targets_pairs, n_bins)
#     uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3, bin_variances_ep3 = compute_uce(probs_expert3, targets, n_bins)
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1,bin_variances_ep1 = val_uce_list_ep1[0],val_uce_list_ep1[1],val_uce_list_ep1[2],val_uce_list_ep1[3],val_uce_list_ep1[4],val_uce_list_ep1[5]
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2,bin_variances_ep2 = val_uce_list_ep2[0],val_uce_list_ep2[1],val_uce_list_ep2[2],val_uce_list_ep2[3],val_uce_list_ep2[4],val_uce_list_ep2[5]
    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3,bin_variances_ep3 = val_uce_list_ep3[0],val_uce_list_ep3[1],val_uce_list_ep3[2],val_uce_list_ep3[3],val_uce_list_ep3[4],val_uce_list_ep3[5]



    # Compute uncertainties for both experts
    _, nattrs = probs_expert1.size()
    nattrs = torch.tensor(nattrs)
    uncertainties_expert1 = (1/torch.log(nattrs))*(-torch.sum(probs_expert1 * torch.log(probs_expert1 + 1e-12), dim=1))
    uncertainties_expert2 = (1/torch.log(nattrs))*(-torch.sum(probs_expert2 * torch.log(probs_expert2 + 1e-12), dim=1))
    uncertainties_expert3 = (1/torch.log(nattrs))*(-torch.sum(probs_expert3 * torch.log(probs_expert3 + 1e-12), dim=1))

    # Find error rates for both experts
    error_rates_expert1 = find_error_rates(uncertainties_expert1, bin_uncertainties_expert1, bin_errors_expert1)
    error_rates_expert2 = find_error_rates(uncertainties_expert2, bin_uncertainties_expert2, bin_errors_expert2)
    error_rates_expert3 = find_error_rates(uncertainties_expert3, bin_uncertainties_expert3, bin_errors_expert3)
    # Choose the expert with lower error rate for each sample

    # Get the predictions from both experts
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)
    preds_expert3 = torch.argmax(probs_expert3, dim=1)

    # 將三個錯誤率堆疊成一個張量
    error_rates = torch.stack([error_rates_expert1, error_rates_expert2, error_rates_expert3])

    # 找出最小錯誤率的索引
    _, min_error_rate_indices = torch.min(error_rates, dim=0)

    # 根據最小錯誤率的索引選擇最終的預測
    final_predictions = torch.where(min_error_rate_indices == 0, preds_expert1,
                                    torch.where(min_error_rate_indices == 1, preds_expert2, preds_expert3))
    
    compute_mae_error_and_uncertainty(probs_expert1, probs_expert2, probs_expert3, targets,
    uncertainties_expert1, uncertainties_expert2, uncertainties_expert3, error_rates_expert1, error_rates_expert2, error_rates_expert3)
    
    
    
    global save_name
    df = analyze_errors(error_rates, final_predictions, targets)
    
    # Check if the directory 'df' exists, if not, create it
    if not os.path.exists('df'):
        os.makedirs('df')
    
    # Save the dataframe to a CSV file in the 'df' directory
    df.to_csv('df/'+save_name+'_error_analysis.csv', index=True)
    
    
    
    return final_predictions

std_deviation_per_position = None
error_rates = None
POE_pred = None
final_predictions = None

def choose_best_three_expert_new(probs_expert1,probs_expert2,probs_expert3 ,targets,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3, n_bins=10):
    global std_deviation_per_position
    global error_rates
    global threshold_
    global POE_pred
    global final_predictions
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1,bin_variances_ep1 = val_uce_list_ep1[0],val_uce_list_ep1[1],val_uce_list_ep1[2],val_uce_list_ep1[3],val_uce_list_ep1[4],val_uce_list_ep1[5]
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2,bin_variances_ep2 = val_uce_list_ep2[0],val_uce_list_ep2[1],val_uce_list_ep2[2],val_uce_list_ep2[3],val_uce_list_ep2[4],val_uce_list_ep2[5]
    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3,bin_variances_ep3 = val_uce_list_ep3[0],val_uce_list_ep3[1],val_uce_list_ep3[2],val_uce_list_ep3[3],val_uce_list_ep3[4],val_uce_list_ep3[5]



    # Compute uncertainties for both experts
    _, nattrs = probs_expert1.size()
    nattrs = torch.tensor(nattrs)
    uncertainties_expert1 = (1/torch.log(nattrs))*(-torch.sum(probs_expert1 * torch.log(probs_expert1 + 1e-12), dim=1))
    uncertainties_expert2 = (1/torch.log(nattrs))*(-torch.sum(probs_expert2 * torch.log(probs_expert2 + 1e-12), dim=1))
    uncertainties_expert3 = (1/torch.log(nattrs))*(-torch.sum(probs_expert3 * torch.log(probs_expert3 + 1e-12), dim=1))

    # Find error rates for both experts
    error_rates_expert1 = find_error_rates(uncertainties_expert1, bin_uncertainties_expert1, bin_errors_expert1)
    error_rates_expert2 = find_error_rates(uncertainties_expert2, bin_uncertainties_expert2, bin_errors_expert2)
    error_rates_expert3 = find_error_rates(uncertainties_expert3, bin_uncertainties_expert3, bin_errors_expert3)
    # Choose the expert with lower error rate for each sample

    # Get the predictions from both experts
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)
    preds_expert3 = torch.argmax(probs_expert3, dim=1)

    # 將三個錯誤率堆疊成一個張量
    error_rates = torch.stack([error_rates_expert1, error_rates_expert2, error_rates_expert3])
    
    std_deviation_per_position = torch.std(error_rates, dim=0)
    mean_value = torch.mean(std_deviation_per_position)
    print("mean: ",mean_value)
#     print(std_deviation_per_position)
    # 找出最小錯誤率的索引
    _, min_error_rate_indices = torch.min(error_rates, dim=0)
    
    POE_probs_ = torch.stack([probs_expert1, probs_expert2,probs_expert3])
    POE_probs = product_of_experts(POE_probs_)
    POE_pred = np.argmax(POE_probs, axis=1)
    
    POE_pred = torch.tensor(POE_pred)  # Convert numpy array to torch tensor

    SOE_probs_ = (probs_expert1+probs_expert2+probs_expert3)/3
    SOE_pred = np.argmax(SOE_probs_, axis=1)

    # 根據最小錯誤率的索引選擇最終的預測
    final_predictions = torch.where(min_error_rate_indices == 0, preds_expert1,
                                      torch.where(min_error_rate_indices == 1, preds_expert2, preds_expert3))

    initial_predictions_probs = torch.where(min_error_rate_indices.unsqueeze(-1) == 0, probs_expert1,
                                           torch.where(min_error_rate_indices.unsqueeze(-1) == 1, probs_expert2, probs_expert3))
    
    POE_probs_ = torch.stack([POE_probs,SOE_probs_,initial_predictions_probs])
    POE_initial_predictions_probs = product_of_experts(POE_probs_)
    SOE_initial_predictions_probs = (SOE_probs_+POE_probs+initial_predictions_probs)/2
    
    POE_final_predictions = np.argmax(POE_initial_predictions_probs, axis=1)
    
    SOE_final_predictions = np.argmax(SOE_initial_predictions_probs, axis=1)
    
    POE_acc =accuracy(POE_final_predictions,targets)
    SOE_acc =accuracy(SOE_final_predictions,targets)
    print("POE_SPE_acc: ",POE_acc,"SOE_SPE_acc: ",SOE_acc)
#     threshold = torch.quantile(std_deviation_per_position, threshold_)

#     threshold = torch.quantile(uncertainties_, threshold_)
#     print("threshold: ",threshold)
    # 根據 std_deviation_per_position 更新預測
#     final_predictions = torch.where(std_deviation_per_position < threshold, POE_pred, initial_predictions)
#     final_predictions = torch.where(std_deviation_per_position < threshold, initial_predictions , POE_pred)

    return SOE_final_predictions

def compute_uce(probs, targets, n_bins=10):
    _, nattrs =probs.size()
    nattrs = torch.tensor(nattrs)
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    uce = 0
    bin_uncertainties = []
    bin_errors = []
    prop_in_bin_values = []
    bin_n_samples = []
    bin_variances = []
    # Compute the uncertainty values (entropy)
    uncertainties = (1/torch.log(nattrs))*(-torch.sum(probs * torch.log(probs + 1e-12), dim=1))
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (uncertainties >= bin_lower) * (uncertainties < bin_upper)
        prop_in_bin = in_bin.float().mean()
        prop_in_bin_values.append(prop_in_bin.item() if prop_in_bin.item() > 0 else None)
        if prop_in_bin.item() > 0:
            sample_indices = torch.where(in_bin)[0]
            bin_targets = targets[sample_indices]
            bin_probs = probs[sample_indices]
            error_in_bin = (bin_targets != torch.argmax(bin_probs, dim=1)).float().mean()
            avg_uncertainty_in_bin = uncertainties[in_bin].mean()
            uce += torch.abs(avg_uncertainty_in_bin - error_in_bin) * prop_in_bin
            bin_uncertainties.append(avg_uncertainty_in_bin.item())
            bin_errors.append(error_in_bin.item())
            n_samples_in_bin = sample_indices.size(0)
            bin_n_samples.append(n_samples_in_bin)
            bin_variances.append(torch.var((bin_targets != torch.argmax(bin_probs, dim=1)).float()).item())
        else:
            bin_uncertainties.append(None)
            bin_errors.append(None)
            bin_n_samples.append(None)
            bin_variances.append(None)

    return uce, bin_uncertainties, bin_errors, prop_in_bin_values, bin_n_samples, bin_variances


def compute_mae_error_and_uncertainty(probs_expert1, probs_expert2, probs_expert3, targets, uncertainties_expert1, uncertainties_expert2, uncertainties_expert3, error_rates_expert1, error_rates_expert2, error_rates_expert3):
    device = torch.device("cuda:0")
    avg_uncertainty = (uncertainties_expert1 + uncertainties_expert2 + uncertainties_expert3) / 3
    error_rates_expert1 = error_rates_expert1.to(device)
    error_rates_expert2 = error_rates_expert2.to(device)
    error_rates_expert3 = error_rates_expert3.to(device)
    
    # Get the predictions from each expert
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)
    preds_expert3 = torch.argmax(probs_expert3, dim=1)

    # Calculate the real error rates
    real_error_expert1 = (preds_expert1 != targets).float().mean()
    real_error_expert2 = (preds_expert2 != targets).float().mean()
    real_error_expert3 = (preds_expert3 != targets).float().mean()
    
    
    # Calculate MAE for error rates
    mae_error_expert1 = torch.abs(error_rates_expert1 - real_error_expert1).mean()
    mae_error_expert2 = torch.abs(error_rates_expert2 - real_error_expert2).mean()
    mae_error_expert3 = torch.abs(error_rates_expert3 - real_error_expert3).mean()
    

#     # Calculate MAE for uncertainties using average uncertainty as the reference
#     mae_uncertainty_expert1 = torch.abs(uncertainties_expert1 - avg_uncertainty).mean()
#     mae_uncertainty_expert2 = torch.abs(uncertainties_expert2 - avg_uncertainty).mean()
#     mae_uncertainty_expert3 = torch.abs(uncertainties_expert3 - avg_uncertainty).mean()

    # Print results
    print("MAE Error Expert 1:", mae_error_expert1.item())
    print("MAE Error Expert 2:", mae_error_expert2.item())
    print("MAE Error Expert 3:", mae_error_expert3.item())
#     print("MAE Uncertainty Expert 1:", mae_uncertainty_expert1.item())
#     print("MAE Uncertainty Expert 2:", mae_uncertainty_expert2.item())
#     print("MAE Uncertainty Expert 3:", mae_uncertainty_expert3.item())
    
    
def analyze_errors(error_rates, final_predictions, targets, n_samples=500):
    samples_idx = torch.randint(0, len(targets), (n_samples,))
    
    expert1_error_rates = error_rates[0][samples_idx].tolist()
    expert2_error_rates = error_rates[1][samples_idx].tolist()
    expert3_error_rates = error_rates[2][samples_idx].tolist()
    
    spe_results = final_predictions[samples_idx].tolist()
    real_results = targets[samples_idx].tolist()
    
    # Checking if SPE results are correct
    spe_is_correct = [1 if spe_results[i] == real_results[i] else 0 for i in range(n_samples)]
    
    data = {
        "專家1錯誤率": expert1_error_rates,
        "專家2錯誤率": expert2_error_rates,
        "專家3錯誤率": expert3_error_rates,
        "SPE選擇結果": spe_results,
        "真實結果": real_results,
        "SPE結果是否正確": spe_is_correct
    }
    
    df = pd.DataFrame(data, index=[f"樣本{i+1}" for i in range(n_samples)])
    
    # Analyze the variance
    correct_predictions = df[df["SPE結果是否正確"] == 1]
    incorrect_predictions = df[df["SPE結果是否正確"] == 0]
    
    variance_correct = (correct_predictions[["專家1錯誤率", "專家2錯誤率", "專家3錯誤率"]].var(axis=1).mean())
    variance_incorrect = (incorrect_predictions[["專家1錯誤率", "專家2錯誤率", "專家3錯誤率"]].var(axis=1).mean())
    
    print("Average variance for correctly predicted samples:", variance_correct)
    print("Average variance for incorrectly predicted samples:", variance_incorrect)
    
    return df


In [5]:
import pickle

root_path = '../train_bls/research/BalancedMetaSoftmax-Classification/logs/CIFAR10_LT/'
k_fold = ''

with open(root_path+ 'models/'+ k_fold + 'resnet32_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data1 = pickle.load(f)
ep1_val_logits = data1['logits']
ep1_val_labels = data1['labels']

with open(root_path + 'models/'+ k_fold + 'resnet32_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data2 = pickle.load(f)
ep1_test_logits = data2['logits']
ep1_test_labels = data2['labels']


with open(root_path+ 'models/'+ k_fold + 'resnet32_balanced_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data3 = pickle.load(f)
ep2_val_logits = data3['logits']
ep2_val_labels = data3['labels']

with open(root_path+'models/'+ k_fold + 'resnet32_balanced_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data4 = pickle.load(f)
ep2_test_logits = data4['logits']
ep2_test_labels = data4['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_decouple_balanced_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data5 = pickle.load(f)
ep3_val_logits = data5['logits']
ep3_val_labels = data5['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_decouple_balanced_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data6 = pickle.load(f)
ep3_test_logits = data6['logits']
ep3_test_labels = data6['labels']

with open(root_path+'clslearn/'+ k_fold + 'resnet32_balms_imba100/train_feat_all.pkl', 'rb') as f:
    data7 = pickle.load(f)
ep4_val_logits = data7['logits']
ep4_val_labels = data7['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_balms_imba100/valfeat_all.pkl', 'rb') as f:
    data8 = pickle.load(f)
ep4_test_logits = data8['logits']
ep4_test_labels = data8['labels']



val_uce_list_ep1 =[]
val_uce_list_ep2 =[]
val_uce_list_ep3 =[]

ep1_val_logits_1 = torch.from_numpy(ep1_val_logits)
ep2_val_logits_1 = torch.from_numpy(ep2_val_logits)
ep3_val_logits_1 = torch.from_numpy(ep3_val_logits)
ep4_val_logits_1 = torch.from_numpy(ep4_val_logits)

ep1_test_logits_1 = torch.from_numpy(ep1_test_logits)
ep2_test_logits_1 = torch.from_numpy(ep2_test_logits)
ep3_test_logits_1 = torch.from_numpy(ep3_test_logits)
ep4_test_logits_1 = torch.from_numpy(ep4_test_logits)

val_label = torch.from_numpy(ep3_val_labels)
test_label = torch.from_numpy(ep3_test_labels)


In [6]:
import pickle

root_path = '../train_bls/research/BalancedMetaSoftmax-Classification/logs/CIFAR10_LT/'
k_fold = '1_'

with open(root_path+ 'models/'+ k_fold + 'resnet32_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data1 = pickle.load(f)
ep1_val_logits = data1['logits']
ep1_val_labels = data1['labels']

with open(root_path + 'models/'+ k_fold + 'resnet32_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data2 = pickle.load(f)
ep1_test_logits = data2['logits']
ep1_test_labels = data2['labels']


with open(root_path+ 'models/'+ k_fold + 'resnet32_balanced_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data3 = pickle.load(f)
ep2_val_logits = data3['logits']
ep2_val_labels = data3['labels']

with open(root_path+'models/'+ k_fold + 'resnet32_balanced_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data4 = pickle.load(f)
ep2_test_logits = data4['logits']
ep2_test_labels = data4['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_decouple_balanced_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data5 = pickle.load(f)
ep3_val_logits = data5['logits']
ep3_val_labels = data5['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_decouple_balanced_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data6 = pickle.load(f)
ep3_test_logits = data6['logits']
ep3_test_labels = data6['labels']

with open(root_path+'clslearn/'+ k_fold + 'resnet32_balms_imba100/train_feat_all.pkl', 'rb') as f:
    data7 = pickle.load(f)
ep4_val_logits = data7['logits']
ep4_val_labels = data7['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_balms_imba100/valfeat_all.pkl', 'rb') as f:
    data8 = pickle.load(f)
ep4_test_logits = data8['logits']
ep4_test_labels = data8['labels']



val_uce_list_ep1 =[]
val_uce_list_ep2 =[]
val_uce_list_ep3 =[]

ep1_val_logits_2 = torch.from_numpy(ep1_val_logits)
ep2_val_logits_2 = torch.from_numpy(ep2_val_logits)
ep3_val_logits_2 = torch.from_numpy(ep3_val_logits)
ep4_val_logits_2 = torch.from_numpy(ep4_val_logits)

ep1_test_logits_2 = torch.from_numpy(ep1_test_logits)
ep2_test_logits_2 = torch.from_numpy(ep2_test_logits)
ep3_test_logits_2 = torch.from_numpy(ep3_test_logits)
ep4_test_logits_2 = torch.from_numpy(ep4_test_logits)

val_label = torch.from_numpy(ep3_val_labels)
test_label = torch.from_numpy(ep3_test_labels)


In [7]:
import pickle

root_path = '../train_bls/research/BalancedMetaSoftmax-Classification/logs/CIFAR10_LT/'
k_fold = '3_'

with open(root_path+ 'models/'+ k_fold + 'resnet32_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data1 = pickle.load(f)
ep1_val_logits = data1['logits']
ep1_val_labels = data1['labels']

with open(root_path + 'models/'+ k_fold + 'resnet32_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data2 = pickle.load(f)
ep1_test_logits = data2['logits']
ep1_test_labels = data2['labels']


with open(root_path+ 'models/'+ k_fold + 'resnet32_balanced_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data3 = pickle.load(f)
ep2_val_logits = data3['logits']
ep2_val_labels = data3['labels']

with open(root_path+'models/'+ k_fold + 'resnet32_balanced_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data4 = pickle.load(f)
ep2_test_logits = data4['logits']
ep2_test_labels = data4['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_decouple_balanced_softmax_imba100/train_feat_all.pkl', 'rb') as f:
    data5 = pickle.load(f)
ep3_val_logits = data5['logits']
ep3_val_labels = data5['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_decouple_balanced_softmax_imba100/valfeat_all.pkl', 'rb') as f:
    data6 = pickle.load(f)
ep3_test_logits = data6['logits']
ep3_test_labels = data6['labels']

with open(root_path+'clslearn/'+ k_fold + 'resnet32_balms_imba100/train_feat_all.pkl', 'rb') as f:
    data7 = pickle.load(f)
ep4_val_logits = data7['logits']
ep4_val_labels = data7['labels']


with open(root_path+'clslearn/'+ k_fold + 'resnet32_balms_imba100/valfeat_all.pkl', 'rb') as f:
    data8 = pickle.load(f)
ep4_test_logits = data8['logits']
ep4_test_labels = data8['labels']



val_uce_list_ep1 =[]
val_uce_list_ep2 =[]
val_uce_list_ep3 =[]

ep1_val_logits_3 = torch.from_numpy(ep1_val_logits)
ep2_val_logits_3 = torch.from_numpy(ep2_val_logits)
ep3_val_logits_3 = torch.from_numpy(ep3_val_logits)
ep4_val_logits_3 = torch.from_numpy(ep4_val_logits)

ep1_test_logits_3 = torch.from_numpy(ep1_test_logits)
ep2_test_logits_3 = torch.from_numpy(ep2_test_logits)
ep3_test_logits_3 = torch.from_numpy(ep3_test_logits)
ep4_test_logits_3 = torch.from_numpy(ep4_test_logits)

val_label = torch.from_numpy(ep3_val_labels)
test_label = torch.from_numpy(ep3_test_labels)


In [6]:
ep1_val_logits=(ep1_val_logits_1 + ep1_val_logits_2 + ep1_val_logits_3)/3
ep2_val_logits=(ep2_val_logits_1 + ep2_val_logits_2 + ep2_val_logits_3)/3
ep3_val_logits=(ep3_val_logits_1 + ep3_val_logits_2 + ep3_val_logits_3)/3
ep4_val_logits=(ep4_val_logits_1 + ep4_val_logits_2 + ep4_val_logits_3)/3

ep1_test_logits = (ep1_test_logits_1+ep1_test_logits_2+ep1_test_logits_3)/3
ep2_test_logits = (ep2_test_logits_1+ep2_test_logits_2+ep2_test_logits_3)/3
ep3_test_logits = (ep3_test_logits_1+ep3_test_logits_2+ep3_test_logits_3)/3
ep4_test_logits = (ep4_test_logits_1+ep1_test_logits_2+ep4_test_logits_3)/3

In [8]:
def choose_best_three_expert(probs_expert1,probs_expert2,probs_expert3 ,targets,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3, n_bins=10):


#     uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1, bin_variances_ep1 = compute_uce(probs_expert1, targets, n_bins)
#     uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2, bin_variances_ep2 = compute_uce(probs_expert2, targets_pairs, n_bins)
#     uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3, bin_variances_ep3 = compute_uce(probs_expert3, targets, n_bins)
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1,bin_variances_ep1 = val_uce_list_ep1[0],val_uce_list_ep1[1],val_uce_list_ep1[2],val_uce_list_ep1[3],val_uce_list_ep1[4],val_uce_list_ep1[5]
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2,bin_variances_ep2 = val_uce_list_ep2[0],val_uce_list_ep2[1],val_uce_list_ep2[2],val_uce_list_ep2[3],val_uce_list_ep2[4],val_uce_list_ep2[5]
    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3,bin_variances_ep3 = val_uce_list_ep3[0],val_uce_list_ep3[1],val_uce_list_ep3[2],val_uce_list_ep3[3],val_uce_list_ep3[4],val_uce_list_ep3[5]



    # Compute uncertainties for both experts
    _, nattrs = probs_expert1.size()
    nattrs = torch.tensor(nattrs)
    uncertainties_expert1 = (1/torch.log(nattrs))*(-torch.sum(probs_expert1 * torch.log(probs_expert1 + 1e-12), dim=1))
    uncertainties_expert2 = (1/torch.log(nattrs))*(-torch.sum(probs_expert2 * torch.log(probs_expert2 + 1e-12), dim=1))
    uncertainties_expert3 = (1/torch.log(nattrs))*(-torch.sum(probs_expert3 * torch.log(probs_expert3 + 1e-12), dim=1))

    # Find error rates for both experts
    error_rates_expert1 = find_error_rates(uncertainties_expert1, bin_uncertainties_expert1, bin_errors_expert1)
    error_rates_expert2 = find_error_rates(uncertainties_expert2, bin_uncertainties_expert2, bin_errors_expert2)
    error_rates_expert3 = find_error_rates(uncertainties_expert3, bin_uncertainties_expert3, bin_errors_expert3)
    # Choose the expert with lower error rate for each sample

    # Get the predictions from both experts
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)
    preds_expert3 = torch.argmax(probs_expert3, dim=1)

    # 將三個錯誤率堆疊成一個張量
    error_rates = torch.stack([error_rates_expert1, error_rates_expert2, error_rates_expert3])

    # 找出最小錯誤率的索引
    _, min_error_rate_indices = torch.min(error_rates, dim=0)

    # 根據最小錯誤率的索引選擇最終的預測
    final_predictions = torch.where(min_error_rate_indices == 0, preds_expert1,
                                    torch.where(min_error_rate_indices == 1, preds_expert2, preds_expert3))
    return final_predictions

In [12]:
def choose_best_three_expert_new(probs_expert1,probs_expert2,probs_expert3 ,targets,val_uce_list_ep1,val_uce_list_ep2,val_uce_list_ep3, n_bins=10):
    global std_deviation_per_position
    global error_rates
    global threshold_
    global POE_pred
    global final_predictions
    uce_expert1, bin_uncertainties_expert1, bin_errors_expert1, prop_in_bin_values_expert1,bin_n_samples_ep1,bin_variances_ep1 = val_uce_list_ep1[0],val_uce_list_ep1[1],val_uce_list_ep1[2],val_uce_list_ep1[3],val_uce_list_ep1[4],val_uce_list_ep1[5]
    uce_expert2, bin_uncertainties_expert2, bin_errors_expert2, prop_in_bin_values_expert2,bin_n_samples_ep2,bin_variances_ep2 = val_uce_list_ep2[0],val_uce_list_ep2[1],val_uce_list_ep2[2],val_uce_list_ep2[3],val_uce_list_ep2[4],val_uce_list_ep2[5]
    uce_expert3, bin_uncertainties_expert3, bin_errors_expert3, prop_in_bin_values_expert3,bin_n_samples_ep3,bin_variances_ep3 = val_uce_list_ep3[0],val_uce_list_ep3[1],val_uce_list_ep3[2],val_uce_list_ep3[3],val_uce_list_ep3[4],val_uce_list_ep3[5]



    # Compute uncertainties for both experts
    _, nattrs = probs_expert1.size()
    nattrs = torch.tensor(nattrs)
    uncertainties_expert1 = (1/torch.log(nattrs))*(-torch.sum(probs_expert1 * torch.log(probs_expert1 + 1e-12), dim=1))
    uncertainties_expert2 = (1/torch.log(nattrs))*(-torch.sum(probs_expert2 * torch.log(probs_expert2 + 1e-12), dim=1))
    uncertainties_expert3 = (1/torch.log(nattrs))*(-torch.sum(probs_expert3 * torch.log(probs_expert3 + 1e-12), dim=1))

    # Find error rates for both experts
    error_rates_expert1 = find_error_rates(uncertainties_expert1, bin_uncertainties_expert1, bin_errors_expert1)
    error_rates_expert2 = find_error_rates(uncertainties_expert2, bin_uncertainties_expert2, bin_errors_expert2)
    error_rates_expert3 = find_error_rates(uncertainties_expert3, bin_uncertainties_expert3, bin_errors_expert3)
    # Choose the expert with lower error rate for each sample

    # Get the predictions from both experts
    preds_expert1 = torch.argmax(probs_expert1, dim=1)
    preds_expert2 = torch.argmax(probs_expert2, dim=1)
    preds_expert3 = torch.argmax(probs_expert3, dim=1)

    # 將三個錯誤率堆疊成一個張量
    error_rates = torch.stack([error_rates_expert1, error_rates_expert2, error_rates_expert3])
    
    std_deviation_per_position = torch.std(error_rates, dim=0)
    mean_value = torch.mean(std_deviation_per_position)
    print("mean: ",mean_value)
#     print(std_deviation_per_position)
    # 找出最小錯誤率的索引
    _, min_error_rate_indices = torch.min(error_rates, dim=0)
    
    POE_probs_ = torch.stack([probs_expert1, probs_expert2,probs_expert3])
    POE_probs = product_of_experts(POE_probs_)
    POE_pred = np.argmax(POE_probs, axis=1)
    
    POE_pred = torch.tensor(POE_pred)  # Convert numpy array to torch tensor

    SOE_probs_ = (probs_expert1+probs_expert2+probs_expert3)/3
    SOE_pred = np.argmax(SOE_probs_, axis=1)

    # 根據最小錯誤率的索引選擇最終的預測
    final_predictions = torch.where(min_error_rate_indices == 0, preds_expert1,
                                      torch.where(min_error_rate_indices == 1, preds_expert2, preds_expert3))

    initial_predictions_probs = torch.where(min_error_rate_indices.unsqueeze(-1) == 0, probs_expert1,
                                           torch.where(min_error_rate_indices.unsqueeze(-1) == 1, probs_expert2, probs_expert3))
    
    POE_probs_ = torch.stack([POE_probs,SOE_probs_,initial_predictions_probs])
    POE_initial_predictions_probs = product_of_experts(POE_probs_)
    SOE_initial_predictions_probs = (SOE_probs_+POE_probs+initial_predictions_probs)/2
    
    POE_final_predictions = np.argmax(POE_initial_predictions_probs, axis=1)
    
    SOE_final_predictions = np.argmax(SOE_initial_predictions_probs, axis=1)
    
    POE_acc =accuracy(POE_final_predictions,targets)
    SOE_acc =accuracy(SOE_final_predictions,targets)
#     print("POE_SPE_acc: ",POE_acc,"SOE_SPE_acc: ",SOE_acc)
#     threshold = torch.quantile(std_deviation_per_position, threshold_)

#     threshold = torch.quantile(uncertainties_, threshold_)
#     print("threshold: ",threshold)
    # 根據 std_deviation_per_position 更新預測
#     final_predictions = torch.where(std_deviation_per_position < threshold, POE_pred, initial_predictions)
#     final_predictions = torch.where(std_deviation_per_position < threshold, initial_predictions , POE_pred)

    return SOE_final_predictions

def compute_uce(probs, targets, n_bins=10):
    _, nattrs =probs.size()
    nattrs = torch.tensor(nattrs)
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    uce = 0
    bin_uncertainties = []
    bin_errors = []
    prop_in_bin_values = []
    bin_n_samples = []
    bin_variances = []
    # Compute the uncertainty values (entropy)
    uncertainties = (1/torch.log(nattrs))*(-torch.sum(probs * torch.log(probs + 1e-12), dim=1))
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (uncertainties >= bin_lower) * (uncertainties < bin_upper)
        prop_in_bin = in_bin.float().mean()
        prop_in_bin_values.append(prop_in_bin.item() if prop_in_bin.item() > 0 else None)
        if prop_in_bin.item() > 0:
            sample_indices = torch.where(in_bin)[0]
            bin_targets = targets[sample_indices]
            bin_probs = probs[sample_indices]
            error_in_bin = (bin_targets != torch.argmax(bin_probs, dim=1)).float().mean()
            avg_uncertainty_in_bin = uncertainties[in_bin].mean()
            uce += torch.abs(avg_uncertainty_in_bin - error_in_bin) * prop_in_bin
            bin_uncertainties.append(avg_uncertainty_in_bin.item())
            bin_errors.append(error_in_bin.item())
            n_samples_in_bin = sample_indices.size(0)
            bin_n_samples.append(n_samples_in_bin)
            bin_variances.append(torch.var((bin_targets != torch.argmax(bin_probs, dim=1)).float()).item())
        else:
            bin_uncertainties.append(None)
            bin_errors.append(None)
            bin_n_samples.append(None)
            bin_variances.append(None)

    return uce, bin_uncertainties, bin_errors, prop_in_bin_values, bin_n_samples, bin_variances


In [13]:
std_deviation_per_position = None
error_rates = None
POE_pred = None
final_predictions = None

In [14]:
threshold_ = 0

save_name = 'bls_134_1'

val_uce_list_ep1 , val_uce_list_ep2 ,val_uce_list_ep3 = [],[],[]

cal_val_state(ep1_val_logits_1, ep3_val_logits_1, ep4_val_logits_1,val_label,phase='val')
print('------------test1-------------------')
cal_val_state(ep1_test_logits_1,ep3_test_logits_1,ep4_test_logits_1,test_label,phase='test')

val_uce_list_ep1 , val_uce_list_ep2 ,val_uce_list_ep3 = [],[],[]
cal_val_state(ep1_val_logits_2, ep3_val_logits_2, ep4_val_logits_2, val_label,phase='val')
print('------------test2-------------------')
cal_val_state(ep1_test_logits_2,ep3_test_logits_2,ep4_test_logits_2,test_label,phase='test')

val_uce_list_ep1 , val_uce_list_ep2 ,val_uce_list_ep3 = [],[],[]
cal_val_state(ep1_val_logits_3, ep3_val_logits_3, ep4_val_logits_3, val_label,phase='val')
print('------------test3-------------------')
cal_val_state(ep1_test_logits_3,ep3_test_logits_3,ep4_test_logits_3,test_label,phase='test')

------------test1-------------------
test expert1_acc:  0.7978 expert2_acc:  0.8413 expert3_acc:  0.8418
1:  tensor(0.1027)  2:  tensor(0.0113)  3:  tensor(0.0149)




SPE_UCE:  tensor(0.0656) SOE_UCE:  tensor(0.0104) POE_UCE:  tensor(0.1636)
mean:  tensor(0.0356)
POE_SPE_acc:  0.8329 SOE_SPE_acc:  0.8324




Table Accuracy:
12: 0.829 23: 0.844 13: 0.831 123: 0.8325 123_new: 0.8324

Simple Voting Accuracy:
12: 0.798 23: 0.841 13: 0.798 123: 0.798

weighted Voting Accuracy:
12: 0.798 23: 0.841 13: 0.798 123: 0.841

SOE Accuracy:
12: 0.825 23: 0.841 13: 0.827 123: 0.8385

POE Accuracy:
12: 0.826 23: 0.841 13: 0.830 123: 0.8360

------------test2-------------------
test expert1_acc:  0.7903 expert2_acc:  0.8433 expert3_acc:  0.841
1:  tensor(0.1110)  2:  tensor(0.0133)  3:  tensor(0.0130)
SPE_UCE:  tensor(0.0674) SOE_UCE:  tensor(0.0408) POE_UCE:  tensor(0.1426)
mean:  tensor(0.0520)
POE_SPE_acc:  0.8579 SOE_SPE_acc:  0.8551
Table Accuracy:
12: 0.852 23: 0.844 13: 0.854 123: 0.8539 123_new: 0.8551

Simple Voting Accuracy:
12: 0.790 23: 0.843 13: 0.790 123: 0.790

weighted Voting Accuracy:
12: 0.790 23: 0.843 13: 0.790 123: 0.843

SOE Accuracy:
12: 0.846 23: 0.843 13: 0.845 123: 0.8571

POE Accuracy:
12: 0.846 23: 0.842 13: 0.848 123: 0.8570

------------test3-------------------
test expert1_ac

In [None]:
threshold_ = 0

save_name = 'bls_134_1'

val_uce_list_ep1 , val_uce_list_ep2 ,val_uce_list_ep3 = [],[],[]

cal_val_state(ep1_val_logits_1, ep3_val_logits_1, ep4_val_logits_1,val_label,phase='val')
print('------------test1-------------------')
cal_val_state(ep1_test_logits_1,ep3_test_logits_1,ep4_test_logits_1,test_label,phase='test')

val_uce_list_ep1 , val_uce_list_ep2 ,val_uce_list_ep3 = [],[],[]
cal_val_state(ep1_val_logits_2, ep3_val_logits_2, ep4_val_logits_2, val_label,phase='val')
print('------------test2-------------------')
cal_val_state(ep1_test_logits_2,ep3_test_logits_2,ep4_test_logits_2,test_label,phase='test')

val_uce_list_ep1 , val_uce_list_ep2 ,val_uce_list_ep3 = [],[],[]
cal_val_state(ep1_val_logits_3, ep3_val_logits_3, ep4_val_logits_3, val_label,phase='val')
print('------------test3-------------------')
cal_val_state(ep1_test_logits_3,ep3_test_logits_3,ep4_test_logits_3,test_label,phase='test')

------------test1-------------------
test expert1_acc:  0.7978 expert2_acc:  0.8413 expert3_acc:  0.8418
1:  tensor(0.1027)  2:  tensor(0.0113)  3:  tensor(0.0149)




POE_SPE_acc:  0.8329 SOE_SPE_acc:  0.8324
Table Accuracy:
12: 0.829 23: 0.844 13: 0.831 123: 0.8325 123_new: 0.8518

Simple Voting Accuracy:
12: 0.798 23: 0.841 13: 0.798 123: 0.798

weighted Voting Accuracy:
12: 0.798 23: 0.841 13: 0.798 123: 0.841

SOE Accuracy:
12: 0.825 23: 0.841 13: 0.827 123: 0.8385

POE Accuracy:
12: 0.826 23: 0.841 13: 0.830 123: 0.8360

------------test2-------------------
test expert1_acc:  0.7903 expert2_acc:  0.8433 expert3_acc:  0.841
1:  tensor(0.1110)  2:  tensor(0.0133)  3:  tensor(0.0130)
POE_SPE_acc:  0.8579 SOE_SPE_acc:  0.8551
Table Accuracy:
12: 0.852 23: 0.844 13: 0.854 123: 0.8539 123_new: 0.8518

Simple Voting Accuracy:
12: 0.790 23: 0.843 13: 0.790 123: 0.790

weighted Voting Accuracy:
12: 0.790 23: 0.843 13: 0.790 123: 0.843

SOE Accuracy:
12: 0.846 23: 0.843 13: 0.845 123: 0.8571

POE Accuracy:
12: 0.846 23: 0.842 13: 0.848 123: 0.8570

------------test3-------------------
test expert1_acc:  0.7945 expert2_acc:  0.8444 expert3_acc:  0.8397
1:

In [60]:
save_name = 'bls_124_best'
cal_val_state(ep1_val_logits,ep2_val_logits,ep4_val_logits,val_label,phase='val')
print('------------test-------------------')
cal_val_state(ep1_test_logits,ep2_test_logits,ep4_test_logits,test_label,phase='test')

------------test-------------------
test expert1_acc:  0.8352 expert2_acc:  0.8759 expert3_acc:  0.8592
1:  tensor(0.0683)  2:  tensor(0.0358)  3:  tensor(0.0168)
Table Accuracy:
12: 0.868 23: 0.878 13: 0.855 123: 0.872

Simple Voting Accuracy:
12: 0.835 23: 0.876 13: 0.835 123: 0.835

weighted Voting Accuracy:
12: 0.835 23: 0.876 13: 0.835 123: 0.876

SOE Accuracy:
12: 0.868 23: 0.882 13: 0.852 123: 0.871

POE Accuracy:
12: 0.868 23: 0.883 13: 0.853 123: 0.871



In [18]:
save_name = 'bls_234'
cal_val_state(ep2_val_logits,ep3_val_logits,ep4_val_logits,val_label,phase='val')
print('------------test-------------------')
cal_val_state(ep2_test_logits,ep3_test_logits,ep4_test_logits,test_label,phase='test')

------------test-------------------
test expert1_acc:  0.8629 expert2_acc:  0.8432 expert3_acc:  0.859
1:  tensor(0.0468)  2:  tensor(0.0117)  3:  tensor(0.0177)
Table Accuracy:
12: 0.873 23: 0.860 13: 0.874 123: 0.875

Simple Voting Accuracy:
12: 0.863 23: 0.843 13: 0.863 123: 0.863

weighted Voting Accuracy:
12: 0.863 23: 0.843 13: 0.863 123: 0.843

SOE Accuracy:
12: 0.878 23: 0.859 13: 0.876 123: 0.875

POE Accuracy:
12: 0.879 23: 0.860 13: 0.877 123: 0.879



In [19]:
save_name = 'bls_134'
cal_val_state(ep1_val_logits,ep3_val_logits,ep4_val_logits,val_label,phase='val')
print('------------test-------------------')
cal_val_state(ep1_test_logits,ep3_test_logits,ep4_test_logits,test_label,phase='test')

------------test-------------------
test expert1_acc:  0.8165 expert2_acc:  0.8432 expert3_acc:  0.859
1:  tensor(0.0860)  2:  tensor(0.0117)  3:  tensor(0.0177)
Table Accuracy:
12: 0.855 23: 0.860 13: 0.852 123: 0.857

Simple Voting Accuracy:
12: 0.817 23: 0.843 13: 0.817 123: 0.817

weighted Voting Accuracy:
12: 0.817 23: 0.843 13: 0.817 123: 0.843

SOE Accuracy:
12: 0.856 23: 0.859 13: 0.845 123: 0.861

POE Accuracy:
12: 0.857 23: 0.860 13: 0.846 123: 0.859

