In [1]:
%cd ..

/work/yuxiang1234/sandbox-AND


## Import packages

In [2]:
import json
import os
import pickle
import random
import sys
from collections import defaultdict

import numpy as np
import torch
import torch.nn.utils.prune as prune
from pruning.audio_dataset import ESC50Dataset, collate_batch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor

from args import parser
from data_utils import (get_label_to_cls, get_target_model, read_json, get_cls_label, mean)
from sentence_utils import get_basename


[nltk_data] Downloading package stopwords to
[nltk_data]     /home/yuxiang1234/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Arguments

In [7]:
class Arguments:
    def __init__(self):

        self.target_name = "ast-esc50" # Model to dissect (target model)
        
        self.target_layers = "layer0_output,layer0_intermediate,layer0_attention_output, \
                        	  layer1_output,layer1_intermediate,layer1_attention_output, \
                              layer2_output,layer2_intermediate,layer2_attention_output, \
                              layer3_output,layer3_intermediate,layer3_attention_output, \
                              layer4_output,layer4_intermediate,layer4_attention_output, \
                              layer5_output,layer5_intermediate,layer5_attention_output, \
                              layer6_output,layer6_intermediate,layer6_attention_output, \
                              layer7_output,layer7_intermediate,layer7_attention_output, \
                              layer8_output,layer8_intermediate,layer8_attention_output, \
                              layer9_output,layer9_intermediate,layer9_attention_output, \
                              layer10_output,layer10_intermediate,layer10_attention_output, \
                              layer11_output,layer11_intermediate,layer11_attention_output"
        # Which layer neurons to describe. String list of layer names to describe, separated by comma (no spaces). 
        # Follows the naming format of the Pytorch module used.

        # For beats
        # self.target_layers = "layer0_1,layer0_2,layer1_1,layer1_2,layer2_1,layer2_2,layer3_1,layer3_2,layer4_1,layer4_2,layer5_1,layer5_2,layer6_1,layer6_2,layer7_1,layer7_2,layer8_1,layer8_2,layer9_1,layer9_2,layer10_1,layer10_2,layer11_1,layer11_2"
        
        self.probing_dataset = "esc50"  # Probing dataset to probe the target model
        self.concept_set_file = "data/concept_set/esc50.txt"  # Path to txt file of concept set
        self.network_class_file = "data/network_class/esc50.txt"  # Path to txt file of network's classification class
        self.clip_model = "ViT-B/32"  # CLIP model version to use
        self.clap_model = "ViT-B/32"  # CLAP model version to use
        self.sentence_transformer = 'all-MiniLM-L12-v2'  # Sentence transformer to use
        self.batch_size = 1  # Batch size when running CLIP/target model
        self.device = "cuda"  # Whether to use GPU/which GPU
        self.seed = 20  # Seed number
        self.num_of_gpus = 1  # Number of available GPUs for vllm
        self.pool_mode = "avg"  # Aggregation function for channels
        self.scoring_func = False  # Scoring function flag

        # Directory paths
        self.audio_description_dir = "audio_description"  # Directory to save audio descriptions
        self.audio_dir = "save_audios"  # Directory to save audio
        self.save_activation_dir = "saved_activations"  # Directory to save activation values
        self.save_summary_dir = "summaries"  # Directory to save summaries
        self.save_discriminative_sample_dir = "discriminative_samples"  # Directory to save discriminative samples
        self.save_prediction_dir = "prediction"  # Directory to save prediction
        self.save_interpretability_dir = 'interpretability'  # Directory to save interpretability experiments

        # Discriminative settings
        self.discriminative_type = "highly"  # Type of discriminative samples
        self.post_process_type = "sim"  # Post-processing type
        self.mutual_info_threshold = 0.6  # Mutual information threshold
        self.K = 5  # Top-K highly/lowly-activated audio
        self.clusters = 11  # Number of clusters

        # LLM settings
        self.llm = "meta-llama/Llama-2-13b-chat-hf"  # LLM to use
        self.top_p = 1.0  # Sampling parameter: top-p
        self.temperature = 1.0  # Sampling parameter: temperature
        self.max_tokens = 128  # Sampling parameter: max tokens
        self.ICL_topk = 1  # Experiments of top5 or top1 accuracy of ICL

        # Pruning settings
        self.save_pruning_dir = "pruning_result"  # Directory to save pruning results
        self.max_pruned_num = 3000  # Maximum number of pruned neurons
        self.pruned_concepts = ["water_drops"]  # Concepts to be ablated
        self.pruning_strategy = "ocp"  # Method to decide pruned neurons (random, db, tab, ocp)

