# Grad-CAM Heatmap to Spectrogram Conversion for Waveform-Input Models

## Introduction

This notebook was created by [Jupyter AI](https://github.com/jupyterlab/jupyter-ai) with the following prompt:

> /generate how to use gard-cam technique to generate heatmap for a waveform-input model, and then convert the heatmap  of waveform into spectrogram using torch

Here is a markdown summary of the Jupyter notebook in a single paragraph:

This notebook demonstrates how to apply the Grad-CAM technique to generate a heatmap for a waveform-input model, and subsequently convert the heatmap of the waveform into a spectrogram using Torch. The notebook guides through the process of installing necessary libraries, importing required libraries and loading a dataset of waveforms, defining a PyTorch model that takes waveforms as input, applying Grad-CAM to generate a heatmap, converting the heatmap into a spectrogram with Torch, and finally visualizing the resulting spectrogram using a library like matplotlib or seaborn.

## Import Libraries and Load Data

In [None]:
Here's the improved version:

In [2]:
# Import necessary libraries
import torch
import torchvision
import numpy as np  # for numerical computations
import matplotlib.pyplot as plt  # for plotting (optional)

In [3]:
# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [4]:
from ay2.torch.data.audio import WaveDataset

In [6]:
import sys
sys.path.append("/home/ay/Coding2/0-Deepfake/2-Audio")

In [7]:
from models.Aaasist import AASIST_lit

## Define Waveform-Input Model

In [9]:
model = AASIST_lit(ckpt_path="/home/ay/data/DATA/1-model_save/00-Deepfake/1-df-audio/AASIST/Codecfake/version_0/checkpoints/best-epoch=1-val-auc=0.9999.ckpt")

no. model params:297705


In [10]:
x = torch.randn(2, 1, 48000)
model(x)

NotImplementedError: Module [AASIST_lit] is missing the required "forward" function

## Apply Grad-CAM Technique to Generate Heatmap

In [None]:
# Import necessary libraries
from pytorch_grad_cam import GradCAM
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
def apply_grad_cam(model, waveform_data):
    """
    Applies Grad-CAM technique to generate a heatmap for each input tensor in the given data.

In [None]:
    Args:
        model (nn.Module): The PyTorch model to use.
        waveform_data (TensorDataset or list of tensors): The waveform data to process.

In [None]:
    Returns:
        heatmaps (list of numpy arrays): A list of heatmaps, one for each input tensor.
    """
    # Define the Grad-CAM class and initialize it with our model
    grad_cam = GradCAM(model=model, target_layers=[model.conv1])

In [None]:
    # Create a data loader for the waveform data
    data_loader = DataLoader(waveform_data, batch_size=1)

In [None]:
    # Apply Grad-CAM technique to generate heatmap
    heatmaps = []
    with torch.no_grad():
        for i, (input_tensor,) in enumerate(data_loader):
            # Compute the output of our model on the input tensor
            output = model(input_tensor)
            
            # Get the predicted class index and its corresponding probability
            _, pred_idx = torch.max(output, dim=1)
            pred_prob = torch.nn.functional.softmax(output, dim=1)[:, pred_idx]
            
            # Compute the heatmap using Grad-CAM technique
            heatmap = grad_cam(input_tensor=input_tensor, target=pred_idx)
            
            # Append the computed heatmap to our list
            heatmaps.append(heatmap.cpu().numpy())

In [None]:
    return heatmaps

In [None]:
# Usage example:
model = our_model  # Replace with your actual model
waveform_data = ...  # Replace with your actual waveform data

In [None]:
heatmaps = apply_grad_cam(model, waveform_data)

In [None]:
# Visualize the generated heatmap (optional)
plt.imshow(heatmaps[0].squeeze(), cmap='hot')
plt.show()

## Convert Waveform Heatmap into Spectrogram

In [None]:
Here's an improved version of your code:

In [None]:
# Convert Waveform Heatmap into Spectrogram
# ======================================

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

In [None]:
def heatmap_to_spectrogram(heatmap: torch.Tensor,
                         n_fft: int = 2048,
                         hop_length: int = 512,
                         win_length: int = 2048) -> torch.Tensor:
    """
    Convert a waveform heatmap into a spectrogram.

In [None]:
    Args:
        heatmap (torch.tensor): Heatmap of shape [1, timesteps]
        n_fft (int, optional): Size of FFT window. Defaults to 2048.
        hop_length (int, optional): Hop length between successive frames. Defaults to 512.
        win_length (int, optional): Window size. Defaults to 2048.

In [None]:
    Returns:
        spectrogram (torch.tensor): Spectrogram of shape [1, freq_bins, time_steps]
    """
    # Pad the heatmap to ensure it can be divided evenly by hop_length
    padded_heatmap = torch.nn.functional.pad(heatmap, (0, -heatmap.shape[-1] % hop_length))

In [None]:
    # Compute Short-Time Fourier Transform (STFT)
    window = torch.hamming_window(win_length).to(padded_heatmap.device)
    stft = torch.stft(padded_heatmap.squeeze(), n_fft=n_fft, hop_length=hop_length,
                     win_length=win_length, window=window, onesided=True,
                     pad_mode='constant', normalized=False)

In [None]:
    # Get the magnitude of the complex-valued STFT
    spectrogram = stft.abs().unsqueeze(0)  # Add batch dimension

In [None]:
    return spectrogram

In [None]:

# Load or generate the heatmap
heatmap: torch.Tensor = ...  

In [None]:
# Convert the heatmap to spectrogram
spectrogram = heatmap_to_spectrogram(heatmap)

In [None]:
print(spectrogram.shape)

In [None]:
Note that I have used type hints for function arguments and returns. The `heatmap` is expected to be of type `torch.Tensor`. The `n_fft`, `hop_length` and `win_length` are optional parameters with default values.

## Visualize the Resulting Spectrogram

In [None]:
# Import necessary libraries
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np  # Ensure numpy is imported for array operations

In [None]:
def visualize_spectrogram(spectrogram, cmap='inferno'):
    """
    Visualize the resulting spectrogram.

In [None]:
    Parameters:
    - spectrogram (2D numpy array): The spectrogram data to be visualized.
    - cmap (str or matplotlib colormap, optional): Colormap to use for visualization. Defaults to 'inferno'.
    """
    # Create a figure with a single subplot
    fig, ax = plt.subplots(figsize=(12, 6))

In [None]:
    # Use seaborn's heatmap function to create the spectrogram visualization
    sns.heatmap(spectrogram, cmap=cmap, ax=ax)

In [None]:
    # Set title and labels for the plot
    ax.set_title('Resulting Spectrogram')
    ax.set_xlabel('Time')
    ax.set_ylabel('Frequency')

In [None]:
    # Display the plot
    plt.tight_layout()  # Ensure plot fits within figure area
    plt.show()

In [None]:
# Example usage:
spectrogram_data = np.random.rand(256, 512)  # Replace with actual spectrogram data
visualize_spectrogram(spectrogram_data)