# Efficient Unsupervised Shortcut Learning Detection and Mitigation in Transformers

This is the official code for the paper "Efficient Unsupervised Shortcut Learning Detection and Mitigation in Transformers" by Lukas Kuhn, Sari Sadiya, Joerg Schloetterer, Christin Seifert, Gemma Roig.

Please contact Lukas Kuhn (lukas.kuhn@dkfz-heidelberg.de) for any questions.

In [None]:
import cv2
import os
import numpy as np
import replicate
import torch

from matplotlib import pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from surgeon_pytorch import Inspect, Extract
from torch import nn
from torch.utils.data import ConcatDataset
from torchvision import transforms, datasets
from torchvision.models import vit_b_16

from tqdm import tqdm
import openvino_xai as xai

We are using the ISIC dataset for this example. Please download and sort the ISIC dataset into the following folders:

- ./data/isic/val_wo_patches/
- ./data/isic/val_w_patches/
- ./data/isic/test_wo_patches/
- ./data/isic/test_w_patches/


In [2]:
seed = 1

torch.manual_seed(seed)
np.random.seed(seed)

mean_ds = [0.485, 0.456, 0.406]
std_ds = [0.229, 0.224, 0.225]

# Define the transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the images to a fixed size
    transforms.ToTensor(),  # Convert the images to tensors
    transforms.Normalize(mean=mean_ds, std=std_ds)  # Normalize the images
])

# Create the ImageFolder dataset
val_dataset_wo_patches = datasets.ImageFolder(root='./data/isic/val_wo_patches/', transform=transform)
val_dataset_w_patches = datasets.ImageFolder(root='./data/isic/val_w_patches/', transform=transform)

concatenated_dataset = ConcatDataset([val_dataset_wo_patches, val_dataset_w_patches])

test_dataset_wo_patches = datasets.ImageFolder(root='./data/isic/test_wo_patches/', transform=transform)
test_dataset_w_patches = datasets.ImageFolder(root='./data/isic/test_w_patches/', transform=transform)

test_concatenated_dataset = ConcatDataset([test_dataset_wo_patches, test_dataset_w_patches])

We are running this code on a MacBook Pro with an M3 chip and 16GB hence the usage of MPS. It can easily be adapted to other devices. The following code loads the pre-trained ViT model and modifies the final layer for the dataset. It also loads the pre-trained model and inserts the XAI layer, to get the saliency maps, which we only use for visualization during our research and not for the shortcut detection.

In [None]:
MODEL_NAME = f"vit_isic_{seed}"
NUM_CLASSES = 2

# Define the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load the pre-trained ViT model
model = vit_b_16(weights='DEFAULT')

# Modify the final layer for the dataset
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model.to(device)
model.eval()

model.load_state_dict(torch.load(f"models/{MODEL_NAME}.pth", map_location=device, weights_only=True))

inspect = Inspect(model, layer="encoder.layers.encoder_layer_11")

model_xai: torch.nn.Module = xai.insert_xai(model, xai.Task.CLASSIFICATION).to(device)


We divide the dataset into two subsets, one with patches and one without and also based on the target class for easier identification of the worst group. This knowledge is only used for evaluation and not for the shortcut detection.

In [None]:
def load_subset_by_target(ds, target_class):
    indices = [i for i, (_, target) in enumerate(ds) if target == target_class]

    subset = torch.utils.data.Subset(ds, indices)
    return subset

def inference_model(model, loader):
    model.eval()

    correct = 0
    total = 0

    for img, label in tqdm(loader):
        label = label.to(device)
        out = model(img.to(device))
        pred = torch.argmax(out, dim=1)

        correct += (pred == label).sum().item()
        total += img.size(0)

    print(f"Accuracy: {correct / total:.4f}")

test_w_patches_cls_0 = load_subset_by_target(test_dataset_w_patches, 0)
test_w_patches_cls_1 = load_subset_by_target(test_dataset_w_patches, 1)

test_wo_patches_cls_0 = load_subset_by_target(test_dataset_wo_patches, 0)
test_wo_patches_cls_1 = load_subset_by_target(test_dataset_wo_patches, 1)

test_w_patches_cls_0_loader = torch.utils.data.DataLoader(test_w_patches_cls_0, batch_size=16, shuffle=False)
test_w_patches_cls_1_loader = torch.utils.data.DataLoader(test_w_patches_cls_1, batch_size=16, shuffle=False)
test_wo_patches_cls_0_loader = torch.utils.data.DataLoader(test_wo_patches_cls_0, batch_size=16, shuffle=False)
test_wo_patches_cls_1_loader = torch.utils.data.DataLoader(test_wo_patches_cls_1, batch_size=16, shuffle=False)


inference_model(model, test_w_patches_cls_0_loader)
inference_model(model, test_w_patches_cls_1_loader)
inference_model(model, test_wo_patches_cls_0_loader)
inference_model(model, test_wo_patches_cls_1_loader)

In [4]:
self_attention_layer = model.encoder.layers[-1].self_attention

in_proj_weight = self_attention_layer.in_proj_weight

# Extract t
# he dimensions
embed_dim = 768
kdim = self_attention_layer.kdim

# Extract the weight matrix for the keys (Wk) from the combined weight matrix
W_K = in_proj_weight[embed_dim:2*embed_dim, :]

conv_out_model = Extract(model, node_out="encoder.dropout") 
conv_out_model.eval()
conv_out_model = conv_out_model.to(device)

We extract the keys from the model for each image in the dataset following _Bolya et al._

