# D-RISE Histogram (Figure 8b)


In [None]:
import torch
import json
import matplotlib.pyplot as plt

In [None]:
def keep_changes(model_name="llava", correctness="corr"):
    """
    Extract bbox blur sensitivity values for correct/incorrect predictions.
    
    Args:
        model_name: Model identifier ("llava" or "llama-11b")
        correctness: Which set to analyze ("corr" or "wrong")
    
    Returns:
        group_true: Sensitivity values for correct predictions
        group_false: Sensitivity values for incorrect predictions
    """
    # Load blur experiment results
    blurred_res = torch.load(f"embeds/blur/{model_name.split('-')[0]}_blur.pt")
    
    # Load ground truth metadata (layer 14 for llama, layer 12 for llava)
    if "llava" not in model_name:
        json_file = open(f"metadata/{model_name}_14.json")
    else:
        json_file = open(f"metadata/{model_name}_12.json")
    
    data = json.load(json_file)
    
    # Build lookup dict: id -> {distance, verdict}
    data_dict = {}
    for d in data:
        data_dict[d['id']] = {'distance': d['distance'], 'verdict': d['verdict']}
    
    corrlist = blurred_res['results_dict'][correctness]
    group_true = []  # Correct predictions
    group_false = []  # Incorrect predictions
    
    for key_ in corrlist:
        try:
            cc = corrlist[key_]
            origval = cc[0]  # Original (no blur)
            blurobj1 = cc[1]  # Blur on object bbox
            others = cc[3:7]  # Blur on 4 random locations
            others_max = min(others)  # Most impactful random blur
            
            # Positive diff_val means object blur changes prediction more than random blur
            diff_val = others_max - blurobj1
            
            if key_ in data_dict:
                if data_dict[key_]['verdict'] == True:
                    group_true.append(diff_val)
                else:
                    group_false.append(diff_val)
        except:
            pass
    
    return group_true, group_false

## Generate Histogram Plot

In [None]:
# Model configuration
model1 = "llava-7b"
model2 = "llama-11b"

# Extract sensitivity data
group_true1, group_false1 = keep_changes(model_name=model1, correctness='corr')
group_true2, group_false2 = keep_changes(model_name=model2, correctness='corr')

# Plot side-by-side histograms
plt.figure(figsize=(8, 4))

plt.subplot(121)
plt.hist(group_true1, bins=15, density=True, alpha=0.5, label='positive')
plt.hist(group_false1, bins=15, density=True, alpha=0.3, color='red', label='negative')
plt.ylabel("Density")
plt.title(f"{model1}")

plt.subplot(122)
plt.hist(group_true2, bins=15, density=True, alpha=0.5, label='positive')
plt.hist(group_false2, bins=15, density=True, alpha=0.3, color='red', label='negative')
plt.title(f"{model2}")
plt.legend()

plt.tight_layout()
plt.show()