In [None]:
import os
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps


In [None]:
import torch
import torchaudio
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
from PIL import Image
from torchvision.transforms.functional import to_pil_image
import random
#from plot_audio import plot_specgram, plot_waveform
#os.getcwd()

In [None]:
dict_mats = np.load('/Users/jansta/learn/acoustics/dict_mats_dB.npy', allow_pickle=True).item()


In [None]:
len(dict_mats['A']['can_opening'][3])

all_labels = dict_mats['A'].keys()
print(all_labels)

In [None]:
chosen_labels = ['crickets', 'can_opening', 'chirping_birds', 'dog', 'chainsaw'][:4]
encoded_labels = {'crickets': 0, 'can_opening': 1, 'chirping_birds': 2, 'dog': 3, 'chainsaw': 4}

In [None]:
chosen_labels = list(all_labels)[:20]
print(chosen_labels)
encoded_labels = {}
for i, label in enumerate(chosen_labels):
    encoded_labels[label] = i

In [None]:
class AudioDataset(Dataset):
    def __init__(self, dict_mats, chosen_labels, encoded_labels, transform=None):
        self.X = []
        self.y = []
        self.transform = transform
        for key in dict_mats.keys():
            if key in chosen_labels:
                for i in range(len(dict_mats[key])):
                    self.X.append(dict_mats[key][i])
                    self.y.append(encoded_labels[key])
        
        self.X = np.array(self.X)
        self.y = np.array(self.y)
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        sample = self.X[idx]
        label = self.y[idx]
        
        # Add a channel dimension
        sample = np.expand_dims(sample, axis=0)
        
        # Convert to tensor
        sample = torch.FloatTensor(sample)
        label = torch.tensor(label, dtype=torch.long)
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, label

In [None]:
transform = transforms.Compose(
    [transforms.Resize((64,431)),
    transforms.Grayscale(num_output_channels=1),
    #transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
    ])

In [None]:
# Create dataset with transform
dataset = AudioDataset(dict_mats['A'], chosen_labels, encoded_labels, transform=transform)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)



In [None]:
# Test the dataloader
for i, (inputs, labels) in enumerate(train_loader):
    print(f"Batch {i+1}:")
    print(f"Input batch size: {inputs.size()}")
    print(f"Labels: {labels}")
    print("-" * 30)
    break  # Just to test the first batch

In [None]:
class AudioClassifNetXAI(nn.Module):
    def __init__(self, n_classes: int) -> None:
        super().__init__()
        self.n_classes = n_classes
        
        # First Convolutional Block
        self.conv_block1 = nn.Sequential(
            # First convolution: increase number of channels to 16
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # reduces height and width by 2
            
            # Second convolution: further increase channels to 32, also add BatchNorm
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # Second Convolutional Block with a fixed channel progression
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # Global average pooling to collapse the spatial dimensions to 1x1.
        # This avoids having to hard-code the flattened size.
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fully connected block, now starting from a known feature size (128)
        self.fc_block = nn.Sequential(
            nn.Flatten(),              # Flattens (B, 128, 1, 1) into (B, 128)
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, n_classes)
        )
        
    def forward(self, x: torch.Tensor, store_feature_maps: bool = False) -> torch.Tensor:
        """
        Forward pass:
          - Applies two convolutional blocks
          - Optionally stores the feature maps (used for techniques such as Grad-CAM)
          - Applies global average pooling to reduce the feature maps to a fixed size
          - Propagates through the fully connected block
      
        Args:
          x (torch.Tensor): Input tensor of shape [batch_size, 1, 64, 431]
          store_feature_maps (bool, optional): If True, saves the output of conv_block2 
                                                 for visualization. Defaults to False.
      
        Returns:
          torch.Tensor: Output logits of shape [batch_size, n_classes]
        """
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        
        if store_feature_maps:
            # Detaching feature maps for visualization (e.g., Grad-CAM)
            self.feature_maps = x.detach()
        
        # Global average pooling: converts (B, 128, H, W) to (B, 128, 1, 1)
        x = self.global_pool(x)
        x = self.fc_block(x)
        # Note: Do not apply an activation like softmax here if you're using CrossEntropyLoss
      
        return x

In [None]:
def check_for_nans(tensor, name):
    if torch.isnan(tensor).any():
        print(f"NaNs found in {name}")
        return True
    return False

In [None]:
## Create an  instance of the model:
model = AudioClassifNetXAI(10)



In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