In [5]:
def extract_keys(dataset):
    keys = []

    for i in range(len(dataset)):
        out = conv_out_model(dataset[i][0].unsqueeze(0).to(device))

        ln_1 = out[:,1:,:]
        k = ln_1 @ W_K.T

        k = k.reshape(1, 196, 12, 64)
        k_mean = k.mean(dim=2)

        keys.append(k_mean.squeeze().detach().cpu().numpy())

    return np.array(keys)

# Example usage:
keys_w_patches = extract_keys(val_dataset_w_patches)
keys_wo_patches = extract_keys(val_dataset_wo_patches)

test_keys_w_patches = extract_keys(test_dataset_w_patches)
test_keys_wo_patches = extract_keys(test_dataset_wo_patches)

In [6]:
activations_wo_patches = []
probs_wo_patches = []

for i in range(len(val_dataset_wo_patches)):
    probs, out = inspect(val_dataset_wo_patches[i][0].unsqueeze(0).to(device))
    activations_wo_patches.append(out.squeeze().detach().cpu().numpy())
    probs_wo_patches.append(probs.squeeze().detach().cpu().numpy())

activations_wo_patches = np.array(activations_wo_patches)
probs_wo_patches = np.array(probs_wo_patches)

activations_w_patches = []
probs_w_patches = []

for i in range(len(val_dataset_w_patches)):
    probs, out = inspect(val_dataset_w_patches[i][0].unsqueeze(0).to(device))
    activations_w_patches.append(out.squeeze().detach().cpu().numpy())
    probs_w_patches.append(probs.squeeze().detach().cpu().numpy())

activations_w_patches = np.array(activations_w_patches)
probs_w_patches = np.array(probs_w_patches)

We perform PCA and K-means clustering on the activations and plot a confusion matrix to evaluate the clustering, this is also only used for evaluation and not for the shortcut detection.

In [None]:
# Perform PCA and K-means clustering
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix


acts = np.concatenate([activations_wo_patches, activations_w_patches], axis=0)
labels = np.concatenate([np.zeros(len(activations_wo_patches)), np.ones(len(activations_w_patches))], axis=0)

k = acts.reshape(len(acts), 197, 768)
k = k[:,1:]
k = np.mean(k, axis=1)

# Normalize k
k_min = k.min()
k_max = k.max()
k_normalized = (k - k_min) / (k_max - k_min)

# Reshape k_normalized to 2D array (required for TSNE)
k_reshaped = k_normalized.reshape(len(k_normalized), -1)

# Perform PCA with 50 components
pca = PCA(n_components=3)
pca_results = pca.fit_transform(k_reshaped)

# Perform K-means clustering
n_clusters = 2  # You can adjust this number based on your needs
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(pca_results)

# Calculate accuracy of K-means clustering compared to true labels

true_labels = labels

cm = confusion_matrix(true_labels, cluster_labels)

# If cluster 0 matches more with label 1, we need to flip the labels
if cm[0][1] > cm[0][0]:
    cluster_labels = 1 - cluster_labels

# Plot the confusion matrix
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix of K-means Clustering')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()


accuracy = accuracy_score(true_labels, cluster_labels)
print(f"Clustering accuracy: {accuracy:.2f}")

We employ the Brier score and class homogeneity to evaluate the clustering. This is described in detail in section 3.3.1 of the paper.

In [None]:
from torch.nn import functional as F

# Calculate Brier score and class homogeneity for each cluster
def calculate_metrics(probs, true_labels):
    # Convert logits to probabilities using softmax
    probs = F.softmax(torch.tensor(probs), dim=1).numpy()
    
    # Brier score for each class
    brier_scores = {
        'class0': None,  # for samples where true label is 0
        'class1': None   # for samples where true label is 1
    }
    
    # Separate samples by true label
    class0_mask = (true_labels == 0)
    class1_mask = (true_labels == 1)
    
    # Calculate Brier score for each class
    if np.any(class0_mask):
        brier_scores['class0'] = np.mean((probs[class0_mask, 0] - 1) ** 2 + 
                                       (probs[class0_mask, 1] - 0) ** 2)
    if np.any(class1_mask):
        brier_scores['class1'] = np.mean((probs[class1_mask, 0] - 0) ** 2 + 
                                       (probs[class1_mask, 1] - 1) ** 2)
    
    # Add probability distribution analysis
    prob_stats = {
        'mean_prob_class0': np.mean(probs[:, 0]),
        'std_prob_class0': np.std(probs[:, 0]),
        'mean_prob_class1': np.mean(probs[:, 1]),
        'std_prob_class1': np.std(probs[:, 1])
    }
    
    class_counts = np.bincount(true_labels.astype(int))
    homogeneity = np.max(class_counts) / len(true_labels)
    dominant_class = np.argmax(class_counts)
    
    return brier_scores, homogeneity, dominant_class, prob_stats

# Get probabilities and true labels for each cluster
cluster_0_indices = np.where(cluster_labels == 0)[0]
cluster_1_indices = np.where(cluster_labels == 1)[0]

# Combine probabilities from both datasets
all_probs = np.concatenate([probs_wo_patches, probs_w_patches], axis=0)
all_labels = labels  # This was created earlier in your code

# Calculate metrics for each cluster
cluster_0_brier, cluster_0_homogeneity, cluster_0_dominant, cluster_0_prob_stats = calculate_metrics(
    all_probs[cluster_0_indices], 
    all_labels[cluster_0_indices]
)