In [8]:
args = Arguments()

### Pruning

In [9]:
random.seed(args.seed)

label_to_cls = get_label_to_cls(args.network_class_file)

if args.pruning_strategy == "tab":
	prediction_file = os.path.join(args.save_prediction_dir, f"tab-{args.target_name}-top{args.K}.json")
elif args.pruning_strategy == "db": 
	prediction_file = os.path.join(args.save_prediction_dir, f"db-{args.target_name}-top{args.K}.json")
# "random" needs neurons names
elif args.pruning_strategy == "ocp" or args.pruning_strategy == "random": 
	prediction_file = os.path.join(args.save_summary_dir, f"calibration_{args.target_name}_{args.probing_dataset}_{get_basename(args.concept_set_file)}_top{args.K}.json")

prediction = read_json(prediction_file)
if args.pruning_strategy == "random":
	_ = list(prediction.items())
	random.Random(args.seed).shuffle(_)
	prediction = dict(_)

if "ast" in args.target_name:
	input_dimension = {"attention_output":768, "intermediate": 768, "output":3072}
	processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")	
elif "beats" in args.target_name: 
	input_dimension = {"1": 768, "2": 3072}
	processor = None

dataset = load_dataset("ashraq/esc50")
dev_dataset = ESC50Dataset(dataset, processor, mode='dev')


dev_loader = DataLoader(dev_dataset, batch_size=128, shuffle=False, pin_memory=True, collate_fn=collate_batch)

results = {}

