## Use Case: Locating Conceptual Groupings in T2V-ResNet18

In [None]:
# Import libraries

import os
import sys
sys.path.append("../..")

import clip
import torch
import utils
import data_utils
import DnD_models
import scoring_function

import pandas as pd
import random
from time import time
import math

from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler

from transformers import BlipProcessor, BlipForConditionalGeneration

### Experiment Settings

In [None]:
clip_name = 'ViT-B/16'
target_name = 'resnet18_tile2vec'
target_layer = 'layer1'
d_probe = 'NAIP'

batch_size = 200
device = 'cuda:0'
pool_mode = 'avg'

results_dir = 'exp_results'
saved_acts_dir = 'saved_activations'
num_images_to_check = 10
blip_batch_size = 10

tag = "concept_grouping"

In [None]:
layer_dims = {"layer1" : 64, "layer2" : 128, "layer3" : 256, "layer4" : 512, "layer5" : 512}

ids_to_check = list(range(layer_dims[target_layer]))
print(ids_to_check)
print(len(ids_to_check))

### Load Models

In [None]:
# Load BLIP model

BLIP_PATH = "/expanse/lustre/scratch/nbai/temp_project/model_base_capfilt_large.pth"

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) 
pretrained_dict = torch.load(BLIP_PATH)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()

In [None]:
# Initialize Stable Diffusion

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
generator = torch.Generator(device=device).manual_seed(0)
pipe = pipe.to(device)

### Set Up Results File

In [None]:
# Setting up file/directory paths for saving results
pot_column_names = ['Neuron ID'] + ['Concept {}'.format(i) for i in range(5)]
all_concepts = pd.DataFrame(columns=pot_column_names)
result_column_names = ['Neuron ID', 'Label 1', 'Label 2', 'Label 3']
final_concepts = pd.DataFrame(columns=result_column_names)

In [None]:
# Create results folder
results_path = utils.create_layer_folder(results_dir = results_dir, base_dir = ".", target_name = target_name, 
                          d_probe = d_probe, layer = target_layer, tag = tag)

### Construct Augmented Probing Data (DnD Step 1)

In [None]:
# Get activations
target_save_name = utils.get_save_names(target_name = target_name,
                                  target_layer = target_layer, d_probe = d_probe,
                                  pool_mode=pool_mode, base_dir = '.', saved_acts_dir = saved_acts_dir)

utils.save_activations(target_name = target_name, target_layers = [target_layer],
                       d_probe = d_probe, batch_size = batch_size, device = device,
                       pool_mode=pool_mode, base_dir = '.', saved_acts_dir = saved_acts_dir)

target_feats = torch.load(target_save_name, map_location='cpu')

pil_data = data_utils.get_data(d_probe)

In [None]:
# Find top activating images
top_vals, top_ids = torch.topk(target_feats, k=num_images_to_check, dim=0)

all_imgs = []
all_img_ids = {neuron_id:[] for neuron_id in ids_to_check}

# Find top activating image crops
for t, orig_id in enumerate(ids_to_check):
    print("Cropping for Neuron {}/{}".format(t+1,len(ids_to_check)))
    activating_images = []
    for i, top_id in enumerate(top_ids[:, orig_id]):
        im = pil_data[top_id][0].type(torch.FloatTensor)
        all_img_ids[orig_id].append(len(all_imgs))
        all_imgs.append(im)
        activating_images.append(im)
    cropped_images = []
    if(target_layer != 'fc'):
        cropped_images = DnD_models.get_attention_crops(target_name, activating_images, orig_id, num_crops_per_image = 4, target_layers = [target_layer], device = device)

    for img in cropped_images:
        all_img_ids[orig_id].append(len(all_imgs))
        all_imgs.append(img)

### Generative Captioning (DnD Step 2)

In [None]:
# Get target activations with D_probe + D_cropped
target_feats = utils.get_target_activations(target_name, all_imgs, [target_layer])

# Find top activating images
top_vals, top_ids = torch.sort(target_feats, dim=0, descending = True)
comp_words = {orig_id : [] for orig_id in ids_to_check}
top_images = {orig_id:[] for orig_id in ids_to_check}