cluster_1_brier, cluster_1_homogeneity, cluster_1_dominant, cluster_1_prob_stats = calculate_metrics(
    all_probs[cluster_1_indices], 
    all_labels[cluster_1_indices]
)

lambda1 = 1/3 
lambda2 = 1/3
lambda3 = 1/3

def calculate_cluster_score(homogeneity, brier_scores, dominant_class):
    # Get both Brier scores
    brier_class0 = brier_scores['class0']
    brier_class1 = brier_scores['class1']
    
    if brier_class0 is None or brier_class1 is None:
        return 0.0
        
    dominant_class = 0 if brier_class0 < brier_class1 else 1
    dominant_brier = brier_scores[f'class{dominant_class}']
    non_dominant_brier = brier_scores[f'class{1-dominant_class}']
    
    return lambda1 * homogeneity + \
           lambda2 * np.exp(-dominant_brier) + \
           lambda3 * (1 - np.exp(-non_dominant_brier))

# Calculate scores for each cluster
cluster_0_score = calculate_cluster_score(
    cluster_0_homogeneity,
    cluster_0_brier,
    cluster_0_dominant
)

cluster_1_score = calculate_cluster_score(
    cluster_1_homogeneity,
    cluster_1_brier,
    cluster_1_dominant
)

print("\nCluster Scores:")
print(f"Cluster 0 Score: {cluster_0_score:.4f}")
print(f"Cluster 1 Score: {cluster_1_score:.4f}")
print(f"Cluster with higher score: {0 if cluster_0_score > cluster_1_score else 1}")

We plot the distance heatmaps for each sample in the dataset. We utilize this information to select the most relevant patch sections for the shortcut detection and naming part afterwards.

In [None]:
# Define a function to revert the normalization transform
def revert_transform(tensor):
    # Convert tensor to numpy array
    img = tensor.cpu().numpy()
    # Transpose from (C, H, W) to (H, W, C)
    img = np.transpose(img, (1, 2, 0))
    # Reverse the normalization
    img = img * np.array(std_ds) + np.array(mean_ds)
    # Clip values to be between 0 and 1
    img = np.clip(img, 0, 1)

    return torch.tensor(img)

NUM_SELECTED = 20

# Calculate distances to centroids
distances = kmeans.transform(pca_results)

# Get indices of 5 closest images to each centroid
closest_indices = []
for i in range(n_clusters):  # We have 2 clusters
    cluster_distances = distances[:, i]
    closest_indices.extend(cluster_distances.argsort()[:NUM_SELECTED])

print(closest_indices)

concatenated_dataset = ConcatDataset([val_dataset_wo_patches, val_dataset_w_patches])
#conatenated_saliency_maps = np.concatenate([saliency_maps_wo_patches, saliency_maps_w_patches], axis=0)

# Get the corresponding images
closest_images = [concatenated_dataset[i][0] for i in closest_indices]

# Convert the closest images to numpy arrays and revert the normalization
closest_images_np = [revert_transform(img).numpy() for img in closest_images]

# Create a figure with 2 rows and 10 columns
fig, axes = plt.subplots(2, NUM_SELECTED, figsize=(25, 4))
fig.suptitle('Images Closest to Cluster Centroids', fontsize=16)

# Plot images for cluster 0
for i in range(NUM_SELECTED):
    axes[0, i].imshow(closest_images_np[i])
    axes[0, i].axis('off')

# Plot images for cluster 1
for i in range(NUM_SELECTED):
    axes[1, i].imshow(closest_images_np[i+NUM_SELECTED])
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

# Create output directory if it doesn't exist
import os
output_dir = f'outputs/{MODEL_NAME}_{seed}/cluster_images/'
os.makedirs(output_dir, exist_ok=True)

# Save images for cluster 0
for i in range(NUM_SELECTED):
    img = closest_images_np[i]
    # Convert from float [0,1] to uint8 [0,255] if needed
    if img.dtype == np.float32 or img.dtype == np.float64:
        img = (img * 255).astype(np.uint8)
    plt.imsave(os.path.join(output_dir, f'cluster0_sample_{i}.png'), img)

# Save images for cluster 1  
for i in range(NUM_SELECTED):
    img = closest_images_np[i+NUM_SELECTED]
    # Convert from float [0,1] to uint8 [0,255] if needed
    if img.dtype == np.float32 or img.dtype == np.float64:
        img = (img * 255).astype(np.uint8)
    plt.imsave(os.path.join(output_dir, f'cluster1_sample_{i}.png'), img)

In [None]:
keys = np.concatenate([keys_wo_patches, keys_w_patches], axis=0)

output_base = f"outputs/{MODEL_NAME}_{seed}"

# Calculate the mean of each cluster's keys
cluster_0_keys = keys[kmeans.labels_ == 0]
cluster_1_keys = keys[kmeans.labels_ == 1]
cluster_0_mean = np.mean(cluster_0_keys, axis=0)
cluster_1_mean = np.mean(cluster_1_keys, axis=0)


# Get the indices of the NUM_SELECTED closest samples for each cluster
cluster_0_indices = closest_indices[:NUM_SELECTED]
cluster_1_indices = closest_indices[NUM_SELECTED:]


# Calculate Euclidean distances
distances = []
for idx in cluster_0_indices:
    distance_for_sample = []
    for key_idx in range(196):  # Iterate over each of the 196 keys
        distance_to_cluster_1 = np.linalg.norm(keys[idx][key_idx] - cluster_1_mean[key_idx]) # could be key_idx instead of mean
        distance_for_sample.append(distance_to_cluster_1)
    distances.append(distance_for_sample)   

