## Ablation: CLIP Caption Quantitative

In [None]:
# Import libraries

import os
import sys
sys.path.append("..")

import clip
import torch
import utils
import data_utils
import DnD_models
import scoring_function

import pandas as pd
import random

from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler

from transformers import BlipProcessor, BlipForConditionalGeneration

from sentence_transformers import SentenceTransformer
from evaluate import load

### Experiment Settings

In [None]:
clip_name = 'ViT-B/16'
target_name = 'resnet50'
target_layer = 'fc'
d_probe = 'imagenet_broden'
concept_set = '../data/20k.txt'

with open(concept_set, 'r') as f: 
    concepts = (f.read()).split('\n')

batch_size = 200
device = 'cuda'
pool_mode = 'avg'

results_dir = 'exp_results'
saved_acts_dir = 'saved_activations'
num_images_to_check = 10
blip_batch_size = 10
tag = "clip_fc"

bertscore = load("bertscore")
ids_to_check = list(range(1000))

### Load Models

In [None]:
# Define embedding models

mpnetmodel = SentenceTransformer('all-mpnet-base-v2')
clip_model, clip_preprocess = clip.load(clip_name, device=device)

In [None]:
# Initialize Stable Diffusion

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
generator = torch.Generator(device=device).manual_seed(0)
pipe = pipe.to(device)

### Set Up Results File

In [None]:
# Setting up file/directory paths for saving results
pot_column_names = ['Neuron ID'] + ['Concept {}'.format(i) for i in range(5)]
all_concepts = pd.DataFrame(columns=pot_column_names)
result_column_names = ['Neuron ID', 'Label 1', 'Label 2', 'Label 3']
final_concepts = pd.DataFrame(columns=result_column_names)

In [None]:
# Create results folder
results_path = utils.create_layer_folder(results_dir = results_dir, base_dir = ".", target_name = target_name, 
                          d_probe = d_probe, layer = target_layer, tag = tag)

### Construct Augmented Probing Data (DnD Step 1)

In [None]:
# Get activations
target_save_name = utils.get_save_names(target_name = target_name,
                                  target_layer = target_layer, d_probe = d_probe,
                                  pool_mode=pool_mode, base_dir = '.', saved_acts_dir = saved_acts_dir)

utils.save_activations(target_name = target_name, target_layers = [target_layer],
                       d_probe = d_probe, batch_size = batch_size, device = device,
                       pool_mode=pool_mode, base_dir = '.', saved_acts_dir = saved_acts_dir)

target_feats = torch.load(target_save_name, map_location='cpu')

pil_data = data_utils.get_data(d_probe)

In [None]:
# Find top activating images
top_vals, top_ids = torch.topk(target_feats, k=num_images_to_check, dim=0)

all_imgs = []
all_img_ids = {neuron_id:[] for neuron_id in ids_to_check}

# Find top activating image crops
for t, orig_id in enumerate(ids_to_check):
    print("Cropping for Neuron {}/{}".format(t+1,len(ids_to_check)))
    
    activating_images = []
    for i, top_id in enumerate(top_ids[:, orig_id]):
        
        # Reshape activating images
        im, label = pil_data[top_id]
        im = im.resize([375,375])
        
        # Add image to d_probe with image ID
        all_img_ids[orig_id].append(len(all_imgs))
        all_imgs.append(im)
        activating_images.append(im)
        
    cropped_images = []
    
    # Get crops - FC layers do not have feature maps
    if(target_layer != 'fc'):
        cropped_images = DnD_models.get_attention_crops(target_name, activating_images, orig_id, num_crops_per_image = 4, target_layers = [target_layer], device = device)

    # Add crops into d_probe with image ID
    for img in cropped_images:
        all_img_ids[orig_id].append(len(all_imgs))
        all_imgs.append(img)

### Captioning with CLIP (DnD Step 2)

In [None]:
# Get target activations with D_probe + D_cropped
target_feats = utils.get_target_activations(target_name, all_imgs, [target_layer])

# Find top activating images
top_vals, top_ids = torch.sort(target_feats, dim=0, descending = True)
comp_words = {orig_id : [] for orig_id in ids_to_check}
top_images = {orig_id:[] for orig_id in ids_to_check}

# Step 2 - Generate Candidate Concepts
for neuron_num, orig_id in enumerate(ids_to_check):
    print("Neuron: {} ({}/{})".format(orig_id, neuron_num+1, len(ids_to_check)))

    # Plot and save highest activating images
    fig, images, top_images = utils.get_top_images(orig_id, top_ids, top_images, 
                                                   all_imgs, all_img_ids, num_images_to_check, 
                                                   blip_batch_size)
    utils.save_activating_fig(fig, results_path, orig_id)
    
    text = clip.tokenize(["{}".format(word) for word in concepts]).to(device)
    concept_embeds = utils.get_clip_text_features(clip_model, text)
    concept_embeds /= concept_embeds.norm(dim=-1, keepdim = True)
    
    image_embeds = utils.get_clip_image_features(clip_model, clip_preprocess, images, device = device)
    image_embeds /= image_embeds.norm(dim=-1, keepdim = True)
    
    clip_feats = image_embeds @ concept_embeds.T
    
    top_words_idx = torch.argmax(clip_feats,1)
    descriptions = [concepts[idx] for idx in top_words_idx]
    for t, label in enumerate(descriptions):
        print("Image {}: {}".format(t,label))

    # Summarize CLIP descriptions
    for i in range(5):
        cand_concept = DnD_models.GPT_model_single(descriptions)
        comp_words[orig_id].append(cand_concept)
        random.shuffle(descriptions)
        print("Candidate Concept {}: {}".format(i+1, cand_concept))
    all_concepts.loc[len(all_concepts)] = [orig_id] + comp_words[orig_id]

