In [None]:
import fairseq
import torch 
import torch.nn as nn
from fairseq.models import (
    FairseqEncoder, 
    register_model, 
    register_model_architecture
)
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model, Wav2Vec2Config
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec import (
    TransformerEncoder,
    TransformerSentenceEncoderLayer,
    Wav2Vec2Model,
    Wav2VecEncoder    
)
from fairseq.data.audio.speech_to_text_dataset import _collate_frames


In [None]:
# original_forward = TransformerSentenceEncoderLayer.forward

# def generate_2d_causal_mask(seq_len, device='cpu'):
#     """
#     Generates a 2D causal mask for multi-head attention.
    
#     Args:
#         seq_len (int): The length of the sequence.
#         device (str): The device on which to create the mask.
    
#     Returns:
#         torch.Tensor: A 2D causal attention mask.
#     """
#     mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1)
#     mask = mask.masked_fill(mask == 1, float('-inf'))
#     return mask

# def causal_forward(
#     self,
#     x: torch.Tensor,
#     self_attn_mask: torch.Tensor = None,
#     self_attn_padding_mask: torch.Tensor = None,
#     need_weights: bool = False,
#     att_args=None,
# ):
#     # Generate the causal mask
#     # print(x)
#     # print(x.size(2))
#     # print(self_attn_mask)
#     causal_mask = generate_2d_causal_mask(x.size(0), device=x.device)
    
#     if self_attn_mask is not None:
#         self_attn_mask = self_attn_mask + causal_mask
#     else:
#         self_attn_mask = causal_mask

#     return original_forward(
#         self, x, 
#         self_attn_mask=self_attn_mask, 
#         self_attn_padding_mask=self_attn_padding_mask, 
#         need_weights=need_weights,
#         att_args=att_args)


In [None]:
def generate_2d_causal_mask(seq_len, dtype, device='cpu'):
    """
    Generates a 2D causal mask for multi-head attention.
    
    Args:
        seq_len (int): The length of the sequence.
        device (str): The device on which to create the mask.
    
    Returns:
        torch.Tensor: A 2D causal attention mask.
    """
    # mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1)
    mask = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=dtype), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

def causal_forward(
    self,
    x: torch.Tensor,
    self_attn_mask: torch.Tensor = None,
    self_attn_padding_mask: torch.Tensor = None,
    need_weights: bool = False,
    att_args=None,
):
    """
    LayerNorm is applied either before or after the self-attention/ffn
    modules similar to the original Transformer imlementation.
    """
    # causal_mask = generate_2d_causal_mask(x.size(0), device=x.device)
    causal_mask = generate_2d_causal_mask(x.size(0), dtype=x.dtype,device=x.device)
    
    if self_attn_mask is not None:
        self_attn_mask = self_attn_mask + causal_mask
    else:
        self_attn_mask = causal_mask
        
    residual = x

    if self.layer_norm_first:
        x = self.self_attn_layer_norm(x)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            attn_mask=self_attn_mask,
            need_weights=True,
        )
        x = self.dropout1(x)
        x = residual + x

        residual = x
        x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)

        layer_result = x

        x = self.dropout3(x)
        x = residual + x
    else:
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            need_weights=True,
        )

        x = self.dropout1(x)
        x = residual + x

        x = self.self_attn_layer_norm(x)

        residual = x
        x = self.activation_fn(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)

        layer_result = x

        x = self.dropout3(x)
        x = residual + x
        x = self.final_layer_norm(x)

    return x, (attn, layer_result)

In [None]:
# try replace multihead attention with causal multihead attention


In [None]:
# try replace wav2vec forward

In [None]:
def replace_forward():
    TransformerSentenceEncoderLayer.forward = causal_forward

In [None]:
speech_tower_path = '/mnt/taurus/data/xixu/models/wav2_vec_vox_960h_pl.pt'
state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(speech_tower_path)
model = Wav2VecEncoder(state['cfg']['model'], None)
new = {}
for key in state['model'].keys():
    new_key = key.replace('w2v_encoder.', '')
    if not new_key.startswith('proj'):
        new[new_key] = state['model'][key]
model.load_state_dict(new, strict=True)
model = model.w2v_model
replace_forward()

In [None]:
#Attention Weight
import matplotlib.pyplot as plt
import torchaudio
import numpy
from train.dataset import PromptSpeechToTextDatasetCreator, SpeechToTextDatasetItem