for idx in cluster_1_indices:
    distance_for_sample = []
    for key_idx in range(196):  # Iterate over each of the 196 keys
        distance_to_cluster_0 = np.linalg.norm(keys[idx][key_idx] - cluster_0_mean[key_idx]) # could be key_idx instead of mean
        distance_for_sample.append(distance_to_cluster_0)
    distances.append(distance_for_sample)

# Reshape distances for heatmap
distances_matrix = np.array(distances).reshape(2*NUM_SELECTED, 196)

# normalize distances_matrix to be between 0 and 1
distances_matrix = (distances_matrix - distances_matrix.min()) / (distances_matrix.max() - distances_matrix.min())

# Create a figure with subplots for each sample
fig, axes = plt.subplots(2, NUM_SELECTED, figsize=(15, 5))
fig.suptitle('Distance Heatmaps for Each Sample', fontsize=16)

# Iterate over each sample
for i in range(2*NUM_SELECTED):
    row = i // NUM_SELECTED
    col = i % NUM_SELECTED
    
    # Reshape the distances for this sample into a 14x14 grid
    heatmap_data = distances_matrix[i].reshape(14, 14) 
  
    # Save heatmap
    cluster = 'cluster_0' if i < NUM_SELECTED else 'cluster_1'
    sample_num = i % NUM_SELECTED + 1
    heatmap_dir = f"{output_base}/cluster_heatmaps/{cluster}"
    os.makedirs(heatmap_dir, exist_ok=True)

    # Upsample heatmap to 256x256 using bicubic interpolation
    upscaled_heatmap = cv2.resize(heatmap_data, (256, 256), interpolation=cv2.INTER_NEAREST)
    plt.imsave(f"{heatmap_dir}/sample_{sample_num}_heatmap.png", upscaled_heatmap, cmap='hot', vmin=0, vmax=1)

    # Plot the heatmap
    im = axes[row, col].imshow(heatmap_data, cmap='hot', interpolation='nearest', vmin=0, vmax=1)
    axes[row, col].axis('off') 

    # Set title for each subplot
    if row == 0:
        axes[row, col].set_title(f'Cluster 0\nSample {col+1}')
    else:
        axes[row, col].set_title(f'Cluster 1\nSample {col+1}')

plt.tight_layout()  # Adjust layout to accommodate suptitle
plt.show()

Here we extract the patches that we will use for the shortcut detection and naming part. We extract patches that have the highest activation in the heatmap and then we extract the surrounding patches to get a better understanding of the shortcut for the LLM.

In [None]:
# Create directories for high activation patches
for i in range(2*NUM_SELECTED):
    cluster = 'cluster_0' if i < NUM_SELECTED else 'cluster_1'
    sample_dir = f"{output_base}/{cluster}/sample_{i%NUM_SELECTED+1}/high_activations"
    os.makedirs(sample_dir, exist_ok=True)

# For each image
for i in range(2*NUM_SELECTED):
    # Get image and its heatmap
    img = closest_images_np[i]
    heatmap = distances_matrix[i].reshape(14, 14)
    
    # Get index of highest activation
    max_flat_idx = np.argmax(heatmap.flatten())
    max_y, max_x = np.unravel_index(max_flat_idx, (14,14))
    
    cluster = 'cluster_0' if i < NUM_SELECTED else 'cluster_1'
    sample_num = i % NUM_SELECTED + 1
    
    # Create grid centered on highest activation
    patch_size = 16
    grid_size = 6
    
    # Calculate starting positions in heatmap coordinates
    start_y = max_y - grid_size//2 + 1
    start_x = max_x - grid_size//2 + 1
    
    # Adjust grid size and starting positions if needed to stay in bounds
    actual_start_y = max(0, start_y)
    actual_start_x = max(0, start_x)
    actual_end_y = min(14, start_y + grid_size)
    actual_end_x = min(14, start_x + grid_size)
    actual_grid_height = actual_end_y - actual_start_y
    actual_grid_width = actual_end_x - actual_start_x
        
    combined_patch = np.zeros((patch_size * actual_grid_height, patch_size * actual_grid_width, 3))
    
    # Extract and place grid of patches around highest activation
    for dy in range(actual_grid_height):
        for dx in range(actual_grid_width):
            y = actual_start_y + dy
            x = actual_start_x + dx
            
            # Convert to image coordinates
            img_y = y * patch_size
            img_x = x * patch_size
            
            # Extract patch
            patch = img[img_y:img_y+patch_size, img_x:img_x+patch_size]
            
            # Place patch in grid
            grid_y = dy * patch_size
            grid_x = dx * patch_size
            combined_patch[grid_y:grid_y+patch_size, grid_x:grid_x+patch_size] = patch
    
    # Save combined patch grid
    patch_path = f"{output_base}/{cluster}/sample_{sample_num}/high_activations/centered_patch_grid.png"
    plt.imsave(patch_path, combined_patch)

We employ the Replicate API to call the LLM models. We use the LLaVa-13B model for the captioning and the Mixtral-8x7B and Llama-3.1-8B and Llama-3.1-70B models for the shortcut detection.

In [None]:
patch_files = []

for cluster in ['cluster_0', 'cluster_1']:
    for sample_num in range(NUM_SELECTED):
        patch_files.append(f"{output_base}/{cluster}/sample_{sample_num}/high_activations/centered_patch_grid.png")

os.environ["REPLICATE_API_TOKEN"] = "<INSERT TOKEN HERE>"
api = replicate.Client(api_token=os.environ["REPLICATE_API_TOKEN"])