# %% TRAINING
losses_epoch_mean = []
NUM_EPOCHS = 500
for epoch in range(NUM_EPOCHS):
    losses_epoch = []
    for i, data in enumerate(train_loader):
        inputs, labels = data
        # Check for NaN loss
        if torch.isnan(inputs).any():
            print(f"NaN input at epoch {epoch}, batch {i}")
            i_err = inputs
            break
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        losses_epoch.append(loss.item())
    
    losses_epoch_mean.append(np.mean(losses_epoch))
    if epoch % int(NUM_EPOCHS/10) == 0:
        print(f'Epoch {epoch}/{NUM_EPOCHS}, Loss: {np.mean(losses_epoch):.12f}')

sns.lineplot(x=list(range(len(losses_epoch_mean))), y=losses_epoch_mean)        

In [None]:
y_val = []
y_val_hat = []
model.eval()
for i, data in enumerate(val_loader):
    inputs, y_val_temp = data
    with torch.no_grad():
        y_val_hat_temp = model(inputs)
    
    y_val.extend(y_val_temp.cpu().numpy())
    y_val_hat.extend(y_val_hat_temp.cpu().numpy())

In [None]:

# Accuracy
acc = accuracy_score(y_val, np.argmax(y_val_hat, axis=1))
print(f'Accuracy: {acc*100:.2f} %')
# confusion matrix
cm = confusion_matrix(y_val, np.argmax(y_val_hat, axis=1))
sns.heatmap(cm, annot=True, xticklabels=chosen_labels, yticklabels=chosen_labels)

In [None]:
from gradCAM import gradCAM

In [None]:
# model.load_state_dict(torch.load("path_to_weights.pth"))
model.eval()

# Select the target convolutional layer.
target_layer = model.conv_block2[-1]

# Create an instance of GradCAM with your model and target layer.
grad_cam = GradCAM(model, target_layer)

# Create a dummy input corresponding to one spectrogram [batch, channel, height, width]

test_inp, _ = val_loader.dataset[0]  # Assuming val_loader.dataset[0] returns a tuple (input, target)
test_inp = test_inp.unsqueeze(0)  # Add batch dimension
#test_inp.requires_grad_(True) 
test_inp.size()




In [None]:
#input_tensor = torch.randn(1, 1, 64, 431)

# Generate the Grad-CAM heatmap:
# Pass the input and (optionally) specify a target_class; otherwise the predicted class is used.
cam_heatmap, pred_class = grad_cam.generate_cam(test_inp, target_class=None)
predicted_label = list(encoded_labels.keys())[list(encoded_labels.values()).index(pred_class)]
print(f"Predicted class: {predicted_label}")
# Convert the heatmap to numpy and visualize it using matplotlib
heatmap = cam_heatmap.squeeze().cpu().numpy()  # shape becomes (64, 431)
plt.imshow(heatmap, cmap='jet', interpolation='bilinear')
plt.title("Grad-CAM Heatmap")
plt.colorbar()
plt.show()

# When done, remove the hooks to avoid potential memory leaks.
grad_cam.remove_hooks()

In [None]:
cams = {}
samples = {}
model.eval()
for i, data in enumerate(val_loader):
    inputs, y_val_temp = data
    #print(inputs.shape, y_val_temp.shape)
    for j in range(inputs.shape[0]):
        target_layer = model.conv_block2[-1]
        grad_cam = gradCAM(model, target_layer)
        single_input = inputs[j].unsqueeze(0)
        cam_hm, pred_class = grad_cam.generate_cam(single_input, target_class=None)
        predicted_label = list(encoded_labels.keys())[list(encoded_labels.values()).index(pred_class)]
        print(f"Predicted class: {predicted_label}")
        if predicted_label not in cams.keys():
            cams[predicted_label] = [cam_hm]
        else:
            cams[predicted_label].append(cam_hm)

        if predicted_label not in samples.keys():
            samples[predicted_label] = inputs[j]


class_cams = {}
for key in cams.keys():
    mean_class_cam = np.mean(cams[key], axis=0)
    #print(mean_class_cam.shape)
    class_cams[key] = mean_class_cam

In [None]:


for key in class_cams.keys():
    plt.figure(figsize=(5, 10))
    plt.imshow(class_cams[key][0,0,:,:], cmap='jet')
    plt.title(f"Class Activation Map for class: {key}")
    #plt.colorbar()
    plt.show()
    plt.savefig(f'cam_{key}.png')

# if save_output:
#     np.save('class_cams.npy', class_cams)

    # for key in class_cams.keys():
    #     plt.figure(figsize=(5, 10))
    #     plt.imshow(class_cams[key:[0,0,:,:], cmap='jet')
    #     plt.title(f"Class Activation Map for class: {key}")
    #     #plt.colorbar()
    #     plt.show()
    #     plt.savefig(f'cam_{key}.png')