for cls_id, cls_name in label_to_cls.items():

	mask_cnt = 0 
	pruned_neuron_record = defaultdict(int)
	masked_out_neurons = []
	masked_neuron = defaultdict(list)
	masked_neuron_bias = defaultdict(list)
	
	model = get_target_model(args.target_name, device=args.device)

	for key, _ in tqdm(prediction.items()):
		layer = key.split('#')[0]

		if layer == 'fc':
			continue

		layer = layer.split('_')

		layer_name = layer[1]
		layer_num = layer[0]

		if len(layer) == 3:
			layer_name = layer[1] + '_' + layer[2]
		
		layer_id = layer_num.replace("layer", "")

		neuron_id = key.split('#')[1]
		dim = input_dimension[layer_name]

		flag = False
		if args.pruning_strategy == "ocp":
			nouns = prediction[key]['nouns']

			if cls_name is not None:
				for n in nouns:
					if n in cls_name: 
						flag = True 
						break		

		# We select the best simlarity function by last layer dissection accuracy 
		elif args.pruning_strategy == "tab" :
			if cls_name is not None and cls_name in prediction[key]["soft_wpmi"]["prediction"][:3]:
				flag = True
		elif args.pruning_strategy == "db":
			if cls_name is not None and cls_name in prediction[key]["cos_similarity_cubed"]["prediction"][:3]:
				flag = True
		elif args.pruning_strategy == "random":
			if mask_cnt < args.max_pruned_num:
				flag = True

		if flag:
			pruned_neuron_record[layer_num + '_' + layer_name] += 1
			mask_cnt += 1
			masked_neuron[layer_num + "_" + layer_name].append([0 for _ in range(dim)])
			masked_out_neurons.append(f"{layer_name}_{layer_id}#{neuron_id}")
			masked_neuron_bias[layer_num + "_" + layer_name].append(0)
		else:
			masked_neuron[layer_num + "_" + layer_name].append([1 for _ in range(dim)])
			masked_neuron_bias[layer_num + "_" + layer_name].append(1)

	for key, mask in masked_neuron.items():
		layer_id = key.split("_")[0].replace("layer", "")
		layer_id = int(layer_id)
		layer_name = key.split("_")[1]
		
		if "ast" in args.target_name:
			if layer_name == "attention":
				module = model.audio_spectrogram_transformer.encoder.layer[layer_id].attention.output.dense
			elif layer_name == "intermediate":
				module = model.audio_spectrogram_transformer.encoder.layer[layer_id].intermediate.dense
			elif layer_name == "output":
				module = model.audio_spectrogram_transformer.encoder.layer[layer_id].output.dense
		elif "beats" in args.target_name: 
			if layer_name == "1":
				module = model.beats.encoder.layers[layer_id].fc1
			elif layer_name == "2": 
				module = model.beats.encoder.layers[layer_id].fc2

		weight_mask = torch.tensor(mask).to("cuda")
		bias_mask = torch.tensor(masked_neuron_bias[key]).to("cuda")

		prune.custom_from_mask(module, 'weight', mask=weight_mask)
		prune.custom_from_mask(module, 'bias', mask=bias_mask)
	
	wrong_record = []
	correct_by_class, total_by_class = defaultdict(int), defaultdict(int)
	pred_by_class = defaultdict(int)
	correct, total = 0, 0
	confidence_by_class = defaultdict(list)
	
	with torch.no_grad():
		for batch in tqdm(dev_loader):
			batch["input_values"] = batch["input_values"].to("cuda")
			batch["labels"] = batch["labels"].to("cuda")
			outputs = model(batch["input_values"])
			if "ast" in args.target_name:
				outputs = outputs.logits
			outputs_list = outputs.detach().cpu().tolist()
			outputs = torch.argmax(outputs, dim = -1)
			labels = batch["labels"]
			correct += torch.sum(outputs == labels).detach().cpu().item()
			total += outputs.shape[0]	
			
			outputs = outputs.detach().cpu().tolist()
			labels = labels.detach().cpu().tolist()
			for idx, (pred, gt, filename) in enumerate(zip(outputs, labels, batch['filenames'])):
				pred = label_to_cls[pred]
				gt = label_to_cls[gt]
				if (pred == gt):
					correct_by_class[pred] += 1
				else:
					wrong_record.append(filename)
				pred_by_class[pred] += 1
				total_by_class[gt] += 1
				confidence_by_class[gt].append(outputs_list[idx])

	results[cls_name] = {}
	results[cls_name]["masked_count"] = len(masked_out_neurons)
	results[cls_name]["masked_neuron"] = masked_out_neurons
	results[cls_name]["correct"] = correct
	results[cls_name]["total"] = total
	results[cls_name]["correct_by_class"] = correct_by_class
	results[cls_name]["pred_by_class"] = pred_by_class
	results[cls_name]["total_by_class"] = total_by_class
	results[cls_name]["confidence"] = confidence_by_class
	print(cls_name, 'mask_cnt: ', mask_cnt)

if not os.path.exists(args.save_pruning_dir):
	os.makedirs(args.save_pruning_dir)

with open(os.path.join(args.save_pruning_dir, f"class-{args.target_name}-{args.pruning_strategy}.json"), "w") as f:
	json.dump(results, fp = f,indent=2)