caption_models = [("yorickvp/llava-13b:b5f6212d032508382d61ff00469ddda3e32fd8a0e75dc39d8a4191bb742157fb", "llava-13b")]

# Captioning via the LLaVa-13B model
for cap_mod in caption_models:
    cluster_descriptions = []
    
    for cluster in ['cluster_0', 'cluster_1']:
        descriptions = []
        start_idx = 0 if cluster == 'cluster_0' else NUM_SELECTED
        
        for i in range(start_idx, start_idx + NUM_SELECTED):
            output = api.run(
                cap_mod[0],
                input={"image": open(patch_files[i],'rb'), "prompt": "What is in this picture? Describe in a few words."}
            )
            descriptions.append("".join(output))
        
        cluster_descriptions.append(descriptions)
        
        # Create descriptions directory and save
        os.makedirs(f"{output_base}/{cluster}/descriptions", exist_ok=True)
        with open(f"{output_base}/{cluster}/descriptions/{cap_mod[1]}_descriptions.txt", "w") as f:
            f.write("\n".join(descriptions))

shortcut_models = [("meta/meta-llama-3-8b-instruct", "llama8b"), ("meta/meta-llama-3-70b-instruct", "llama70b")]

# Calling the LLama Models
for short_mod in shortcut_models:
    for cluster in ['cluster_0', 'cluster_1']:
        cluster_idx = 0 if cluster == 'cluster_0' else 1
        input = {
            "prompt": f"I extracted patches from images in my dataset where my model seems to focus on the most. I let an LLM caption these images for you. I am searching for potential shortcuts in the dataset. Can you identify one or more possible shortcuts in this dataset? Describe it in one sentence (only!) and pick the most significant. No other explanations are needed. Descriptions: \n" + "".join(cluster_descriptions[cluster_idx]),
            "prompt_template": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        }

        output = api.run(
            short_mod[0],
            input=input
        )

        text = "".join(output)
        
        os.makedirs(f"{output_base}/{cluster}/descriptions", exist_ok=True)
        with open(f"{output_base}/{cluster}/descriptions/{short_mod[1]}_shortcut.txt", "w") as f:
            f.write(text)

# Mixtral model
for cluster in ['cluster_0', 'cluster_1']:
    cluster_idx = 0 if cluster == 'cluster_0' else 1
    input = {
        "top_k": 50,
        "top_p": 0.9,
        "prompt": f"I extracted patches from images in my dataset where my model seems to focus on the most. I let an LLM caption these images for you. I am searching for potential shortcuts in the dataset. Can you identify one or more possible shortcuts in this dataset? Describe it in one sentence (only!) and pick the most significant. No other explanations are needed. If there isn't any shortcut obvious to you then just say so. Descriptions: \n" + "".join(cluster_descriptions[cluster_idx]),
        "temperature": 0.6,
        "max_new_tokens": 1024,
        "prompt_template": "<s>[INST] {prompt} [/INST] "
    }

    output = api.run(
        "mistralai/mixtral-8x7b-instruct-v0.1",
        input=input
    )

    text = "".join(output)

    os.makedirs(f"{output_base}/{cluster}/descriptions", exist_ok=True)
    with open(f"{output_base}/{cluster}/descriptions/mixtral_shortcut.txt", "w") as f:
            f.write(text)

In [12]:
output_base = f"outputs/{MODEL_NAME}_{seed}"

# Create directories for patches
for i in range(2*NUM_SELECTED):
    cluster = 'cluster_0' if i < NUM_SELECTED else 'cluster_1'
    sample_dir = f"{output_base}/{cluster}/sample_{i%NUM_SELECTED+1}/patches"
    os.makedirs(sample_dir, exist_ok=True)

# For each image
for i in range(2*NUM_SELECTED):
    # Get image and its heatmap
    img = closest_images_np[i]
    heatmap = distances_matrix[i].reshape(14, 14)
    
    # Get indices of top 5 activations
    top_5_flat_indices = np.argpartition(heatmap.flatten(), -5)[-5:]
    top_5_positions = [np.unravel_index(idx, (14,14)) for idx in top_5_flat_indices]
    
    cluster = 'cluster_0' if i < NUM_SELECTED else 'cluster_1'
    sample_num = i % NUM_SELECTED + 1
    
    # Extract and save patches for each top activation
    for patch_idx, (y, x) in enumerate(top_5_positions):
        # Convert heatmap coordinates to image coordinates
        patch_size = 16
        img_y = y * patch_size
        img_x = x * patch_size
        
        # Extract patch
        patch = img[img_y:img_y+patch_size, img_x:img_x+patch_size]
        
        # Save patch
        patch_path = f"{output_base}/{cluster}/sample_{sample_num}/patches/patch_{patch_idx+1}.png"
        plt.imsave(patch_path, patch)


These parts are only used for visualization and qualitative evaluation.

In [None]:
# Create output directories if they don't exist
import os

output_base = f"outputs/{MODEL_NAME}_{seed}"
os.makedirs(f"{output_base}/cluster_0", exist_ok=True)
os.makedirs(f"{output_base}/cluster_1", exist_ok=True)