replace_forward()
def visualize_attention_weights(model, plot_size=10):
    test_dataset = PromptSpeechToTextDatasetCreator.from_tsv("/mnt/data/xixu/datasets/must-c-v1.0/en-es/", 'tst-COMMON_1' )
    for test_data in test_dataset:
        source, ref, id = test_data.source, test_data.target, test_data.id                  
        speech_batch = _collate_frames([source], is_audio_input=True)

        model.eval()
        # Forward pass through the model
        with torch.no_grad():
            result = model.extract_features(speech_batch, padding_mask=None)
        # ((x, z, lr))
        # changed https://github.com/facebookresearch/fairseq/blob/fad2c4d1ebe14d974876de52dcb06db6d99b0b4a/fairseq/models/wav2vec/wav2vec2.py#L1330C34-L1330C34 
        # to get attention weights  
        attn = result['layer_results'][0][1]
        feature = result["x"]
        print(feature.size())
        # print(attn.size())
        attn = attn[0] if attn.ndim == 3 else attn
        # print(attn.size())
        # Select a smaller portion of the attention matrix to visualize
        small_attn = attn[:plot_size, :plot_size].cpu().numpy()

        # Visualize the attention weights
        plt.matshow(small_attn)
        plt.title(f"Attention Weights - First {plot_size} Timesteps")
        plt.xlabel("Key Positions")
        plt.ylabel("Query Positions")
        plt.colorbar()
        plt.show()

# Call the function with the model and desired input length
visualize_attention_weights(model, plot_size=10)


# Inference

## Incremental w2v2 encoding

In [None]:
from fairseq.modules import MultiheadAttention

In [None]:
mha = MultiheadAttention(512, 8, dropout=0.0, self_attention=True)

In [None]:
incremental_state = {}

In [None]:
x = torch.rand(2, 1, 512)

In [None]:
attn_mask = generate_2d_causal_mask(x.size(0), dtype=x.dtype, device=x.device)

In [None]:
attn, attn_weights = mha.forward(x, x, x, incremental_state=incremental_state, attn_mask=attn_mask)

In [None]:
y = torch.rand(3, 1, 512)

In [None]:
attn_mask = generate_2d_causal_mask(5, dtype=x.dtype, device=x.device)[2:]

In [None]:
attn, attn_weights = mha.forward(y, y, y, incremental_state=incremental_state, attn_mask=attn_mask)

In [None]:
list(incremental_state.values())[0]['prev_key'].size()

## Incremental llama encoding with w2v2 input

In [None]:
%load_ext autoreload

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import argparse, time, json
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch, transformers
import torch.nn as nn
from eval.utils import disable_torch_init
from model.model import SpeechLlamaForCausalLM, SpeechLlamaModel, SpeechLlamaConfig
from model.utils import KeywordsStoppingCriteria
from fairseq.data.audio.speech_to_text_dataset import _collate_frames
from train.dataset import PromptSpeechToTextDatasetCreator, SpeechToTextDatasetItem
import conversation as conversation_lib
from conversation import SeparatorStyle

import requests

import torch.nn.functional as F

import importlib
import numpy as np
from fairseq.models.speech_to_text import lengths_to_padding_mask
from train.uni_wav2vec_monkey_patch import replace_uni_train, replace_uni_decode, uni_self_attn_forward, uni_w2v2_extract_features

transformers.set_seed(998244353)
torch.use_deterministic_algorithms(True)

args = argparse.Namespace()
args.model_name = '/mnt/taurus/data/xixu/runs/sllama/en-es/7b/uni/stage2/checkpoint-2000'
args.length_adapter_path = os.path.join(args.model_name, 'length_adapter.bin')
args.mlp_adapter_path = os.path.join(args.model_name, 'mlp_adapter.bin')
args.speech_tower_path = os.path.join(args.model_name, 'speech_tower.bin')

load_type = torch.float32
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    args.model_name,
    padding_side="right",
    use_fast=False,
)
config = json.load(open(os.path.join(args.model_name, 'config.json')))
config['large_model'] = True
update_config = os.path.join(args.model_name, 'config_large.json')
json.dump(config, open(update_config, 'w'), indent=2)
# replace_llama_attn_with_flash_attn()

replace_uni_train()

model = SpeechLlamaForCausalLM.from_pretrained(args.model_name,
                                                torch_dtype=load_type,
                                                low_cpu_mem_usage=True,
                                                device_map='cpu',
                                                config=update_config,).eval()

device_input = device_output = 'cpu'

