# Initialization

In [None]:
# Install required libraries
!pip install torch torchvision pillow einops



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd /content/drive/MyDrive/XAI_project

/content/drive/MyDrive/XAI_project


# Setup

## Imports

In [None]:
from vit_model.baselines.ViT.ViT_explanation_generator_CPU import LRP
from vit_model.VIT_LRP import vit_base_patch16_224_spectrogram as vit_LRP
from PIL import Image
from dataset.GTZAN import GTZAN
import torchvision.transforms as transforms
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from audio_preprocess.audio_preprocess import AudioPreprocessor
import soundfile as sf
import librosa
from IPython.display import Audio, display
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import math
from dataset.GTZAN_SNR import GTZAN_SNR

## Functions

In [None]:
def add_noise_to_audio_get_spectrogram(audio, snr_db, audiopreprocessor):
    rms_signal = math.sqrt(np.mean(audio ** 2))
    rms_noise = rms_signal / (10 ** (snr_db / 20))
    noise = np.random.normal(0, rms_noise, audio.shape[0])
    spectrogram = audiopreprocessor.compute_log_spectrogram(audio + noise)
    return spectrogram

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

# Attribution visualization generation
def generate_visualization(original_image, not_transformed_image, attribution_generator, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach() # (1, w*w)
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) # (1, 1, w, w)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, size=(1600, 224), mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(1600, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())

    image_transformer_attribution = not_transformed_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    heatmap = cv2.applyColorMap(np.uint8(255 * transformer_attribution), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(image_transformer_attribution)
    cam = cam / np.max(cam)
    vis =  np.uint8(255 * cam)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, transformer_attribution

# Generating spectrogram tensor without attention parts
def generate_removal_evaluation(original_image, not_transformed_image, attribution_generator, class_index=None, threshold=0.5):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach() # (1, w*w)
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) # (1, 1, w, w)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, size=(1600, 224), mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(1600, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())

    mask_TA = np.where(transformer_attribution < threshold, 1, 0)
    original_tensor_filtered = not_transformed_image*mask_TA

    return original_tensor_filtered