def plot_and_save_heatmaps_and_patches(closest_images_np, distances_matrix, NUM_SELECTED, grid_size=5):
    # Create a figure with subplots for visualization
    fig, axes = plt.subplots(4, NUM_SELECTED, figsize=(20, 8))
    fig.suptitle(f'Distance Heatmaps and {grid_size}x{grid_size} Surrounding Patches', fontsize=16)
    
    # Iterate over each sample
    for i in range(2*NUM_SELECTED):
        row = (i // NUM_SELECTED) * 2
        col = i % NUM_SELECTED
        cluster = 0 if row == 0 else 1
        
        # Get the original image
        img = closest_images_np[i]
        
        # Reshape the distances for this sample into a 14x14 grid
        heatmap_data = distances_matrix[i].reshape(14, 14)
        
        # Find the position of maximum activation
        max_pos = np.unravel_index(np.argmax(heatmap_data), heatmap_data.shape)
        
        # Calculate the corresponding position in the original image (224x224)
        patch_size = 16
        center_y = max_pos[0] * patch_size
        center_x = max_pos[1] * patch_size
        
        # Collect valid patches and their positions
        valid_patches = []
        min_y, max_y = float('inf'), -float('inf')
        min_x, max_x = float('inf'), -float('inf')
        
        offset = grid_size // 2
        for dy in range(-offset, offset + 1):
            for dx in range(-offset, offset + 1):
                y = center_y + (dy * patch_size)
                x = center_x + (dx * patch_size)
                
                # Check boundaries
                if (0 <= y < img.shape[0]-patch_size and 
                    0 <= x < img.shape[1]-patch_size):
                    patch = img[y:y+patch_size, x:x+patch_size]
                    valid_patches.append((dy+offset, dx+offset, patch))
                    min_y = min(min_y, dy+offset)
                    max_y = max(max_y, dy+offset)
                    min_x = min(min_x, dx+offset)
                    max_x = max(max_x, dx+offset)
        
        if valid_patches:
            # Create composite patch of exact size needed
            height = (max_y - min_y + 1) * patch_size
            width = (max_x - min_x + 1) * patch_size
            composite_patch = np.zeros((height, width, 3))
            
            # Place patches in their relative positions
            for py, px, patch in valid_patches:
                y_pos = (py - min_y) * patch_size
                x_pos = (px - min_x) * patch_size
                composite_patch[y_pos:y_pos+patch_size, x_pos:x_pos+patch_size] = patch
            
            # Save the composite patch without white space
            plt.imsave(
                f"{output_base}/cluster_{cluster}/sample_{col+1}_patches.png",
                composite_patch
            )
            
            # For visualization in the notebook
            im = axes[row, col].imshow(heatmap_data, cmap='hot', interpolation='nearest', vmin=0, vmax=1)
            axes[row, col].axis('off')
            
            axes[row+1, col].imshow(composite_patch)
            axes[row+1, col].axis('off')
            
            # Set titles
            if row == 0:
                axes[row, col].set_title(f'Cluster 0\nSample {col+1}')
            else:
                axes[row, col].set_title(f'Cluster 1\nSample {col+1}')

    plt.tight_layout()
    plt.show()

# Usage
plot_and_save_heatmaps_and_patches(closest_images_np, distances_matrix, NUM_SELECTED, grid_size=5)

We extract the 100 most distant keys for each cluster and the 100 lowest keys for each cluster. This is used for the shortcut detection using a KNN classifier as described in section 3.4.1 of the paper.

In [None]:
# Get the 100 most distant keys for each cluster
top_100_keys_cluster_0 = []
top_100_keys_cluster_1 = []

lowest_100_keys_cluster_0 = []
lowest_100_keys_cluster_1 = []

top_100_img_indices_c0 = []
top_100_img_indices_c1 = []

NUM_SELECTED_PATCHES = 200 

# For cluster 0
cluster_0_distances = distances_matrix[:NUM_SELECTED]
top_100_indices_cluster_0 = np.argsort(cluster_0_distances.flatten())[-NUM_SELECTED_PATCHES:]
for idx in top_100_indices_cluster_0:
    sample_idx = idx // 196
    key_idx = idx % 196
    top_100_img_indices_c0.append((sample_idx, key_idx))
    top_100_keys_cluster_0.append(keys[cluster_0_indices[sample_idx]][key_idx])

# For cluster 1
cluster_1_distances = distances_matrix[NUM_SELECTED:]
top_100_indices_cluster_1 = np.argsort(cluster_1_distances.flatten())[-NUM_SELECTED_PATCHES:]
for idx in top_100_indices_cluster_1:
    sample_idx = idx // 196
    key_idx = idx % 196
    top_100_img_indices_c1.append((sample_idx, key_idx))
    top_100_keys_cluster_1.append(keys[cluster_1_indices[sample_idx]][key_idx])

# Get the 100 lowest keys for each cluster
cluster_0_distances = distances_matrix[:NUM_SELECTED]
lowest_100_indices_cluster_0 = np.argsort(cluster_0_distances.flatten())[:NUM_SELECTED_PATCHES]
for idx in lowest_100_indices_cluster_0:
    sample_idx = idx // 196
    key_idx = idx % 196
    lowest_100_keys_cluster_0.append(keys[cluster_0_indices[sample_idx]][key_idx])

cluster_1_distances = distances_matrix[NUM_SELECTED:]
lowest_100_indices_cluster_1 = np.argsort(cluster_1_distances.flatten())[:NUM_SELECTED_PATCHES]
for idx in lowest_100_indices_cluster_1:
    sample_idx = idx // 196
    key_idx = idx % 196
    lowest_100_keys_cluster_1.append(keys[cluster_1_indices[sample_idx]][key_idx])

# Convert to numpy arrays and reshape
top_100_keys_cluster_0 = np.array(top_100_keys_cluster_0).reshape(NUM_SELECTED_PATCHES, 64)
top_100_keys_cluster_1 = np.array(top_100_keys_cluster_1).reshape(NUM_SELECTED_PATCHES, 64)

print("Shape of top 100 keys for cluster 0:", top_100_keys_cluster_0.shape)
print("Shape of top 100 keys for cluster 1:", top_100_keys_cluster_1.shape)

Here we train a KNN classifier to detect shortcuts during inference.

In [None]:
X = np.concatenate([lowest_100_keys_cluster_0, top_100_keys_cluster_1], axis=0)
y = np.concatenate([np.zeros(NUM_SELECTED_PATCHES), np.ones(NUM_SELECTED_PATCHES)], axis=0)

# Initialize and train the KNN classifier
knn = KNeighborsClassifier(n_neighbors=5)  # You can adjust the number of neighbors
knn.fit(X, y)

In [15]:
conv_out_model = Extract(model, node_out="encoder.dropout")
conv_in_model = Extract(model, node_in="encoder.dropout", node_out="heads.head")

This is helper code to get the surrounding patches for the shortcut detection and to have an inference function that utilizes the KNN classifier to detect shortcuts during inference.

In [16]:
import torch.nn.functional as F

def get_surrounding_indices(index, grid_size):
    patch_size = 16
    num_patches = grid_size // patch_size
    row = index // num_patches
    col = index % num_patches
    
    surrounding_indices = []
    
    # Define the possible directions (top, bottom, left, right, and diagonals)
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
    
    for dr, dc in directions:
        new_row = row + dr
        new_col = col + dc
        
        # Check if the new coordinates are within bounds
        if 0 <= new_row < num_patches and 0 <= new_col < num_patches:
            surrounding_index = new_row * num_patches + new_col
            surrounding_indices.append(surrounding_index)
    
    return surrounding_indices

def add_surrounding_patches(indices, grid_size=224):
    # Convert the list of indices to a set for efficient look-up
    indices_set = set(indices)
    
    new_indices = set(indices)
    
    for index in indices:
        surrounding_indices = get_surrounding_indices(index, grid_size)
        for surrounding_index in surrounding_indices:
            if surrounding_index not in indices_set:
                new_indices.add(surrounding_index)
    
    return list(new_indices)

def get_mask_from_indices(indices):
    mask = torch.ones(197, dtype=bool)
    mask[indices] = False
    return mask

def inference_for_loader(loader):
    correct_abilation = 0
    correct = 0
    total = 0
    ablated = 0
    correct_images = []

    for img, target in tqdm(loader):
        all_patches = conv_out_model(img.to(device))[0, 1:, :].detach().cpu().numpy()

        ln_1 = torch.tensor(all_patches.reshape(196, -1)).to(device)
        k = ln_1 @ W_K.T

        k = k.reshape(196, 12, 64)
        all_keys = k.mean(dim=1).detach().cpu()

        preds = knn.predict(all_keys)

        indices = np.where(preds == 1)[0] 

        indices = np.array(add_surrounding_patches(indices))

        indices += 1

        if len(indices) > 0:
            ablated += 1

        mask = get_mask_from_indices(indices)
        
        x = img.to(device)
        encoder = conv_out_model(x)
        enc = encoder[:, mask, :]

        pred_abilation = F.softmax(conv_in_model(x, enc), dim=1).argmax().item()
        pred = F.softmax(model(x), dim=1).argmax().item()

        if pred == target.item():
            correct += 1

        if pred_abilation == target.item():
            correct_abilation += 1
            
        if pred != pred_abilation:
            correct_images.append(img)

        total += 1

    return correct, correct_abilation, total, correct_images, ablated

We now perform the inference on the test set.

In [None]:
test_w_patches_cls_0_loader = torch.utils.data.DataLoader(test_w_patches_cls_0, batch_size=1, shuffle=False)
test_w_patches_cls_1_loader = torch.utils.data.DataLoader(test_w_patches_cls_1, batch_size=1, shuffle=False)
test_wo_patches_cls_0_loader = torch.utils.data.DataLoader(test_wo_patches_cls_0, batch_size=1, shuffle=False)
test_wo_patches_cls_1_loader = torch.utils.data.DataLoader(test_wo_patches_cls_1, batch_size=1, shuffle=False)

corrects, correct_abilations, totals, ablateds = [], [], [], []

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_w_patches_cls_0_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_w_patches_cls_1_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_wo_patches_cls_0_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_wo_patches_cls_1_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

import pandas as pd
res = pd.DataFrame([corrects, correct_abilations, totals, ablateds], columns=["Malignant w/ Patch", "Benign w/ Patch", "Malignant w/o Patch", "Benign w/o Patch"], index=["W Patches Benign", "W Patches Malignant", "Wo Patches Benign", "Wo Patches Malignant"])
res.to_csv(f"outputs/{MODEL_NAME}_{seed}/results.csv")

This next parts helps us understand how many images have at least one ablated patch.

In [None]:
val_shortcut_cls_0 = []
val_shortcut_cls_1 = []
val_non_shortcut_cls_0 = []
val_non_shortcut_cls_1 = []

for i in tqdm(range(len(concatenated_dataset))):
    all_patches = conv_out_model(concatenated_dataset[i][0].unsqueeze(0).to(device))[0, 1:, :].detach().cpu().numpy()

    ln_1 = torch.tensor(all_patches.reshape(196, -1)).to(device)
    k = ln_1 @ W_K.T

    k = k.reshape(196, 12, 64)
    all_keys = k.mean(dim=1).detach().cpu()

    preds = knn.predict(all_keys)
    if preds.sum() > 1 and concatenated_dataset[i][1] == 0:
        val_shortcut_cls_0.append(i)
    elif preds.sum() > 1 and concatenated_dataset[i][1] == 1:
        val_shortcut_cls_1.append(i)
    elif preds.sum() == 0 and concatenated_dataset[i][1] == 0:
        val_non_shortcut_cls_0.append(i)
    elif preds.sum() == 0 and concatenated_dataset[i][1] == 1:
        val_non_shortcut_cls_1.append(i)

In [None]:
print(len(val_shortcut_cls_0), len(val_shortcut_cls_1), len(val_non_shortcut_cls_0), len(val_non_shortcut_cls_1))

Now based on this we follow DFR to create equally weighted subsets for last layer retraining. We orient ourselves on the shortest list to create a balanced dataset.

In [20]:
import random
# get the minimum length of the four lists
min_length = min(len(val_shortcut_cls_0), len(val_shortcut_cls_1), len(val_non_shortcut_cls_0), len(val_non_shortcut_cls_1))

# randomly sample the same number of images from each list
sampled_val_shortcut_cls_0 = random.sample(val_shortcut_cls_0, min_length)
sampled_val_shortcut_cls_1 = random.sample(val_shortcut_cls_1, min_length)
sampled_val_non_shortcut_cls_0 = random.sample(val_non_shortcut_cls_0, min_length)
sampled_val_non_shortcut_cls_1 = random.sample(val_non_shortcut_cls_1, min_length)

# create subset of concatenated_dataset
subset = torch.utils.data.Subset(concatenated_dataset, sampled_val_shortcut_cls_0 + sampled_val_shortcut_cls_1 + sampled_val_non_shortcut_cls_0 + sampled_val_non_shortcut_cls_1)

In [None]:
train_loader = torch.utils.data.DataLoader(subset, batch_size=1, shuffle=False)

train_images = []
train_encs = []
lbls = []

num_epochs = 1

selected_images = []

idx = 0
for epoch in range(num_epochs):
    conv_in_model.train()
    with tqdm(train_loader, unit="batch") as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)

            lbls.append(labels.item())
            
            out = conv_out_model(images.to(device)).detach().cpu().numpy()

            all_patches = out[0, 1:, :]

            ln_1 = torch.tensor(all_patches.reshape(196, -1)).to(device)
            k = ln_1 @ W_K.T

            k = k.reshape(196, 12, 64)
            all_keys = k.mean(dim=1).detach().cpu()

            means = knn.predict(all_keys.reshape(196, -1))
            indices = np.where(means > 0)[0] 
            indices = np.array(add_surrounding_patches(indices))

            indices += 1

            mask = get_mask_from_indices(indices)
            
            x = images.to(device)
            enc = out[:, mask, :]

            train_encs.append(enc)
            train_images.append(x.cpu())

            selected_images.append(idx)

            idx += 1