length_after_ssl, length_after_adp = model.model.initialize_speech_modules(
    speech_tower_path='/mnt/taurus/data/xixu/models/wav2_vec_vox_960h_pl.pt',
    speech_tower_type=None,
    len_adapter_channels=model.config.len_adapter_channels,
    len_adapter_kernel_sizes=model.config.len_adapter_kernel_sizes,
    ssl_fintuned=model.config.ssl_fintuned,
)
model.model.speech_tower.to(dtype=load_type, device=device_input)

length_adapter_weights = torch.load(args.length_adapter_path, map_location='cpu')
mlp_adapter_weights = torch.load(args.mlp_adapter_path, map_location='cpu')
speech_tower_weights = torch.load(args.speech_tower_path, map_location='cpu')


model.model.mm_length_adapter.load_state_dict(length_adapter_weights)
model.model.mm_mlp_adapter.load_state_dict(mlp_adapter_weights)
model.model.speech_tower.load_state_dict(speech_tower_weights)

model.model.mm_length_adapter.to(dtype=load_type, device=device_input).eval()
model.model.mm_mlp_adapter.to(dtype=load_type, device=device_input).eval()
model.model.speech_tower.to(dtype=load_type, device=device_input).eval()

In [1]:
import sys
sys.path.append('/mnt/taurus/home/siqiouyang/work/projects/SimulEval/')
from eval.agents.tt_waitk_sllama import S2TAgentStates
from eval.agents.tt_waitk_sllama_incremental import IncrementalS2TAgentStates

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_SPEECH_TOKEN = "<speech>"
DEFAULT_SPEECH_PATCH_TOKEN = "<sp_patch>"
DEFAULT_SPEECH_START_TOKEN = "<sp_start>"
DEFAULT_SPEECH_END_TOKEN = "<sp_end>"

  from .autonotebook import tqdm as notebook_tqdm


[2024-02-02 10:32:52,021] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


ModuleNotFoundError: No module named 'simuleval'

In [None]:
torch.FloatTensor([1.1]).floor()

In [None]:
model.model.speech_tower.encoder.pos_conv

In [None]:
def process(states):
    source = torch.tensor(states.source).to(
        device=model.device, dtype=model.dtype
    )
    speech_batch = _collate_frames([source], is_audio_input=True)
    n_frames = torch.tensor([source.size(0)], dtype=torch.long)
    # source = F.layer_norm(source, source.size())
    speech_lens = length_after_adp(length_after_ssl(n_frames))

    to_adds = [int(speech_len)*DEFAULT_SPEECH_PATCH_TOKEN for speech_len in speech_lens]
    to_adds = [DEFAULT_SPEECH_START_TOKEN + to_add + DEFAULT_SPEECH_END_TOKEN for to_add in to_adds]

    conv = conversation_lib.default_conversation.copy()
    conv.messages = []
    conv.append_message(conv.roles[0], to_adds[0])
    conv.append_message(conv.roles[1], None)
    prompt_inputs = conv.get_prompt()

    inputs = tokenizer([prompt_inputs])
    input_ids = inputs.input_ids[0] + states.target_ids
    input_ids_tensor = torch.as_tensor([input_ids])
    model.model.speech_features_extracted = False

    with torch.inference_mode():
        output = model.model(
            attention_mask=None, # input_ids_tensor.ne(tokenizer.pad_token_id),
            input_ids=input_ids_tensor,
            speech_batch=speech_batch,
            src_lengths=n_frames.to(device=model.device),
            after_lens=speech_lens.to(device=model.device),
        )
        # output = model.model.speech_tower.extract_features(speech_batch, None)
        
    return output

states1 = S2TAgentStates([])
states1.source_finished = False
states1.source_sample_rate = 16000
states1.source = np.random.rand(25600).tolist()

states2 = S2TAgentStates([])
states2.source_finished = False
states2.source_sample_rate = 16000
states2.source = states1.source + np.random.rand(5120).tolist()

# states3 = S2TAgentStates([])
# states3.source_finished = False
# states3.source_sample_rate = 16000
# states3.source = states2.source + np.random.rand(5120).tolist()

o1 = process(states1)
o2 = process(states2)
# o3 = process(states3)

