In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sklearn.mixture import GaussianMixture
from scipy.stats import norm

cirr_loss = torch.load(f'loss_2.pth', weights_only=False)
cirr_prob = torch.load('prob_2.pth', weights_only=False)
gt_mask = (cirr_prob > 0.5).cpu().numpy()
input = cirr_loss.cpu().numpy()

def density_plot(input, gt_mask, name='loss', density=True):
    plt.figure(figsize=(7, 5))
    plt.clf()
    ax = plt.gca()
    clean_index = np.where(gt_mask == 1)[0]
    noisy_index = np.where(gt_mask == 0)[0]

    gmm = GaussianMixture(n_components=2)
    gmm.fit(input.reshape(-1, 1))

    #  Plot mixture PDF
    x = np.linspace(np.min(input), np.max(input), 1000)
    pdf = np.exp(gmm.score_samples(x.reshape(-1, 1)))
    ax.plot(x, pdf, color='black', linestyle='-', label='Mixture PDF')

    # Plot individual GMM components
    for i in range(gmm.n_components):
        mean = gmm.means_[i][0]
        std_dev = np.sqrt(gmm.covariances_[i][0][0])
        component_pdf = norm.pdf(x, mean, std_dev)
        if mean == min(gmm.means_):
            color = 'green'
        else:
            color = 'red'
        ax.plot(x, gmm.weights_[i] * component_pdf, color=color, linestyle='--', label=f'Component {i+1}')
    # weight_a = ((len(clean_index)/len(input))) * np.ones_like(input[clean_index])
    # weight_b = ((len(noisy_index)/len(input))) * np.ones_like(input[noisy_index])
    # print(weight_a)
    # Plot density
    
    hist1, bins1 = np.histogram(input[clean_index], bins=1000, density=True)
    hist2, bins2 = np.histogram(input[noisy_index], bins=1000, density=True)
    
    bin_centers = (bins1[:-1] + bins1[1:]) / 2  # 计算每个柱子的中心位置
    bin_widths = np.diff(bins1)  # 计算每个柱子的宽度
    hist1 = [x*len(clean_index) / len(input) for x in hist1]
    ax.bar(bin_centers, hist1, width=bin_widths, align='center', color='green', alpha=0.4)
    
    bin_centers = (bins2[:-1] + bins2[1:]) / 2  # 计算每个柱子的中心位置
    bin_widths = np.diff(bins2)  # 计算每个柱子的宽度
    hist2 = [x*len(noisy_index) / len(input) for x in hist2]
    ax.bar(bin_centers, hist2, width=bin_widths, align='center', color='red', alpha=0.4)
    
    plt.yticks(size=15)
    plt.xticks(size=15)
    plt.xlabel(name, fontsize=15)
    if not density:
        plt.ylabel('num', fontsize=15)
    else:
        plt.ylabel('density', fontsize=15)
    plt.legend(loc='upper right', fontsize=12, frameon=True)
    plt.savefig("visualization.pdf")
    plt.show()
    plt.close()
    
density_plot(input, gt_mask)