# Print predicted classes
def print_top_classes(predictions, class_map, **kwargs):
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []

    for cls_idx in class_indices:
        class_names.append(class_map[cls_idx])
        if len(class_map[cls_idx]) > max_str_len:
            max_str_len = len(class_map[cls_idx])

    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, class_map[cls_idx])
        output_string += ' ' * (max_str_len - len(class_map[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)

# Swapping label map to print predicted classes
def swap_label_map(label_map):
    label_map_int = {}

    for key, value in label_map.items():
        label_map_int[value] = key

    return label_map_int

#Sonification function
def sonify(genre, number, transformer_attribution):
  audio_f = f'data/genres_original/test/{GENRE}/{GENRE}.000{NUMBER}.wav'
  audio_section = audiopreprocessor.load_audio(audio_f, select_section = False)
  audio_section = audiopreprocessor.normalize_amplitude(audio_section)
  stft_result = librosa.stft(audio_section, n_fft=N_FFT, hop_length=HOP_LENGTH)
  output_audio = (transformer_attribution**2)*stft_result
  # reconstructs audio form the intermediate result (stft) where phase information is preserved
  reconstruction_audio = librosa.istft(output_audio, n_fft = N_FFT, hop_length=HOP_LENGTH)
  output_audio_path = f'output_audios/output_audio_{GENRE}_{NUMBER}.wav'
  sf.write(output_audio_path, reconstruction_audio, SAMPLE_RATE)
  return audio_f, output_audio_path



## Constants

In [None]:
GENRE = 'disco'
NUMBER = '96'

DPI = 100
IM_DATA = plt.imread(f'data/images_1600_224/test/{GENRE}/{GENRE}000{NUMBER}.png')
HEIGHT, WIDTH, _ = IM_DATA.shape

FIGSIZE = (WIDTH / float(DPI))*2, HEIGHT / float(DPI)
FIGSIZE_ONLY_MAP = (WIDTH / float(DPI)), HEIGHT / float(DPI)
FIGSIZE_REMOVAL = (WIDTH / float(DPI))*3, HEIGHT / float(DPI)

DATASET = GTZAN(mode='test', folder='images_1600_224')
LABEL_MAP = swap_label_map(DATASET.label_map)

transform_tensor = transforms.Compose([
        transforms.ToTensor()
])

transform_toPIL = transforms.Compose([
    transforms.ToPILImage()
])

DEVICE = torch.device("cpu")
MODEL = vit_LRP(pretrained=True)
MODEL.eval()
ATTRIBUTION_GENERATOR = LRP(MODEL)

SAMPLE_RATE = 22050
DURATION = 30
HOP_LENGTH_FACTOR = 7.45
N_FFT = 1599*2
MONO = True
HOP_LENGTH = int(SAMPLE_RATE / HOP_LENGTH_FACTOR)

audiopreprocessor = AudioPreprocessor(SAMPLE_RATE, DURATION, HOP_LENGTH_FACTOR, N_FFT)

# Attribution Map generation

In [None]:
image_spec = Image.open(f'data/images_1600_224/test/{GENRE}/{GENRE}000{NUMBER}.png').convert('RGB')
image_spec_tensor = transform_tensor(image_spec)
image_spec_transf = DATASET.transform(image_spec)

output = MODEL(image_spec_transf.unsqueeze(0))
print_top_classes(output, LABEL_MAP)

visualization, _ = generate_visualization(image_spec_transf, image_spec_tensor, ATTRIBUTION_GENERATOR)

combined_image = plt.figure(figsize=FIGSIZE)

ax1 = combined_image.add_subplot(1, 2, 1)
ax2 = combined_image.add_subplot(1, 2, 2)

ax1.imshow(image_spec)
ax1.axis('off')
ax2.imshow(visualization)
ax2.axis('off')

plt.tight_layout(pad=0, h_pad=0, w_pad=0, rect=(0, 0, 0, 0))
plt.savefig('combined_image.png')
plt.show()

print(f"Combined image saved as combined_image.png")

# Attribution Map Sonification

In [None]:
GENRE = 'metal'
NUMBER = '95'

image_spec = Image.open(f'data/images_1600_224/test/{GENRE}/{GENRE}000{NUMBER}.png').convert('RGB')
image_spec_tensor = transform_tensor(image_spec)
image_spec_transf = DATASET.transform(image_spec)

output = MODEL(image_spec_transf.unsqueeze(0))
print_top_classes(output, LABEL_MAP)

visualization, transformer_attribution = generate_visualization(image_spec_transf, image_spec_tensor, ATTRIBUTION_GENERATOR)

map_image = plt.figure(figsize=FIGSIZE_ONLY_MAP)

ax = map_image.add_subplot(1, 1, 1)
ax.imshow(visualization)
ax.axis('off')
plt.tight_layout(pad=0, h_pad=0, w_pad=0, rect=(0, 0, 0, 0))

audio_f, output_audio_path = sonify(GENRE, NUMBER, transformer_attribution)

display(Audio(audio_f))
display(Audio(output_audio_path))


# Evaluation feature removal

In [None]:
GENRE = 'jazz'
NUMBER = '95'

image_spec = Image.open(f'data/images_1600_224/test/{GENRE}/{GENRE}000{NUMBER}.png').convert('RGB')
image_spec_tensor = transform_tensor(image_spec)

print("CLASSIFICATION BEFORE REMOVAL")
image_spec_transf = DATASET.transform(image_spec)
output = MODEL(image_spec_transf.unsqueeze(0))
print_top_classes(output, LABEL_MAP)

visualization, _ = generate_visualization(image_spec_transf, image_spec_tensor, ATTRIBUTION_GENERATOR)
image_filtered = generate_removal_evaluation(image_spec_transf, image_spec_tensor, ATTRIBUTION_GENERATOR, threshold=0.05)

image_filtered_pil = transform_toPIL(image_filtered)

print("CLASSIFICATION AFTER REMOVAL")
image_filtered_pil_transf = DATASET.transform(image_filtered_pil)
output = MODEL(image_filtered_pil_transf.unsqueeze(0))
print_top_classes(output, LABEL_MAP)

visualization_filtered, transformer_attribution_filtered = generate_visualization(image_filtered_pil_transf, image_filtered, ATTRIBUTION_GENERATOR)

image = plt.figure(figsize=FIGSIZE_REMOVAL)

ax1 = image.add_subplot(1, 3, 1)
ax2 = image.add_subplot(1, 3, 2) 
ax3 = image.add_subplot(1, 3, 3) 
ax1.imshow(visualization)
ax1.axis('off')
ax2.imshow(image_filtered_pil)
ax2.axis('off')
ax3.imshow(visualization_filtered)
ax3.axis('off')

plt.tight_layout(pad=0, h_pad=0, w_pad=0, rect=(0, 0, 0, 0))
plt.savefig(f'spectrogram_{GENRE}_filtered.png')
plt.show()

audio_f, output_audio_path = sonify(GENRE, NUMBER, transformer_attribution_filtered)

display(Audio(audio_f))
display(Audio(output_audio_path))


# Model Evaluation with Confusion Matrix generation

In [None]:
genre_corrects = {genre: 0 for genre in DATASET.label_map}
genre_tot = {genre: 0 for genre in DATASET.label_map}
all_preds = []
all_labels = []

test_loader = DataLoader(DATASET, batch_size=32, shuffle=False)
TEST_LABEL_MAP = DATASET.label_map

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = MODEL(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        for pred, label in zip(preds, labels):
            print()
            genre_corrects[LABEL_MAP[label.item()]] += int(pred == label)
            genre_tot[LABEL_MAP[label.item()]] += 1

genre_accuracy = {genre: genre_corrects[genre] / genre_tot[genre] for genre in TEST_LABEL_MAP}
overall_accuracy = np.mean(list(genre_accuracy.values()))

for genre in TEST_LABEL_MAP:
    print(f"{genre}: {genre_accuracy[genre]:.4f}")

print(f"Overall Accuracy: {overall_accuracy:.4f}")


In [None]:
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genre_tot.keys(), yticklabels=genre_tot.keys() ,cbar=False)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.savefig('confusion_matrix.svg', format='svg')


# Noise Injection 5-80

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader

snr_values = range(5, 101, 5)
accuracies = []

for snr in snr_values:
    correct_total = 0
    total_samples = 0
    dataset = GTZAN(mode='', folder=f"noisy_1600_224_snr{snr}")
    test_loader = DataLoader(dataset, batch_size=32, shuffle=False)

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = MODEL(images)
            _, preds = torch.max(outputs, 1)

            correct_total += torch.sum(preds == labels).item()
            total_samples += labels.size(0)

    overall_accuracy = correct_total / total_samples
    accuracies.append(overall_accuracy)

    print(f"SNR {snr}: Overall Accuracy: {overall_accuracy:.4f}")


In [None]:
plt.figure(figsize=(10, 6))
plt.plot(snr_values, accuracies, marker='o')
plt.xlabel('SNR (dB)')
plt.ylabel('Accuracy')
plt.title('Accuracy vs SNR')
plt.grid(True)
plt.xticks(snr_values)
plt.yticks(np.arange(0.3, 1, 0.1))
plt.savefig('snr_vs_accuracy.png')
plt.show()