In [None]:
def incremental_process(states):
    source = torch.tensor(states.source).to(
        device=model.device, dtype=model.dtype
    )
    speech_batch = _collate_frames([source], is_audio_input=True)
    n_frames = torch.tensor([source.size(0)], dtype=torch.long)
    speech_lens = length_after_adp(length_after_ssl(n_frames))

    to_adds = [int(speech_len)*DEFAULT_SPEECH_PATCH_TOKEN for speech_len in speech_lens]
    to_adds = [DEFAULT_SPEECH_START_TOKEN + to_add + DEFAULT_SPEECH_END_TOKEN for to_add in to_adds]

    conv = conversation_lib.default_conversation.copy()
    conv.messages = []
    conv.append_message(conv.roles[0], to_adds[0])
    conv.append_message(conv.roles[1], None)
    prompt_inputs = conv.get_prompt()

    inputs = tokenizer([prompt_inputs])
    input_ids = inputs.input_ids[0] + states.target_ids
    input_ids_tensor = torch.as_tensor([input_ids])
    model.model.speech_features_extracted = False

    with torch.inference_mode():
        output = model.model(
            attention_mask=input_ids_tensor.ne(tokenizer.pad_token_id),
            input_ids=input_ids_tensor,
            speech_batch=speech_batch,
            src_lengths=n_frames.to(device=model.device),
            after_lens=speech_lens.to(device=model.device),
            states=states,
            use_cache=True
        )
        # output = uni_w2v2_extract_features(
        #     model.model.speech_tower,
        #     speech_batch, 
        #     None,
        #     past_key_values=states.w2v2_past_key_values,
        #     past_features=states.w2v2_past_features,
        # )
        # states.w2v2_past_features = output["x"]
        
    # states.num_frames_read = len(states.source)

    return output

In [None]:
replace_uni_decode()
inc_states = IncrementalS2TAgentStates([], [], None, None, -1, 0)
inc_states.source_sample_rate = 16000
inc_states.source_finished = False
inc_states.w2v2_past_key_values = [
    {} for _ in range(model.model.speech_tower.cfg.encoder_layers)
]
inc_states.source = states1.source
io1 = incremental_process(inc_states)

In [None]:
inc_states.source = states2.source
io2 = incremental_process(inc_states)

# Check speech encoder trained using waco

In [1]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
import transformers
from tqdm.notebook import tqdm
from model.model import SpeechEncoder, SpeechLlamaForCausalLM
from train.dataset import PromptSpeechToTextDatasetCreator, SpeechSampler
from train.stage0 import DataCollatorForSupervisedDataset

[2024-02-13 15:42:31,265] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
device = 'cuda:0'

In [3]:
llm = SpeechLlamaForCausalLM.from_pretrained(
    '/mnt/taurus/data/xixu/llm/llama-2-7b/hf',
    low_cpu_mem_usage=True,
    load_in_8bit=False,
    device_map='cpu',
)
llm_embedding = copy.deepcopy(llm.model.embed_tokens)
del llm

You are using a model of type llama to instantiate a model of type SpeechLlama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    '/mnt/taurus/data/xixu/llm/llama-2-7b/hf',
    padding_side="right",
    use_fast=False,
)
tokenizer.pad_token = tokenizer.eos_token

In [5]:
w2v2_path = '/mnt/taurus/data/xixu/models/wav2_vec_vox_960h_pl.pt'
uni_enc = SpeechEncoder(w2v2_path, True, 1024, '3,3', llm_embedding, True, 0.2, 1e-4, 25000)
bi_enc = SpeechEncoder(w2v2_path, True, 1024, '3,3', llm_embedding, False, 0.2, 1e-4, 25000)

/mnt/taurus/home/siqiouyang/anaconda3/envs/sllama_lightning/lib/python3.8/site-packages/torch/nn/utils/weight_norm.py:30: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.


In [6]:
ckpt = torch.load('/mnt/taurus/data1/siqiouyang/runs/sllama/en-es/7b/uni/stage0/epoch=22-step=36000.ckpt', map_location='cpu')
uni_enc.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

In [7]:
ckpt = torch.load('/mnt/taurus/data1/siqiouyang/runs/sllama/en-es/7b/bi/stage0-bi/epoch=23-step=38000.ckpt', map_location='cpu')
bi_enc.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

In [8]:
uni_enc.to(device)
bi_enc.to(device)