# Step 2 - Generate Candidate Concepts
for neuron_num, orig_id in enumerate(ids_to_check):
    print("Neuron: {} ({}/{})".format(orig_id, neuron_num+1, len(ids_to_check)))

    # Plot and save highest activating images
    fig, images, top_images = utils.get_top_images(orig_id, top_ids, top_images, 
                                                   all_imgs, all_img_ids, num_images_to_check, 
                                                   blip_batch_size, convert_from_np = True)
    utils.save_activating_fig(fig, results_path, orig_id)
    
    # Generate and simplify BLIP Captions
    descriptions = DnD_models.blip_caption(model, processor, images, blip_batch_size, device)
    for i, description in enumerate(descriptions):
        descriptions[i] = DnD_models.GPT_simplify(description)

    # Summarize BLIP descriptions
    for i in range(5):
        cand_concept = DnD_models.GPT_model_naip(descriptions)
        comp_words[orig_id].append(cand_concept)
        random.shuffle(descriptions)
        print("Candidate Concept {}: {}".format(i+1, cand_concept))
    all_concepts.loc[len(all_concepts)] = [orig_id] + comp_words[orig_id]

# Save candidate concepts
utils.save_potential_concepts(all_concepts, results_path)

### Find Neuron Groups

In [None]:
def text_embed(model, labels):

    text_features = []
    label_feats = clip.tokenize(["{}".format(word) for word in labels]).to(device)

    with torch.no_grad():
        for i in range(math.ceil(len(label_feats)/batch_size)):
            text_features.append(model.encode_text(label_feats[batch_size*i:batch_size*(i+1)]))

    text_features = torch.cat(text_features, dim=0)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

In [None]:
"""

- We skip Step 3: Best Concept Selection because Stable Diffusion cannot generate land coverage specific synthetic images

- Optimal thresholding value can change based on generated labels or target layer. 
  We suggest experimenting with [0.75, 0.80, 0.85, 0.90, 0.95, 0.97]
  
"""

THRESH = 0.80
clip_model, _ = clip.load(clip_name, device=device)

cls = {id : [] for id in ids_to_check}
t0 = time()

for id in ids_to_check:
    descrip = []
    class_labels = []
    for label in comp_words[id]:
        descrip.append(label[:label.find('(')])
        class_labels.append(label[label.find('(') + 1 : label.find(')')])
    cls[id].append((class_labels, descrip, text_embed(clip_model, class_labels), text_embed(clip_model, descrip)))
    
print("Embedding time: {:.3f}".format(time()-t0))

record = {}
for id in ids_to_check:
    record["Class: {}, {}".format(cls[id][0][0][0], cls[id][0][0][1])] = [id]
    record["Concept: {}, {}".format(cls[id][0][1][0], cls[id][0][1][1])] = [id]

grouped_cls = [False for _ in range(len(ids_to_check))]
grouped_cpt = [False for _ in range(len(ids_to_check))]

for i, id1 in enumerate(ids_to_check[:-1]):
    for id2 in ids_to_check[i + 1:]:
        if torch.mean(cls[id1][0][2] @ cls[id2][0][2].T) > THRESH and grouped_cls[id1] is False:
            record["Class: {}, {}".format(cls[id1][0][0][0],cls[id1][0][0][1])].append(id2)
            grouped_cls[id2] = True
        if torch.mean(cls[id1][0][3] @ cls[id2][0][3].T) > THRESH and grouped_cpt[id1] is False:
            record["Concept: {}, {}".format(cls[id1][0][1][0],cls[id1][0][1][1])].append(id2)
            grouped_cpt[id2] = True
            
record = dict(sorted(record.items(), key=lambda item: len(item[1]), reverse = True))

cls_classified = []
cpt_classified = []
for label in record:
    if len(record[label]) > 1 and "Class" in label:
        record[label] = list(set(record[label]))
        record[label].sort()
        print("{}\n{}\n".format(label, record[label]))
        cls_classified.extend(record[label])
    elif len(record[label]) > 1 and "Concept" in label:
        record[label] = list(set(record[label]))
        record[label].sort()
        print("{}\n{}\n".format(label, record[label]))
        cpt_classified.extend(record[label])

non_classified_by_class = list(set(ids_to_check) - set(cls_classified))
non_classified_by_class.sort()

non_classified_by_concept = list(set(ids_to_check) - set(cpt_classified))
non_classified_by_concept.sort()

print("No classification by class: {} neurons\n{}".format(len(non_classified_by_class), non_classified_by_class))
print("No classification by concept: {} neurons\n{}".format(len(non_classified_by_concept), non_classified_by_concept))