In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import WhisperForConditionalGeneration
from datasets import load_dataset

class WhisperWithATMAN(nn.Module):
    def __init__(self, original_whisper_model):
        super().__init__()
        self.whisper = original_whisper_model

    def forward(self, x, token_index=None, suppression_factor=0.1, threshold=0.5):
        """
        Forward method modified to include ATMAN-based perturbation for explainability.

        Args:
            x (torch.Tensor): Input tensor representing the Log-Mel spectrograms.
            token_index (int, optional): Index of the token to perturb.
            suppression_factor (float, optional): Factor by which attention scores are suppressed.
            threshold (float, optional): Cosine similarity threshold for correlated token suppression.

        Returns:
            torch.Tensor: Decoded output after perturbation.
        """
        # Step 1: Encode the input audio
        encoder_outputs = self.whisper.model.encoder(x)
        embeddings = encoder_outputs.last_hidden_state

        # Step 2: Calculate attention scores
        attention_scores = self.compute_attention_scores(embeddings)
        
        # Step 3: Apply attention modification if a token index is provided
        if token_index is not None:
            similarity_matrix = self.calculate_cosine_similarity(embeddings)
            modified_attention_scores = self.modify_attention_scores(
                attention_scores, token_index, suppression_factor, similarity_matrix, threshold
            )
        else:
            modified_attention_scores = attention_scores

        # Step 4: Decode with modified attention
        decoder_outputs = self.whisper.model.decoder(
            inputs_embeds=embeddings, 
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=modified_attention_scores
        )

        return decoder_outputs

    def compute_attention_scores(self, embeddings):
        """
        Computes attention scores from embeddings using scaled dot-product attention.

        Args:
            embeddings (torch.Tensor): Embeddings from the encoder.

        Returns:
            torch.Tensor: Attention score matrix.
        """
        Q = self.whisper.model.encoder.layer[0].attention.self.query(embeddings)
        K = self.whisper.model.encoder.layer[0].attention.self.key(embeddings)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (embeddings.size(-1) ** 0.5)
        return F.softmax(scores, dim=-1)

    def modify_attention_scores(self, H, token_index, suppression_factor, similarity_matrix=None, threshold=0.5):
        """
        Modifies attention scores to suppress specific tokens and their correlated tokens.

        Args:
            H (torch.Tensor): Original attention score matrix.
            token_index (int): Index of the token to suppress.
            suppression_factor (float): Factor by which to suppress attention scores.
            similarity_matrix (torch.Tensor, optional): Cosine similarity matrix for embeddings.
            threshold (float, optional): Threshold for identifying correlated tokens.

        Returns:
            torch.Tensor: Modified attention score matrix.
        """
        s, _ = H.shape[-2:]  # Sequence length
        mask = torch.ones((s, s), device=H.device)
        mask[token_index, :] *= (1 - suppression_factor)

        if similarity_matrix is not None:
            correlated_mask = (similarity_matrix > threshold).float() * (1 - suppression_factor)
            mask += correlated_mask
        
        return H * mask

    def calculate_cosine_similarity(self, embeddings):
        """
        Calculates the cosine similarity matrix for the given embeddings.

        Args:
            embeddings (torch.Tensor): Encoder embeddings.

        Returns:
            torch.Tensor: Cosine similarity matrix.
        """
        normalized_embeddings = F.normalize(embeddings, p=2, dim=-1)
        similarity_matrix = torch.matmul(normalized_embeddings, normalized_embeddings.transpose(-2, -1))
        return similarity_matrix

# Function to load and preprocess audio
def preprocess_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)
    transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_mels=80,
        hop_length=160,
        n_fft=400
    )
    log_mel_spec = transform(waveform).log2()
    return log_mel_spec.unsqueeze(0)  # Add batch dimension

# Load dataset from Hugging Face and preprocess
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# Preprocess the first audio sample from the dataset
audio_sample = dataset[0]["audio"]["array"]
sample_rate = dataset[0]["audio"]["sampling_rate"]
waveform = torch.tensor(audio_sample).unsqueeze(0)  # Add batch dimension

transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=sample_rate,
    n_mels=80,
    hop_length=160,
    n_fft=400
)
input_audio = transform(waveform).log2().unsqueeze(0)  # Add batch and channel dimensions

# Load the Whisper model and create an instance of WhisperWithATMAN
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
atman_whisper = WhisperWithATMAN(whisper_model)

# Run the preprocessed audio sample through the model
output = atman_whisper(input_audio, token_index=5, suppression_factor=0.3)

# Print the output tokens
print(output)


RuntimeError: expected scalar type Float but found Double