We built a custom dataset to train the last layer of the model and then freezed all other layers in the model. Since we are using ablation during this process we are using pytorch surgeons extract function to extract the activations.

In [22]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Assume train_encs is provided as input
sequences = [torch.tensor(arr).squeeze().cpu() for arr in train_encs]

padded_encs = pad_sequence(sequences, batch_first=True)

# Assume train_images and dataset.targets are provided as input
train_images = torch.stack(train_images).squeeze().cpu()
trgts = torch.tensor(lbls)

# Custom dataset for train_images, full_batch, and trgts
class CustomDataset(Dataset):
    def __init__(self, images, encs, targets):
        self.images = images
        self.encs = encs
        self.targets = targets

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.encs[idx], self.targets[idx]

custom_dataset = CustomDataset(train_images, padded_encs, trgts)

In [23]:
train_loader = DataLoader(custom_dataset, batch_size=16, shuffle=True)

conv_out_model = Extract(model, node_out="encoder.dropout")
conv_in_model = Extract(model, node_in="encoder.dropout", node_out="heads.head")

# set every parameter to False but the last one in conv_in_model
for name, param in conv_in_model.named_parameters():
    param.requires_grad = False

conv_in_model.heads.head.weight.requires_grad = True
conv_in_model.heads.head.bias.requires_grad = True

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(conv_in_model.parameters(), lr=1e-3)