Found cached dataset parquet (/work/yuxiang1234/cache/ashraq___parquet/ashraq--esc50-1000c3b73cc1500f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /work/yuxiang1234/cache/ashraq___parquet/ashraq--esc50-1000c3b73cc1500f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-8aa9d6610af22639.arrow
Loading cached processed dataset at /work/yuxiang1234/cache/ashraq___parquet/ashraq--esc50-1000c3b73cc1500f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-d13104c871fa4d92.arrow
100%|██████████| 55346/55346 [00:02<00:00, 21023.04it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


None mask_cnt:  0


100%|██████████| 55346/55346 [00:03<00:00, 17639.99it/s]
100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


dog mask_cnt:  2158


100%|██████████| 55346/55346 [00:02<00:00, 20561.18it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


chirping_birds mask_cnt:  4315


100%|██████████| 55346/55346 [00:02<00:00, 20298.13it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


vacuum_cleaner mask_cnt:  732


100%|██████████| 55346/55346 [00:02<00:00, 20525.65it/s]
100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


thunderstorm mask_cnt:  2003


100%|██████████| 55346/55346 [00:03<00:00, 17505.46it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


door_wood_knock mask_cnt:  5307


100%|██████████| 55346/55346 [00:02<00:00, 20002.63it/s]
100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


can_opening mask_cnt:  257


100%|██████████| 55346/55346 [00:02<00:00, 20380.36it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


crow mask_cnt:  614


100%|██████████| 55346/55346 [00:02<00:00, 20501.56it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


clapping mask_cnt:  1885


100%|██████████| 55346/55346 [00:03<00:00, 17072.33it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


fireworks mask_cnt:  2202


100%|██████████| 55346/55346 [00:02<00:00, 20318.19it/s]
100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


chainsaw mask_cnt:  4535


100%|██████████| 55346/55346 [00:02<00:00, 20146.36it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


airplane mask_cnt:  1473


100%|██████████| 55346/55346 [00:02<00:00, 20545.63it/s]
100%|██████████| 4/4 [00:08<00:00,  2.19s/it]


mouse_click mask_cnt:  11897


100%|██████████| 55346/55346 [00:03<00:00, 16205.08it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


pouring_water mask_cnt:  2838


100%|██████████| 55346/55346 [00:02<00:00, 20497.62it/s]
100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


train mask_cnt:  2775


100%|██████████| 55346/55346 [00:02<00:00, 19461.45it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


sheep mask_cnt:  251


100%|██████████| 55346/55346 [00:02<00:00, 20640.69it/s]
100%|██████████| 4/4 [00:11<00:00,  2.89s/it]


water_drops mask_cnt:  2651


100%|██████████| 55346/55346 [00:02<00:00, 19268.18it/s]
100%|██████████| 4/4 [00:09<00:00,  2.34s/it]


church_bells mask_cnt:  2484


100%|██████████| 55346/55346 [00:03<00:00, 17442.09it/s]
100%|██████████| 4/4 [00:12<00:00,  3.07s/it]


clock_alarm mask_cnt:  6132


100%|██████████| 55346/55346 [00:02<00:00, 20357.54it/s]
100%|██████████| 4/4 [00:13<00:00,  3.25s/it]


keyboard_typing mask_cnt:  2313


100%|██████████| 55346/55346 [00:02<00:00, 20214.28it/s]
100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


wind mask_cnt:  1856


100%|██████████| 55346/55346 [00:02<00:00, 20006.49it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


footsteps mask_cnt:  384


100%|██████████| 55346/55346 [00:03<00:00, 17223.20it/s]
100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


frog mask_cnt:  894


100%|██████████| 55346/55346 [00:02<00:00, 20353.78it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


cow mask_cnt:  964


100%|██████████| 55346/55346 [00:02<00:00, 20052.51it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


brushing_teeth mask_cnt:  1057


100%|██████████| 55346/55346 [00:02<00:00, 20507.00it/s]
100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


car_horn mask_cnt:  2093


100%|██████████| 55346/55346 [00:03<00:00, 16948.82it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


crackling_fire mask_cnt:  1329


100%|██████████| 55346/55346 [00:02<00:00, 20207.89it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


helicopter mask_cnt:  4698


100%|██████████| 55346/55346 [00:02<00:00, 20237.44it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


drinking_sipping mask_cnt:  245


100%|██████████| 55346/55346 [00:02<00:00, 20238.31it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


rain mask_cnt:  1448


100%|██████████| 55346/55346 [00:03<00:00, 16417.29it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


insects mask_cnt:  336


100%|██████████| 55346/55346 [00:02<00:00, 20042.84it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


laughing mask_cnt:  370


100%|██████████| 55346/55346 [00:02<00:00, 20520.19it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


hen mask_cnt:  5


100%|██████████| 55346/55346 [00:02<00:00, 20313.75it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


engine mask_cnt:  5867


100%|██████████| 55346/55346 [00:02<00:00, 20126.76it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


breathing mask_cnt:  857


100%|██████████| 55346/55346 [00:03<00:00, 17367.89it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


crying_baby mask_cnt:  1234


100%|██████████| 55346/55346 [00:02<00:00, 20261.95it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


hand_saw mask_cnt:  636


100%|██████████| 55346/55346 [00:02<00:00, 20521.34it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


coughing mask_cnt:  983


100%|██████████| 55346/55346 [00:02<00:00, 20747.49it/s]
100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


glass_breaking mask_cnt:  3883


100%|██████████| 55346/55346 [00:03<00:00, 16914.38it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


snoring mask_cnt:  1273


100%|██████████| 55346/55346 [00:02<00:00, 20189.72it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


toilet_flush mask_cnt:  1073


100%|██████████| 55346/55346 [00:02<00:00, 20409.93it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


pig mask_cnt:  1005


100%|██████████| 55346/55346 [00:02<00:00, 20029.20it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


washing_machine mask_cnt:  2809


100%|██████████| 55346/55346 [00:03<00:00, 16548.51it/s]
100%|██████████| 4/4 [00:08<00:00,  2.23s/it]


clock_tick mask_cnt:  4478


100%|██████████| 55346/55346 [00:02<00:00, 19995.44it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


sneezing mask_cnt:  248


100%|██████████| 55346/55346 [00:02<00:00, 20851.41it/s]
100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


rooster mask_cnt:  1237


100%|██████████| 55346/55346 [00:02<00:00, 20111.71it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


sea_waves mask_cnt:  1027


100%|██████████| 55346/55346 [00:02<00:00, 20473.88it/s]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]


siren mask_cnt:  3639


100%|██████████| 55346/55346 [00:03<00:00, 17380.37it/s]
100%|██████████| 4/4 [00:08<00:00,  2.25s/it]


cat mask_cnt:  842


100%|██████████| 55346/55346 [00:02<00:00, 20535.48it/s]
100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


door_wood_creaks mask_cnt:  5295


100%|██████████| 55346/55346 [00:02<00:00, 20299.12it/s]
100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


crickets mask_cnt:  1622


### Evaluation

In [10]:
result_file = os.path.join(args.save_pruning_dir, f"class-{args.target_name}-{args.pruning_strategy}.json")

results = read_json(result_file)

classes = get_cls_label(args.network_class_file)

# confidence on ablating class samples 
before_confidence = {}
after_confidence = {}
neuron_number = {}

# confidence on remaining class samples
remaining_class_before_confidence = defaultdict(list)
remaining_class_after_confidence = defaultdict(list)

origin = results["null"]
origin_acc = origin["correct"] / origin["total"]
origin_confidence_by_class = origin["confidence"]

for cls_name, object in results.items():

	if cls_name == "null":
		continue

	before_confidence[cls_name] = mean([logit[classes.index(cls_name)] for logit in origin_confidence_by_class[cls_name]])
	after_confidence[cls_name] = mean([logit[classes.index(cls_name)] for logit in object["confidence"][cls_name]])

	for cursor, remaining_cls_name in enumerate(classes):
		if remaining_cls_name == cls_name:
			continue
		remaining_class_before_confidence[cls_name].append(mean([logit[classes.index(remaining_cls_name)] for logit in origin_confidence_by_class[remaining_cls_name]]))
		remaining_class_after_confidence[cls_name].append(mean([logit[classes.index(remaining_cls_name)] for logit in object["confidence"][remaining_cls_name]]))

	remaining_class_before_confidence[cls_name] = mean(remaining_class_before_confidence[cls_name])
	remaining_class_after_confidence[cls_name] = mean(remaining_class_after_confidence[cls_name])
	neuron_number[cls_name] = object["masked_count"]

ablating_class_before = [value for value in before_confidence.values()]
ablating_class_after = [value for value in after_confidence.values()]
ablating_delta = (sum(ablating_class_after) - sum(ablating_class_before)) / len(ablating_class_before)

remaining_class_before = [value for value in remaining_class_before_confidence.values()]
remaining_class_after = [value for value in remaining_class_after_confidence.values()]
remaining_delta = (sum(remaining_class_after) - sum(remaining_class_before)) / len(remaining_class_before)

neuron_number = [value for value in neuron_number.values()]

print("ablating class before", mean(ablating_class_before))
print("ablating class after", mean(ablating_class_after))  
print("ablating class delta", ablating_delta) 
print("neuron_number", mean(neuron_number))
print("remaining class before", mean(remaining_class_before))
print("remaining class after", mean(remaining_class_after))  
print("remaining class delta", remaining_delta) 


ablating class before -1.6850083828228526
ablating class after -11.091819011196494
ablating class delta -9.406810628373641
neuron_number 2210.18
remaining class before -1.6850083828228524
remaining class after -6.894974808408415
remaining class delta -5.209966425585562
