In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import re
import numpy as np
import torch


poison_per_category = 100
poisoners = ['desk', 'palace', 'necklace', 'balloon', 'pillow', 
             'candle', 'pizza', 'umbrella', 'television', "baseball", 
             "ice cream", "suit", 'mountain', 'beach', 'plate',
             'orange']
full_poison_range = poison_per_category * len(poisoners)

def plot_poison_distribution(file_path, poison_category='full', filter_ratios=[0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]):
    # df = pd.read_csv(file_path, sep='\t', header=None)
    t = torch.load(file_path).cpu().numpy()
    df = pd.DataFrame(t, columns=None)
    mean_similarity = df[1].mean()
    orig_len = len(df)

    if poison_category == 'full':
        condition = df[0] < full_poison_range
    elif poison_category == 'less':
        condition = df[0] < (full_poison_range // 2)
    else:
        condition = (df[0] >= poison_per_category * poisoners.index(poison_category) \
                        and df[0] < poison_per_category * (poisoners.index(poison_category)+1))
    df = df[condition]
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))

    n, bins, patches = ax1.hist(df.index.tolist(), bins=50, color='blue', alpha=0.5)
    ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
    ax1.set_title('Poison Rank Distribution')
    ax1.set_xlabel('Poison Rank')
    ax1.set_ylabel('Frequency')
    
    comments = []
    for ratio in filter_ratios:
        unfiltered_poison_num = (df.index < orig_len * ratio).sum()
        comments.append('poison num at top %f: %d'%(ratio, unfiltered_poison_num))
    comment_x = np.argmax(n)
    comment_y = np.max(n) 
    ax1.text(comment_x, comment_y, '\n'.join(comments), fontsize=12, ha='center')


    ax2.hist(df[1].tolist(), bins=30, color='green', alpha=0.5)
    ax2.invert_xaxis()
    ax2.set_title('Poison Similarity Distribution')
    ax2.set_xlabel('Poison Similarity')
    ax2.set_ylabel('Frequency')
    ax2.axvline(mean_similarity, color='red', linestyle='--', label='mean_similarity')
    ax2.legend()

    plt.tight_layout()
    plt.savefig('post_pretraining_analysis/dist_%s_%s.png' \
                %(re.search(r"/([^/]+).pt", file_path).group(1), poison_category))
    plt.close()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# for i in [7,12,17,22,27]:
plot_poison_distribution("/home/hyang/NNCLIP/CyCLIP/jigao_indices/NNCLIP_1M_100_16_w_NN_wo_intersection_update7.pt")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import re
import numpy as np
import torch
from scipy.interpolate import make_interp_spline
from matplotlib.lines import Line2D  # I

In [None]:
def plot_clean_n_poison_similarity_distribution(file_path, poison_category='full'):
    if file_path[-2:] == 'pt':
        # import pdb
        # pdb.set_trace()
        t = torch.load(file_path).cpu().numpy()
        df = pd.DataFrame(t, columns=None)
    else:
        df = pd.read_csv(file_path, sep='\t', header=None)

    if poison_category == 'full':
        poison_condition = df[0] < full_poison_range
        clean_condition = df[0] >= full_poison_range
    elif poison_category == 'less':
        poison_condition = df[0] < (full_poison_range // 2)
        clean_condition = df[0] >= (full_poison_range // 2)
    else:
        condition = (df[0] >= poison_per_category * poisoners.index(poison_category) \
                        and df[0] < poison_per_category * (poisoners.index(poison_category)+1))
    df_poison = df[poison_condition]
    df_clean = df[clean_condition]
    
    fig, ax = plt.subplots(figsize=(8, 6))

    
    # n, bins, patches =  plt.hist(df_poison[1].tolist(), bins=30, color='blue', alpha=0.5)
    n, bins = np.histogram(df_poison[1].tolist(), bins=20, density=True)
    # n = n / n.max() * 1.1
    # Compute the midpoints of each bin
    bin_centers = 0.5 * (bins[:-1] + bins[1:])

    x_smooth = np.linspace(bin_centers.min(), bin_centers.max(), 100)  # Generate more points for a smoother curve
    y_smooth = make_interp_spline(bin_centers, n)(x_smooth)

    # Plot the smoothed curve
    # plt.plot(x_smooth, y_smooth, 'b-', linewidth=2)
    plt.fill_between(x_smooth, 0, y_smooth, alpha=0.3, color='blue')

    # n, bins, patches =  plt.hist(df_clean[1].tolist(), bins=30, color='green', alpha=0.5)
    n, bins = np.histogram(df_clean[1].tolist(), bins=30, density=True)
    # n = n / n.max() * 1.1
    # Compute the midpoints of each bin
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    
    x_smooth = np.linspace(bin_centers.min(), bin_centers.max(), 100)  # Generate more points for a smoother curve
    y_smooth = make_interp_spline(bin_centers, n)(x_smooth)

    # Plot the smoothed curve
    # plt.plot(x_smooth, y_smooth, 'g-', linewidth=2)
    
    plt.fill_between(x_smooth, 0, y_smooth, alpha=0.3, color='green')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ax.set_ylabel('Probability Density', fontsize=14)  
    ax.set_xlabel('Cosine Similarity', fontsize=14) 
    # Create custom legend entries with colored boxes
    legend_elements = [
        Line2D([0], [0], color='blue', lw=10, alpha=0.3, label='Poison'),
        Line2D([0], [0], color='green', lw=10, alpha=0.3, label='Clean'),
    ]

    # Add legend with custom entries
    ax.legend(handles=legend_elements, handlelength=4, handleheight=3, fontsize=12)

    plt.tight_layout()

    if file_path[-2:] == 'pt':
        plt.savefig('post_pretraining_analysis/clean_poison_dist_%s_%s.png' \
                %(re.search(r"/([^/]+).pt", file_path).group(1), poison_category))
    else:
        plt.savefig('post_pretraining_analysis/clean_poison_dist_%s_%s.png' \
            %(re.search(r"/([^/]+).tsv", file_path).group(1), poison_category))
    plt.close()