We traing the last-layer of the model for a given nuber of epochs and test the performance on the test set.

In [None]:
losses = []
conv_in_model.train()

num_epochs = 5

for epoch in range(num_epochs):
    for batch_img, batch_enc, batch_trgts in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = conv_in_model(batch_img.to(device), batch_enc.to(device))
        loss = criterion(outputs, batch_trgts.to(device))
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    
conv_in_model.eval()

In [None]:
corrects, correct_abilations, totals, ablateds = [], [], [], []

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_w_patches_cls_0_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_w_patches_cls_1_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_wo_patches_cls_0_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

correct, correct_abilation, total, correct_images, ablated = inference_for_loader(test_wo_patches_cls_1_loader)
print(correct, correct_abilation, total, ablated)
corrects.append(correct)
correct_abilations.append(correct_abilation)
totals.append(total)
ablateds.append(ablated)

In [26]:
import pandas as pd
res = pd.DataFrame([corrects, correct_abilations, totals, ablateds], columns=["Malignant w/ Patch", "Benign w/ Patch", "Malignant w/o Patch", "Benign w/o Patch"], index=["W Patches Benign", "W Patches Malignant", "Wo Patches Benign", "Wo Patches Malignant"])
res.to_csv(f"outputs/{MODEL_NAME}_{seed}/results_retraining_val.csv")