# Save candidate concepts
utils.save_potential_concepts(all_concepts, results_path)

### Best Concept Selection (DnD Step 3)

In [None]:
"""
We adjust concepts with certain vague words to help SD generation
"""

replace_set = ['design','designs','graphic','graphics']
for orig_id in ids_to_check:
    comp_words[orig_id] = [concept.lower() for concept in comp_words[orig_id]]
    for i, word in enumerate(comp_words[orig_id]):
        if word[-1] == '.':
            comp_words[orig_id][i] = word[:-1]
        if word.split()[-1] in replace_set:
            new_concept = word + ' background'
            comp_words[orig_id].append(new_concept)
    comp_words[orig_id] = list(set(comp_words[orig_id]))

In [None]:
pil_data = data_utils.get_data(d_probe)
d_probe_len = len(pil_data)
all_final_results = {neuron_id : [] for neuron_id in ids_to_check}

num_images_per_prompt = 10
top_K_param = 10
beta_images_param = 5
scoring_func = 'topk-sq-mean'

sd_prompt = 'One realistic image of {}'
num_inference_steps = 50

# Step 3 Main Code
for list_id, orig_id in enumerate(ids_to_check):

    # Initialize starting concepts
    word_list = comp_words[orig_id]
    
    print("Neuron {}".format(orig_id))

    # Account for modified candidate concept (if necessary)
    labels_to_check = len(word_list)

    print("# Labels to Check: {}".format(labels_to_check), "   # Images per Concept: {}".format(num_images_per_prompt))

    add_im = {}
    add_im_id = {}
    all_sd_imgs = []

    # Generate images for each label
    for label_id in range(labels_to_check):
        
        print("Label {}/{}: {}".format(label_id + 1, labels_to_check, word_list[label_id]))
        pred_label = sd_prompt.format(word_list[label_id])
        add_im_id[label_id] = []
        
        add_im, add_im_id, all_sd_imgs = DnD_models.generate_sd_images(add_im, add_im_id, all_sd_imgs, 
                                                                  pred_label, label_id, pipe, generator,
                                                                  num_images_per_prompt, num_inference_steps)
    
    # Concept Scoring
    target_feats = utils.get_target_activations(target_name, all_sd_imgs, [target_layer])
    
    ranks, highest_activating = utils.rank_images(target_feats, orig_id, labels_to_check,
                                                 add_im_id, add_im, top_K_param)
    
    clip_weight = scoring_function.compare_images(top_images[orig_id], highest_activating, clip_name, 
                                                  device, target_name, top_K_param)
    
    top_avg_topk = scoring_function.get_score(ranks, mode = scoring_func, hyp_param = beta_images_param)
    
    top_avg_comb = []
    for i in range(len(clip_weight)):
        concept_rank = len(top_avg_topk) - scoring_function.find_by_last(top_avg_topk, clip_weight[i][1])
        weight = clip_weight[i][0]
        concept_score = concept_rank * weight
        top_avg_comb.append((concept_score, clip_weight[i][1]))
        
    top_avg_comb.sort(reverse = True)
    
    # Save results in .csv file
    for label_num in range(3):
        if(label_num < len(top_avg_comb)):
            all_final_results[orig_id] += [word_list[top_avg_comb[label_num][1]]]
        else:
            all_final_results[orig_id] += [' ']
    final_concepts.loc[len(final_concepts)] = [orig_id] + all_final_results[orig_id]
    
    # Print results
    print('------------------------------\n')
    print('Neuron {}:'.format(orig_id))
    for k, word in enumerate(all_final_results[orig_id]):
        if(word != " "):
            print("Label {}: {}".format(k + 1, word))
        else:
            break
    print('\n------------------------------')
    
utils.save_final_results(final_concepts, results_path)

### Find FC Layer Similarity

In [None]:
# Get ground truths
with open('../data/imagenet_labels.txt', 'r') as f:
    classes = f.read().split('\n')

# Get DnD labels
dnd_labels = pd.read_csv('./exp_results/resnet50_imagenet_broden_layerfc_results_clip_fc/all_result_concepts/final_results.csv')
dnd_preds = list(dnd_labels['Label 1'])

# Get DnD labels with CLIP captioning
dnd_clip_preds = []
for orig_id in ids_to_check:
    dnd_clip_preds.append(all_final_results[orig_id][0])

In [None]:
# Calculate similarities between labels and ground truths

clip_cos, mpnet_cos = utils.get_cos_similarity(dnd_preds, classes, clip_model, mpnetmodel, device, batch_size)
bert_score = bertscore.compute(predictions=dnd_preds, references=classes, lang="en")
print("DnD w/ BLIP - Clip similarity: {:.4f}, mpnet similarity: {:.4f}, BERTScore: {:4f}".format(clip_cos, mpnet_cos, sum(bert_score["f1"]) / len(bert_score["f1"])))
clip_cos, mpnet_cos = utils.get_cos_similarity(dnd_clip_preds, classes, clip_model, mpnetmodel, device, batch_size)
bert_score = bertscore.compute(predictions=dnd_clip_preds, references=classes, lang="en")
print("DnD w/ CLIP - Clip similarity: {:.4f}, mpnet similarity: {:.4f}, BERTScore: {:4f}".format(clip_cos, mpnet_cos, sum(bert_score["f1"]) / len(bert_score["f1"])))