In [1]:
import fairseq
import torch 
import torch.nn as nn
from fairseq.models import (
    FairseqEncoder, 
    register_model, 
    register_model_architecture
)
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model, Wav2Vec2Config
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec import (
    TransformerEncoder,
    TransformerSentenceEncoderLayer,
    Wav2Vec2Model,
    Wav2VecEncoder    
)
from fairseq.data.audio.speech_to_text_dataset import _collate_frames


[2024-01-13 17:44:20,107] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
original_forward = TransformerSentenceEncoderLayer.forward

def generate_2d_causal_mask(seq_len, device='cpu'):
    """
    Generates a 2D causal mask for multi-head attention.
    
    Args:
        seq_len (int): The length of the sequence.
        device (str): The device on which to create the mask.
    
    Returns:
        torch.Tensor: A 2D causal attention mask.
    """
    mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

def causal_forward(
    self,
    x: torch.Tensor,
    self_attn_mask: torch.Tensor = None,
    self_attn_padding_mask: torch.Tensor = None,
    need_weights: bool = False,
    att_args=None,
):
    # Generate the causal mask
    # print(x)
    # print(x.size(2))
    # print(self_attn_mask)
    causal_mask = generate_2d_causal_mask(x.size(0), device=x.device)
    
    if self_attn_mask is not None:
        self_attn_mask = self_attn_mask + causal_mask
    else:
        self_attn_mask = causal_mask

    return original_forward(
        self, x, 
        self_attn_mask=self_attn_mask, 
        self_attn_padding_mask=self_attn_padding_mask, 
        need_weights=need_weights,
        att_args=att_args)


In [3]:
# try replace multihead attention with causal multihead attention


In [4]:
# try replace wav2vec forward

In [3]:
def replace_forward():
    TransformerSentenceEncoderLayer.forward = causal_forward

In [4]:
speech_tower_path = '/mnt/data/xixu/models/wav2_vec_vox_960h_pl.pt'
state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(speech_tower_path)
model = Wav2VecEncoder(state['cfg']['model'], None)
new = {}
for key in state['model'].keys():
    new_key = key.replace('w2v_encoder.', '')
    if not new_key.startswith('proj'):
        new[new_key] = state['model'][key]
model.load_state_dict(new, strict=True)
model = model.w2v_model
replace_forward()



In [8]:
#Attention Weight
import matplotlib.pyplot as plt
import torchaudio
import numpy
from train.dataset import PromptSpeechToTextDatasetCreator, SpeechToTextDatasetItem



def visualize_attention_weights(model, plot_size=10):
    test_dataset = PromptSpeechToTextDatasetCreator.from_tsv("/mnt/data/xixu/datasets/must-c-v1.0/en-es/", 'tst-COMMON_1' )
    for test_data in test_dataset:
        source, ref, id = test_data.source, test_data.target, test_data.id                  
        speech_batch = _collate_frames([source], is_audio_input=True)

        model.eval()
        # Forward pass through the model
        with torch.no_grad():
            result = model.extract_features(speech_batch, padding_mask=None)
        # ((x, z, lr))
        attn = result['layer_results'][0][1]
        print(attn.size())
        attn = attn[0] if attn.ndim == 3 else attn
        print(attn.size())
        # Select a smaller portion of the attention matrix to visualize
        small_attn = attn[:plot_size, :plot_size].cpu().numpy()

        # Visualize the attention weights
        plt.matshow(small_attn)
        plt.title(f"Attention Weights - First {plot_size} Timesteps")
        plt.xlabel("Key Positions")
        plt.ylabel("Query Positions")
        plt.colorbar()
        plt.show()

# Call the function with the model and desired input length
visualize_attention_weights(model, plot_size=10)


torch.Size([1, 466, 466])
torch.Size([466, 466])


IndexError: too many indices for tensor of dimension 2

In [None]:
#Attention Weight
import matplotlib.pyplot as plt

def visualize_attention_weights(model, input_length, plot_size=10):
    # Create a dummy input tensor
    input_tensor = torch.rand(input_length, state['cfg']['model']['w2v_args']['model'].encoder_embed_dim)
    model.eval()
    # Forward pass through the model
    with torch.no_grad():
        result = model.extract_features(input_tensor, padding_mask=None, mask=True)
    
    attn = result['layer_results'][0][2]  # Adjust based on your model's structure

    # Select a smaller portion of the attention matrix to visualize
    small_attn = attn[0, :plot_size, :plot_size].cpu().numpy()

    # Visualize the attention weights of the first head of the first layer
    plt.matshow(small_attn)
    plt.title(f"Attention Weights (Head 1, Layer 1) - First {plot_size} Timesteps")
    plt.xlabel("Key Positions")
    plt.ylabel("Query Positions")
    plt.colorbar()
    plt.show()

visualize_attention_weights(model, input_length=100, plot_size=10)


RuntimeError: The shape of the 3D attn_mask is torch.Size([1, 100, 100]), but should be (1600, 2, 2).

In [None]:
#  Test2 Mask Future Timesteps and Check Output

def test_mask_future_timesteps(model, input_length, mask_start):
    # Create a dummy input tensor
    input_tensor = torch.rand(input_length, state['cfg']['model']['w2v_args']['model'].encoder_embed_dim)

a
    # Forward pass with original input
     with torch.no_grad():
        out_orig, _ = model.extract_features(input_tensor, padding_mask=None, mask=True)

    # Mask future timesteps in the input
    input_tensor[:, mask_start:, :] = 0

    # Forward pass with masked input
    with torch.no_grad():
        out_masked, _ = model.extract_features(input_tensor, padding_mask=None, mask=True)

    # Check if outputs are the same for unmasked timesteps
    assert torch.allclose(out_orig[:, :mask_start, :], out_masked[:, :mask_start, :], atol=1e-6), "Outputs do not match for unmasked timesteps."

test_mask_future_timesteps(model, input_length=100, mask_start=50)


IndentationError: unexpected indent (3195207612.py, line 9)