SpeechEncoder(
  (speech_tower): Wav2Vec2Model(
    (feature_extractor): ConvFeatureExtractionModel(
      (conv_layers): ModuleList(
        (0): Sequential(
          (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (1): Dropout(p=0.0, inplace=False)
          (2): Sequential(
            (0): TransposeLast()
            (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (2): TransposeLast()
          )
          (3): GELU(approximate='none')
        )
        (1-4): 4 x Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (1): Dropout(p=0.0, inplace=False)
          (2): Sequential(
            (0): TransposeLast()
            (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (2): TransposeLast()
          )
          (3): GELU(approximate='none')
        )
        (5-6): 2 x Sequential(
          (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (1): Dropout(p=0.0, inplace=Fal

In [9]:
dataset = PromptSpeechToTextDatasetCreator.from_tsv('/mnt/taurus/data/xixu/datasets/must-c-v1.0/en-es', 'dev_mfa')   
data_collator = DataCollatorForSupervisedDataset(tokenizer, uni_enc.length_after_ssl, uni_enc.length_after_adp)

In [10]:
dataloader = DataLoader(dataset, batch_size=1, collate_fn=data_collator)

In [31]:
def compute_sim(enc, batch, only_left=False, only_right=False):
    src_text = batch["src_text"].to(device)
    src_speech = batch["src_speech"].to(device)
    src_speech_lengths = batch["src_speech_lengths"]
    after_speech_lengths = batch["after_speech_lengths"]
    text_word = batch["text_word"]
    speech_word = batch["speech_word"]

    src_text_emb = enc.llm_embedding(src_text).float()
    src_speech_emb = enc.get_ssl_feature_w2v(src_speech, src_speech_lengths, after_speech_lengths).transpose(0, 1).float()

    speech_word_emb = []
    text_word_emb = []
    for i in range(len(text_word)):
        s_word, t_word = speech_word[i], text_word[i]
        if s_word is not None:
            for j in range(s_word.size(0)):
                s_l, s_r = s_word[j]
                t_l, t_r = t_word[j]
                if only_left:
                    s_word_emb = src_speech_emb[i][s_l: (s_l + s_r) // 2 + 1].mean(dim=0)
                elif only_right:
                    s_word_emb = src_speech_emb[i][(s_l + s_r) // 2: s_r + 1].mean(dim=0)
                else:
                    s_word_emb = src_speech_emb[i][s_l : s_r + 1].mean(dim=0)
                t_word_emb = src_text_emb[i][t_l : t_r + 1].mean(dim=0)
                speech_word_emb.append(s_word_emb)
                text_word_emb.append(t_word_emb)
    speech_word_emb = torch.stack(speech_word_emb, dim=0)
    text_word_emb = torch.stack(text_word_emb, dim=0)

    st_sim = F.cosine_similarity(
        speech_word_emb, 
        text_word_emb, 
        dim=-1
    )

    return st_sim

In [32]:
uni_st_sims = []
uni_st_sims_left = []
uni_st_sims_right = []
with torch.inference_mode():
    for batch in tqdm(dataloader):
        if batch["speech_word"][0] is not None and batch["src_speech_lengths"] >= 5120:
            uni_st_sims.append(compute_sim(uni_enc, batch))
            uni_st_sims_left.append(compute_sim(uni_enc, batch, only_left=True))
            uni_st_sims_right.append(compute_sim(uni_enc, batch, only_right=True))
            torch.cuda.empty_cache()

  0%|          | 0/1312 [00:00<?, ?it/s]

In [33]:
bi_st_sims = []
bi_st_sims_left = []
bi_st_sims_right = []
with torch.inference_mode():
    for batch in tqdm(dataloader):
        if batch["speech_word"][0] is not None and batch["src_speech_lengths"] >= 5120:
            bi_st_sims.append(compute_sim(bi_enc, batch))
            bi_st_sims_left.append(compute_sim(bi_enc, batch, only_left=True))
            bi_st_sims_right.append(compute_sim(bi_enc, batch, only_right=True))
            torch.cuda.empty_cache()

  0%|          | 0/1312 [00:00<?, ?it/s]

In [34]:
uni_st_sims = torch.cat(uni_st_sims, dim=0)
uni_st_sims_left = torch.cat(uni_st_sims_left, dim=0)
uni_st_sims_right = torch.cat(uni_st_sims_right, dim=0)

In [35]:
bi_st_sims = torch.cat(bi_st_sims, dim=0)
bi_st_sims_left = torch.cat(bi_st_sims_left, dim=0)
bi_st_sims_right = torch.cat(bi_st_sims_right, dim=0)

In [36]:
uni_st_sims.mean(), bi_st_sims.mean()

(tensor(0.6815, device='cuda:0'), tensor(0.7011, device='cuda:0'))

In [37]:
uni_st_sims_left.mean(), bi_st_sims_left.mean()

(tensor(0.6180, device='cuda:0'), tensor(0.4562, device='cuda:0'))

In [38]:
uni_st_sims_right.mean(), bi_st_sims_right.mean()

(tensor(0.6327, device='cuda:0'), tensor(0.7230, device='cuda:0'))