### Train a ALM in Google Colab!

### Clone the repository if you don't have it already

In [1]:
import os

if not os.path.isdir('nanoALM'):
    !git clone https://github.com/LWL220184016/nanoVLM_From_Huggingface.git
%cd nanoVLM_From_Huggingface/
!ls

fatal: destination path 'nanoVLM_From_Huggingface' already exists and is not an empty directory.
/content/nanoVLM_From_Huggingface
assets			checkpoints  generate.py      nanoALM.ipynb  test
benchmark-inference.py	data	     measure_vram.py  old
benchmark_suite.py	debug	     models	      README.md


### Imports and Setup

In [2]:
# Let's authentificate with the Hugging Face Hub so you can push your model
# from huggingface_hub import notebook_login
# notebook_login()
# !huggingface-cli login


In [3]:
import os
from google.colab import drive
drive.mount('/content/drive')

check_dir = '/content/stage1'
source_dir = '/content/drive/MyDrive/nanoALM/7/stage1'

# 檢查來源資料夾是否存在且非空
if os.path.isdir(check_dir) and os.listdir(check_dir):
    print(f"偵測到檔案存在於 '{check_dir}'。將停止複製檔案和後續 pip 安裝。")
else:
    print(f"'{check_dir}' 是空的或不存在。將複製檔案並繼續執行 pip 安裝。")
    !cp -r {source_dir} /content

    # If you get an "Error" from pip's dependency resolver but the cell complets fine, this is not an issue, you can continue :)
    !pip -q install torch
    !pip -q install gcsfs
    !pip -q install tqdm
    !pip -q install huggingface_hub
    !pip -q install librosa
    !pip install soundfile librosa -q
    # !pip install --upgrade transformers
    !pip install datasets==3.6.0

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
偵測到檔案存在於 '/content/stage1'。將停止複製檔案和後續 pip 安裝。


In [4]:
# Decide on the name of your model here!
# You will need your HF user name and the name you want to give to it
# For me, this would be "lusxvr/nanoALM"
# hf_model_name = "YOUR_HF_USER_NAME/nanoALM"

In [5]:
# nanoALM Imports (please check out the implementations in detail, that's where all the interessting stuff is!)
from data.collators import AlignmentCollator, AudioQACollator
from data.datasets import SAVEEDataset, AudioQADataset
from data.processors import get_audio_processor
from data.processors import get_tokenizer
from models.audio_language_model import AudioLanguageModel
import models.utils as utils

# Libraries
import math
import time
import torch

from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt
from dataclasses import dataclass
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
#Otherwise, the tokenizer will through a warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

torch.autograd.set_detect_anomaly(True)

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

# To reload the modules if you change something in the code
%reload_ext autoreload
%autoreload 2

Using device: cuda


## Functions

### Get the dataloaders

In [6]:
def get_dataloaders(train_cfg, alm_cfg, tokenizer):
    # Create datasets
    audio_processor = get_audio_processor(alm_cfg)

    # text = "splitting datasets, disable in get_dataloaders function"
    # print(f"\n\033[38;5;05m{text}05m\033[0m")
    # Load and combine all training datasets
    combined_train_data = []
    for dataset_name in train_cfg.train_dataset_name:
        train_ds = load_dataset(
        path = train_cfg.train_dataset_path,
        name = dataset_name,
    )
        combined_train_data.append(train_ds['train'])
    train_ds = concatenate_datasets(combined_train_data)

    test_ds = load_dataset(train_cfg.test_dataset_path)
    train_ds = train_ds.shuffle(seed=0) # Shuffle the training dataset, so train and val get equal contributions from all concatinated datasets

    # Apply cutoff if specified
    if train_cfg.data_cutoff_idx is None:
        total_samples = len(train_ds)  # Use the entire dataset
    else:
        total_samples = min(len(train_ds), train_cfg.data_cutoff_idx)

    val_size = int(total_samples * train_cfg.val_ratio)
    train_size = total_samples - val_size

    train_dataset = AudioQADataset(train_ds.select(range(train_size)), tokenizer, audio_processor)
    val_dataset = AudioQADataset(train_ds.select(range(train_size, total_samples)), tokenizer, audio_processor)
    test_dataset = SAVEEDataset(test_ds, tokenizer, audio_processor)

    # Create collators
    alignment_collator = AlignmentCollator(tokenizer, alm_cfg.lm_max_length, audio_processor)
    aqa_collator = AudioQACollator(tokenizer)
    savee_collator = AudioQACollator(tokenizer)

    # Create dataloaders
    alignment_train_loader = DataLoader(
        train_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=True,
        collate_fn=alignment_collator,
        num_workers=2,
        pin_memory=True,
        drop_last=True,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=True,
        collate_fn=aqa_collator,
        num_workers=2,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=False,
        collate_fn=aqa_collator,
        num_workers=2,
        pin_memory=True,
        drop_last=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=train_cfg.savee_batch_size,
        shuffle=False,
        collate_fn=savee_collator,
        pin_memory=True,
        )

    return alignment_train_loader, train_loader, val_loader, test_loader

### Prepare the testing function

In [7]:
def test_savee(model, tokenizer, test_loader, device):
    total_examples = 0
    correct_predictions = 0
    with torch.no_grad():
        for batch in test_loader:
            audio = batch['audio'].to(device)
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            correct_answer = tokenizer.batch_decode(labels, skip_special_tokens=True)

            gen = model.generate(input_ids, audio, attention_mask)
            model_output = tokenizer.batch_decode(gen, skip_special_tokens=True)

            is_correct = utils.check_multiple_choice_with_regex(model_output, correct_answer)

            total_examples += len(is_correct)
            if is_correct:
                correct_predictions += sum(is_correct)
    accuracy = correct_predictions / total_examples if total_examples > 0 else 0
    return accuracy

def get_avg_alignment(model, val_loader, device, epoch):
    """
    Validate the model's audio-text alignment on the validation set.
    This function computes the average alignment score over the validation set.
    It runs for a maximum of 20 batches to save time during training.
    """
    model.eval()
    total_alignment_score = 0

    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            if i >= 20:  # 只驗證前20個batch以節省時間
                break
            audios = batch["audio"].to(device)
            input_ids = batch["input_ids"].to(device)
            alignment_score = model.validate_audio_text_alignment(input_ids, audios)
            total_alignment_score += alignment_score

    avg_alignment = total_alignment_score / min(20, len(val_loader))
    print(f"Epoch {epoch+1}: Average alignment score: {avg_alignment:.4f}")

    print(" ")
    model.train()
    return avg_alignment

### Prepare the training loop

#### Three-stage training (contrast training, generative training, instruction fine-tuning) 三段式訓練(對比訓練, 生成式訓練, 指令微調)

In [8]:
import torch.nn.functional as F
import torch.nn as nn
import torch.amp as GradScaler

from debug.debug_func import debug_contrastive_learning

# 改進對比學習訓練
def get_lr(it, max_lr, max_steps):
    min_lr = max_lr * 0.1
    warmup_steps = max_steps * 0.03
    # 1) linear warmup for warmup_iters steps
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    # 2) if it > lr_decay_iters, return min learning rate
    if it > max_steps:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)

def get_contrastive_loss(audio_embeds, text_embeds, temperature=0.07):
    """
    標準、高效的對比學習損失 (CLIP Loss)。
    注意：輸入的 embeds 應該是池化後的 [B, D] 維度向量。
    """
    # --- 開始計算尺度對齊損失 ---

    # 1. 計算每個向量的 L2 範數 (沿著最後一個維度)
    #    detach() 是為了確保這個輔助損失的梯度只流向投影器，而不影響上游的編碼器或文字嵌入層的穩定性（可選但推薦）
    audio_norms = torch.norm(audio_embeds, p=2, dim=-1)
    text_norms = torch.norm(text_embeds, p=2, dim=-1)

    # 2. 計算批次內的平均範數
    mean_audio_norm = torch.mean(audio_norms)
    mean_text_norm = torch.mean(text_norms)

    # 3. 計算尺度對齊損失 (使用 MSE)
    #    目標是讓 mean_audio_norm 和 mean_text_norm 盡可能接近
    scale_loss = F.mse_loss(mean_audio_norm, mean_text_norm)

    # --- 結束計算尺度對齊損失 ---

    # --- 開始計算交叉熵損失 ---
    # 歸一化
    audio_embeds = F.normalize(audio_embeds, p=2, dim=-1)
    text_embeds = F.normalize(text_embeds, p=2, dim=-1)

    # 計算相似度矩陣
    # temperature 是一個重要的超參數，CLIP 論文中是可學習的，但固定值也可以
    logits_per_audio = torch.matmul(audio_embeds, text_embeds.T) / temperature
    logits_per_text = logits_per_audio.T

    # 創建標籤 (0, 1, 2, ..., B-1)
    labels = torch.arange(audio_embeds.shape[0]).to(logits_per_audio.device)

    # 對稱的交叉熵損失
    loss_a = F.cross_entropy(logits_per_audio, labels)
    loss_t = F.cross_entropy(logits_per_text, labels)

    contrastive_loss = (loss_a + loss_t) / 2

    # --- 結束計算交叉熵損失 ---


    # 4. 組合損失
    #    lambda_scale 是一個需要調整的超參數，用來平衡兩個損失的權重
    lambda_scale = 0.001  # 範例值，可以從 0.01, 0.1, 1.0 等開始嘗試
    total_loss = contrastive_loss + lambda_scale * scale_loss

    # 監控指標 (可選但推薦)
    with torch.no_grad():
        pos_sim = torch.diagonal(logits_per_audio * temperature).mean()
        mask = ~torch.eye(labels.shape[0], dtype=torch.bool, device=labels.device)
        neg_sim = (logits_per_audio * temperature)[mask].mean()

    return total_loss, contrastive_loss, scale_loss, {
        "loss": total_loss.item(),
        "pos_sim": pos_sim.item(), # 正樣本對的餘弦相似度
        "neg_sim": neg_sim.item()  # 負樣本對的餘弦相似度
    }

def train_step1_alignment(train_cfg, alm_cfg, model=None, tokenizer=None, device=None):
    # 凍結音頻編碼器和語言模型
    model.audio_encoder.audio_encoder.requires_grad_(False)
    model.decoder.requires_grad_(False)
    model.MP.requires_grad_(True)

    alignment_train_loader, _, val_loader, _ = get_dataloaders(train_cfg, alm_cfg, tokenizer)

    optimizer = optim.AdamW(model.MP.parameters(), lr=train_cfg.lr_mp, weight_decay=0.01)

    best_alignment = 0

    for epoch in range(train_cfg.stage1_epochs):
        model.train()
        total_contrastive_loss = 0  # 添加這個變數初始化
        total_scale_loss = 0  # 添加這個變數初始化

        for batch in tqdm(alignment_train_loader, desc=f"Stage1 Epoch {epoch+1}"):
            audios = batch["audio"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            optimizer.zero_grad()

            # 1. 音頻編碼 -> 投影
            with torch.no_grad():
                audio_features = model.audio_encoder.audio_encoder(audios, output_hidden_states=True)
            projected_audio_features = model.MP(audio_features.last_hidden_state)

            # 2. 文本編碼 - 修復這裡的問題
            with torch.no_grad():
                # 檢查 decoder 的 forward 方法簽名
                # 根據 language_model.py，應該傳入 x 而不是分別的參數
                text_embeds = model.decoder.token_embedding(input_ids)  # 直接獲取文本嵌入

                # 如果需要通過完整的 decoder，使用以下方式：
                # text_outputs, _ = model.decoder(text_embeds, attention_mask=attention_mask)
                # text_embeds = text_outputs  # 使用輸出的嵌入

            # 3. 池化操作 (Pooling)
            # 音頻池化
            audio_pooled = projected_audio_features.mean(dim=1)  # [B, D]

            # 文本池化 - 修復維度問題
            # text_embeds 現在是 [B, seq_len, hidden_dim]
            if attention_mask is not None:
                # 根據 attention_mask 來安全地做平均池化
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_embeds.size()).float()
                sum_embeddings = torch.sum(text_embeds * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                text_pooled = sum_embeddings / sum_mask  # [B, D]
            else:
                text_pooled = text_embeds.mean(dim=1)  # [B, D]

            # 如果維度仍然不匹配，添加投影層
            if audio_pooled.shape[-1] != text_pooled.shape[-1]:
                # 創建一個投影層來匹配維度
                if not hasattr(model, 'text_projection'):
                    model.text_projection = nn.Linear(text_pooled.shape[-1], audio_pooled.shape[-1]).to(device)
                text_pooled = model.text_projection(text_pooled)

            # 4. 計算對比損失
            loss, contrastive_loss, scale_loss, metrics = get_contrastive_loss(audio_pooled, text_pooled)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.MP.parameters(), max_norm=1.0)
            optimizer.step()

            total_contrastive_loss += contrastive_loss.item()
            total_scale_loss += scale_loss.item()

        avg_contrastive_loss = total_contrastive_loss / len(alignment_train_loader)
        avg_scale_loss = total_scale_loss / len(alignment_train_loader)
        print(f"\nEpoch {epoch+1}: Total Loss {loss:.4f}, Contrastive Loss {avg_contrastive_loss:.4f}, Scale Loss {avg_scale_loss:.4f}")

        avg_alignment = get_avg_alignment(model, val_loader, device, epoch)

        if avg_alignment > best_alignment:
            best_alignment = avg_alignment
            model.save_pretrained(save_directory=f"{alm_cfg.alm_checkpoint_path}/stage1_best")
            print(f"  New best alignment: {best_alignment:.4f}")

    print(f"Stage 1 completed! Best alignment: {best_alignment:.4f}")
    return model

def train_step2_pretraining(train_cfg, alm_cfg, stage1_model=None, tokenizer=None, device=None):
    print("=== Stage 2: Language Model Pretraining ===")

    # 使用傳入的 tokenizer，不要重建
    _, train_loader, val_loader, test_loader = get_dataloaders(train_cfg, alm_cfg, tokenizer)

    model = stage1_model
    # 冻结/解冻
    model.audio_encoder.audio_encoder.requires_grad_(False)
    model.MP.requires_grad_(True)
    model.decoder.requires_grad_(True)  # 或僅解凍頂層幾層

    # 調小 decoder LR
    param_groups = [
        {'params': [p for p in model.MP.parameters() if p.requires_grad], 'lr': train_cfg.lr_mp * 0.05},
        {'params': [p for p in model.decoder.parameters() if p.requires_grad], 'lr': train_cfg.lr_backbones},
    ]
    optimizer = optim.AdamW(param_groups, weight_decay=0.01)
    scaler = torch.amp.GradScaler(device=device)

    batch_losses = []
    best_loss = float('inf')
    global_step = 0
    total_steps = len(train_loader) * train_cfg.stage2_epochs

    for epoch in range(train_cfg.stage2_epochs):
        model.train()
        total_train_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Stage2 Epoch {epoch+1}"):
            audios = batch["audio"].to(device)
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            optimizer.zero_grad(set_to_none=True)

            if alm_cfg.dtype == torch.float16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    _, loss = model(input_ids, audios, attention_mask=attention_mask, labels=labels)
            else:
                _, loss = model(input_ids, audios, attention_mask=attention_mask, labels=labels)

            scaler.scale(loss).backward()

            # 調整 LR
            optimizer.param_groups[0]['lr'] = get_lr(global_step, param_groups[0]['lr'], total_steps)
            optimizer.param_groups[1]['lr'] = get_lr(global_step, param_groups[1]['lr'], total_steps)

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                [p for p in model.parameters() if p.requires_grad and p.grad is not None],
                max_norm=1.0
            )
            scaler.step(optimizer)
            scaler.update()

            batch_loss = loss.item()
            total_train_loss += batch_loss
            batch_losses.append(batch_loss)
            global_step += 1

        avg_train_loss = total_train_loss / len(train_loader)
        if avg_train_loss < best_loss:
            best_loss = avg_train_loss
            model.save_pretrained(save_directory=f"{alm_cfg.alm_checkpoint_path}/stage2_best")

        avg_alignment = get_avg_alignment(model, val_loader, device, epoch)
        print(f"\nEpoch {epoch+1}/{train_cfg.stage2_epochs} | Loss: {avg_train_loss:.4f} | Alignment: {avg_alignment:.4f}")

    model.save_pretrained(save_directory=f"{alm_cfg.alm_checkpoint_path}/stage2_final")
    print("Stage 2 completed!")
    plt.plot(batch_losses, label='Train Loss')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.grid(True)
    plt.legend()
    plt.show()

    return model

def train_step3_instruction_tuning(train_cfg, alm_cfg, stage2_model=None, tokenizer=None, device=None):
    """第三步：指令微调"""
    print("=== Stage 3: Instruction Tuning ===")

    _, train_loader, val_loader, test_loader = get_dataloaders(train_cfg, alm_cfg, tokenizer)
    scaler = torch.amp.GradScaler(device=device)

    model = stage2_model
    # 全部解冻，使用较小学习率
    for param in model.parameters():
        param.requires_grad = True

    print(f"Stage 3: Training all {sum(p.numel() for p in model.parameters()):,} parameters")

    # 更小的学习率
    param_groups = [
        {'params': model.MP.parameters(), 'lr': train_cfg.lr_mp * 0.01},
        {'params': model.decoder.parameters(), 'lr': train_cfg.lr_backbones * 0.1},
        {'params': model.audio_encoder.audio_encoder.parameters(), 'lr': train_cfg.lr_backbones * 0.01}
    ]
    optimizer = optim.AdamW(param_groups)

    if train_cfg.compile:
        model = torch.compile(model)

    # 这里可以使用原来的训练循环，但数据应该是指令格式
    # 暂时使用相同的数据格式
    best_accuracy = 0
    global_step = 0

    for epoch in range(train_cfg.stage3_epochs):
        model.train()
        total_train_loss = 0

        for batch in tqdm(train_loader, desc=f"Stage3 Epoch {epoch+1}"):
            audios = batch["audio"].to(device)
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            optimizer.zero_grad()

            if alm_cfg.dtype == torch.float16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    _, loss = model(input_ids, audios, attention_mask=attention_mask, labels=labels)
            else:
                _, loss = model(input_ids, audios, attention_mask=attention_mask, labels=labels)


            scaler.unscale_(optimizer)
            # 2. 對 unscale 後的梯度進行裁剪 (max_norm=1.0 是一個常用的值)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            # --- 修改結束 ---

            # 3. Scaler 執行優化器步驟 (如果梯度沒有 inf/nan)
            scaler.step(optimizer)
            # 4. 更新 scaler 的縮放因子
            scaler.update()

            batch_loss = loss.item()
            total_train_loss += batch_loss

            if global_step % 50 == 0:
                print(f"Stage3 Step: {global_step}, Instruction Loss: {batch_loss:.4f}")

            global_step += 1

        avg_train_loss = total_train_loss / len(train_loader)

        # 评估性能
        if train_cfg.eval_in_epochs:
            accuracy = test_savee(model, tokenizer, test_loader, device)
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                model.save_pretrained(save_directory=f"{alm_cfg.alm_checkpoint_path}/stage3_best")
            print(f"Stage3 Epoch {epoch+1}/{train_cfg.stage3_epochs} | Loss: {avg_train_loss:.4f} | Accuracy: {accuracy:.4f}")
        else:
            print(f"Stage3 Epoch {epoch+1}/{train_cfg.stage3_epochs} | Instruction Loss: {avg_train_loss:.4f}")

    # 保存最终模型
    model.save_pretrained(save_directory=f"{alm_cfg.alm_checkpoint_path}/final_model")
    print("Stage 3 completed!")
    return model

def train_three_stages(train_cfg, alm_cfg, device=None):
    """完整的三阶段训练"""
    print("Starting Three-Stage Training Pipeline")

    # 第一阶段：模态投影器对齐
    stage1_model = train_step1_alignment(train_cfg, alm_cfg, device=device)

    # 第二阶段：语言模型预训练
    stage2_model = train_step2_pretraining(train_cfg, alm_cfg, stage1_model, device=device)

    # 第三阶段：指令微调
    final_model = train_step3_instruction_tuning(train_cfg, alm_cfg, stage2_model, device=device)

    print("=== Training Pipeline Completed! ===")
    return stage1_model, stage2_model, final_model


# # 替换原来的训练调用
# alm_cfg = ALMConfig()
# train_cfg = TrainConfig()

# # 运行三阶段训练
# final_model = train_three_stages(train_cfg, alm_cfg)

## Debug

In [9]:
# !python ./debug/debug_forward.py

## Lets run the training!

In [10]:
import os
from models.config import ALMConfig, TrainConfig

# 要創建的目錄路徑
dir_name = ALMConfig.alm_checkpoint_path

try:
    os.mkdir(dir_name)
    print(f"Directory '{dir_name}' created successfully.")
except FileExistsError:
    print(f"Directory '{dir_name}' already exists.")
except FileNotFoundError:
    print(f"Parent directory does not exist for '{dir_name}'.")
except Exception as e:
    print(f"An error occurred: {e}")

alm_cfg = ALMConfig()
train_cfg = TrainConfig()
dtype = alm_cfg.dtype
device = alm_cfg.device
print(f"Using device: {device}")


tokenizer = get_tokenizer(alm_cfg.lm_tokenizer)

# 創建一個帶有新 token 的模型實例
model = None
stage1_model = None

if train_cfg.resume_from_alm_checkpoint:
    checkpoint_path = "../stage1"
    # checkpoint_path = "./checkpoints/stage1_best"
    stage1_model = AudioLanguageModel.from_pretrained(checkpoint_path, tokenizer=tokenizer)

else:
    model = AudioLanguageModel(alm_cfg, load_from_HF=True, tokenizer=tokenizer)


Directory 'checkpoints' already exists.
Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


load model from local
Initializing empty audio encoder: openai/whisper-small.en


In [11]:
if model != None:
    stage1_model = train_step1_alignment(train_cfg, alm_cfg, model, tokenizer, device)
    stage1_model.save_pretrained("/content/")

In [12]:
stage2_model = train_step2_pretraining(train_cfg, alm_cfg, stage1_model, tokenizer, device)
stage2_model.save_pretrained("/content/")

=== Stage 2: Language Model Pretraining ===
AudioProcessor_from_HF initialized with model: <class 'transformers.models.whisper.processing_whisper.WhisperProcessor'>
  Target feature frames from cfg: 1500
  Using model sampling rate: 16000, hop_length: 160, n_fft: 400
  Calculated max raw audio samples for processor: 240240


Stage2 Epoch 1:   0%|          | 0/205 [00:00<?, ?it/s]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 54, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 78, 2048]), dtype=torch.float32, min=-4.2626, max=4.2608, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 78, 2048]), dtype=torch.float32, min=-54.5700, max=48.1333, mean=0.0052
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 54]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 34763,    28,
          1303,   549,   820,  1272,    30, 16912,    28,  6820,   617,   260,
         12541,   314,    47,   657,   506,   549,    30,     0,     2,   198,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47

Stage2 Epoch 1:   0%|          | 1/205 [00:03<11:48,  3.47s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 69, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 93, 2048]), dtype=torch.float32, min=-4.7663, max=4.9278, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 93, 2048]), dtype=torch.float32, min=-54.5700, max=48.7536, mean=0.0036
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 69]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,    28,
           346,   699,   787,  1542,  2166, 15411,   502,   359,  1035,  1165,
         13645,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:   1%|          | 2/205 [00:04<07:28,  2.21s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 57, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-4.3181, max=4.4578, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-54.5700, max=48.4671, mean=0.0056
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 57]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   669,   314,
           915,   451,    30,   216,  1206,   699,    47,   216,   339,  1441,
           357,   506,   253,  2341,   284,  3117,   564,   357,  3247,   982,
         10105,  1745,    30,   216,  3315,   346,   699,   732,   339,  5248,
          3396,   288,  1643,    47,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281

Stage2 Epoch 1:   1%|▏         | 3/205 [00:06<06:02,  1.80s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 32, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 56, 2048]), dtype=torch.float32, min=-4.1223, max=4.2858, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 56, 2048]), dtype=torch.float32, min=-54.5700, max=49.1008, mean=0.0082
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 32]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  2316,
           441,  2045,   702,   451,    30,     0,     2,   198,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1073,   800,
            47,     0,     2,   198,     0,     0,     0,     0,     0,     0,
            

Stage2 Epoch 1:   2%|▏         | 4/205 [00:07<05:17,  1.58s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 50, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-4.5048, max=4.4251, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-54.5700, max=49.8333, mean=0.0067
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 50]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,    28,
          1206,   359,   588,   973, 13442,   515,    30,   216,  3315,   346,
          2853,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198

Stage2 Epoch 1:   2%|▏         | 5/205 [00:08<04:56,  1.48s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 33, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 57, 2048]), dtype=torch.float32, min=-4.4336, max=4.5488, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 57, 2048]), dtype=torch.float32, min=-54.5700, max=50.1076, mean=0.0062
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 33]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,    29,
          1812,  6688,  1250,   346,  1690,   281,   335,    47,     0,     2,
           198,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   732,    47,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
     

Stage2 Epoch 1:   3%|▎         | 6/205 [00:09<04:38,  1.40s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 82, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 106, 2048]), dtype=torch.float32, min=-4.5759, max=4.4522, mean=-0.0005
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 106, 2048]), dtype=torch.float32, min=-54.5700, max=49.4828, mean=0.0064
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 82]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,  3514,
           335, 22865,  1434,   327,  2009,    30,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     

Stage2 Epoch 1:   3%|▎         | 7/205 [00:11<04:34,  1.39s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 46, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-4.4812, max=4.2905, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-54.5700, max=48.5151, mean=0.0044
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 46]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812, 39527,
            47,     0,     2,   198,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198

Stage2 Epoch 1:   4%|▍         | 8/205 [00:12<04:26,  1.35s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 46, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-4.6086, max=4.4076, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-54.5700, max=49.3177, mean=0.0081
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 46]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,    28,
           957,  1839,    30,     0,     2,   198,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198

Stage2 Epoch 1:   4%|▍         | 9/205 [00:13<04:19,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 61, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 85, 2048]), dtype=torch.float32, min=-4.9556, max=4.4992, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 85, 2048]), dtype=torch.float32, min=-54.5700, max=49.1751, mean=0.0088
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 61]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3009,  8320,
           699,    47,  9413,   282,   260,  3191, 12053,    30,     0,     2,
           198,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  40

Stage2 Epoch 1:   5%|▍         | 10/205 [00:15<04:16,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 44, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 68, 2048]), dtype=torch.float32, min=-4.4558, max=4.3838, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 68, 2048]), dtype=torch.float32, min=-54.5700, max=48.9279, mean=0.0041
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 44]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6810,   732,
            47,     0,     2,   198,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  1250

Stage2 Epoch 1:   5%|▌         | 11/205 [00:16<04:13,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 41, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-4.4119, max=4.4565, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-54.5700, max=48.4802, mean=0.0085
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 41]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3315,   346,
          1277,   288,   820,  7553,  1163,    47,    29,  1812,    47,   216,
          1812,    28,   536,   346,  1277,   253, 16532,    47,     0,     2,
           198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1644,   346,
          8540,   9

Stage2 Epoch 1:   6%|▌         | 12/205 [00:17<04:10,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 49, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 73, 2048]), dtype=torch.float32, min=-4.2790, max=4.5555, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 73, 2048]), dtype=torch.float32, min=-54.5700, max=48.1333, mean=0.0055
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 49]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22297,   253,
          8427,    30,   216,   339,  3543,  3363,   357,    30,     0,     2,
           198,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1

Stage2 Epoch 1:   6%|▋         | 13/205 [00:18<04:08,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 28, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 52, 2048]), dtype=torch.float32, min=-4.6008, max=4.3760, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 52, 2048]), dtype=torch.float32, min=-54.5700, max=48.8435, mean=0.0063
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 28]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,    47,
             0,     2,   198,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 10007,    30,
             0,     2,   198,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   45

Stage2 Epoch 1:   7%|▋         | 14/205 [00:20<04:04,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 78, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 102, 2048]), dtype=torch.float32, min=-4.6761, max=4.5160, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 102, 2048]), dtype=torch.float32, min=-54.5700, max=48.4132, mean=0.0041
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 78]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1326,
           982,  1176,   732,   346,   536,    30,   216,  4787,    28,   346,
           856,  7670,   469,  1248,  5681,  2654,   284,   284,  1658, 18557,
           738,   260,  9373,   717,   486,   465,   284,  1658,   351,   897,
           555,   281,   260,   905,   284,   339,   443,   276,   982,  1643,
           253,  2229,    30,   21

Stage2 Epoch 1:   7%|▋         | 15/205 [00:21<04:07,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 36, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-4.2638, max=4.6210, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-54.5700, max=50.0431, mean=0.0092
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 36]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  2316,
           253,  2872,   284, 20891,  4166,    30,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1934,   500,
           355, 39523, 18871,    28, 25847, 25957, 25957, 39523

Stage2 Epoch 1:   8%|▊         | 16/205 [00:22<04:03,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 41, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-4.5694, max=4.2034, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-54.5700, max=50.6867, mean=0.0081
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 41]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1068,  1839,
           506, 15215,    28,  1296,   929,    47,   216, 32304,  2216,  1056,
           990,  1296,   929,    43,   338,   506, 38695,    30,     0,     2,
           198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  2242,  2141,
         11476,    

Stage2 Epoch 1:   8%|▊         | 17/205 [00:24<04:01,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 55, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 79, 2048]), dtype=torch.float32, min=-4.3632, max=4.4165, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 79, 2048]), dtype=torch.float32, min=-54.5700, max=48.9112, mean=0.0051
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 55]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   514,    88,
         24651,    28,  8707,    81,    28,  3590,   634,   282,   957, 21074,
            28,  9805,    28,   588,   585,   339,  3590,  5196,   253, 21120,
           338,   436,   281,   260,  9066,  7544,   284,   702,   253,  3631,
            29,   258,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389

Stage2 Epoch 1:   9%|▉         | 18/205 [00:25<04:00,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 48, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 72, 2048]), dtype=torch.float32, min=-4.3834, max=4.4539, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 72, 2048]), dtype=torch.float32, min=-54.5700, max=50.3546, mean=0.0054
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 48]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1073, 31404,
            30,  1073, 26963,    28, 26963, 31404,    30,     0,     2,   198,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520

Stage2 Epoch 1:   9%|▉         | 19/205 [00:26<03:59,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 46, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-4.4783, max=4.2886, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-54.5700, max=48.8796, mean=0.0005
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 46]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  5248,
         22657,    30,   216,   657,   506,   915,    28,   357,   506,   357,
           506,   915,  2698,   288,   549,    30,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198

Stage2 Epoch 1:  10%|▉         | 20/205 [00:27<03:57,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 36, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-4.4110, max=4.5488, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-54.5700, max=49.1094, mean=0.0050
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 36]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 10007,    30,
          9230,    28, 12102,    23,   332,    30,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 42679,    28,
          5872,    30, 20718,    30,     0,     2,   198,     0

Stage2 Epoch 1:  10%|█         | 21/205 [00:29<03:54,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 31, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 55, 2048]), dtype=torch.float32, min=-4.4113, max=4.4567, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 55, 2048]), dtype=torch.float32, min=-59.7188, max=48.3166, mean=0.0040
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 31]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,   699,
           338,  3627,   519,    28,   339,  2159,   536,    30,     0,     2,
           198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910, 14119,
            28,   868,   392,   702,   338,    47,     0,     2,   198,     0,
             0],
  

Stage2 Epoch 1:  11%|█         | 22/205 [00:30<03:52,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 43, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-4.3879, max=4.3231, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-54.5700, max=50.0291, mean=0.0070
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 43]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,    29,
          1206,   416,  1998,   874,   585,   346,  1277,    28,   564,  1041,
           506, 10229,   284,  1041,  3105,  3763,   982,  2988,   346,    30,
             0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   378,  7241,
     

Stage2 Epoch 1:  11%|█         | 23/205 [00:31<03:52,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 37, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-4.5105, max=4.4439, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-54.5700, max=48.6157, mean=0.0040
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 37]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,   957,
           310,  8619,    30,   216,  1046, 13228,   601,  6757,    28,   338,
           506, 12765,  2042,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1644,   346,
           699,   732,   346,   359,  2045,   288,   536

Stage2 Epoch 1:  12%|█▏        | 24/205 [00:33<03:50,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 34, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 58, 2048]), dtype=torch.float32, min=-4.2693, max=4.3801, mean=-0.0001
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 58, 2048]), dtype=torch.float32, min=-54.5700, max=49.1606, mean=0.0063
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 34]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  6820,
           346,   699,   451,   314,  1701,   339,  3413,   346,   288,  1690,
            30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,   699,
            30,   216,   339,   699,   357,   506,   441,  3506,    30,     0

Stage2 Epoch 1:  12%|█▏        | 25/205 [00:34<03:48,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 40, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 64, 2048]), dtype=torch.float32, min=-4.6164, max=4.8775, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 64, 2048]), dtype=torch.float32, min=-54.5700, max=49.7080, mean=0.0068
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 40]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,    28,
           585,   346, 18918,   982,   719,   588, 19882,   284,  8166,   105,
           346,   736,  3543,  2093,  1343,  3534,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,   744,
           441,   260, 37068,  3241

Stage2 Epoch 1:  13%|█▎        | 26/205 [00:35<03:46,  1.26s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 58, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 82, 2048]), dtype=torch.float32, min=-4.3775, max=4.6233, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 82, 2048]), dtype=torch.float32, min=-54.5700, max=49.4150, mean=0.0081
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 58]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  2959,   506,
            29,  2959,   506,   915,  6081,   327,   253,  3847,    28, 12102,
            23,   332,    30,   216,  1206,  1044,    28, 12032,    28,   392,
          2316,  2045,   288,   536,  3117,   392,   416,   288,  1446,   511,
           282,   469, 47161,  1535,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137

Stage2 Epoch 1:  13%|█▎        | 27/205 [00:36<03:47,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 28, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 52, 2048]), dtype=torch.float32, min=-4.2940, max=4.6520, mean=-0.0001
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 52, 2048]), dtype=torch.float32, min=-54.5700, max=48.6690, mean=0.0089
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 28]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1326,
           982,   699,    30,     0,     2,   198,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1277,
           346,  1209,    28, 36779,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   45

Stage2 Epoch 1:  14%|█▎        | 28/205 [00:38<03:44,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 69, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 93, 2048]), dtype=torch.float32, min=-4.3799, max=4.2565, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 93, 2048]), dtype=torch.float32, min=-54.5700, max=49.5764, mean=0.0051
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 69]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,    29,
           339,   416,  1188,   469,  8443,   506,  9768,   564,   339,   737,
          1372,   910,   282,  8927,   281,  1686,   288,   980,   451,  1941,
           746,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  14%|█▍        | 29/205 [00:39<03:46,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 76, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 100, 2048]), dtype=torch.float32, min=-4.5701, max=4.3922, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 100, 2048]), dtype=torch.float32, min=-54.5700, max=50.5453, mean=0.0054
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 76]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,    47,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     

Stage2 Epoch 1:  15%|█▍        | 30/205 [00:40<03:47,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 45, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 69, 2048]), dtype=torch.float32, min=-4.3910, max=4.4055, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 69, 2048]), dtype=torch.float32, min=-54.5700, max=50.1226, mean=0.0085
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 45]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339, 21111,
           277,   346,    30,   216,   339, 21111,   277, 36779,    30,   216,
          1073,   986,    28,   638,   986,   339,  3543, 25693,   288, 21111,
           346,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  9230

Stage2 Epoch 1:  15%|█▌        | 31/205 [00:42<03:45,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 65, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-4.7198, max=4.7233, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-54.5700, max=48.5275, mean=0.0038
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 65]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  8114,   655,
            28,   346,  3060,   597,  2915,    28,   392,  3363,   653, 10582,
          5793,    28,   339,  3363,  4903,   281,   957, 23247,   370,    28,
           346,  3363,   253, 14501, 10772,   284,   392,  3363,   618,   253,
          2066,  3568,   284,   392,  1250,   441,   963,   260,  1665,   419,
           438,    30,     0,     2,

Stage2 Epoch 1:  16%|█▌        | 32/205 [00:43<03:45,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 36, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-4.5637, max=4.4330, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-55.4199, max=49.1829, mean=0.0055
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 36]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,    28,
          1206,   359,  1147,  1083,    28,   339,   416,   982,  5605,   357,
            30,     0,     2,   198,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   909,   506,
         31471,   563,  1272,  1163,    30,   216,   909,   506

Stage2 Epoch 1:  16%|█▌        | 33/205 [00:44<03:42,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 70, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 94, 2048]), dtype=torch.float32, min=-4.6037, max=4.3706, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 94, 2048]), dtype=torch.float32, min=-54.5700, max=49.6104, mean=0.0071
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 70]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,    29,
           339,  3060,  1928,   346,   253,  1838, 13942,    28,  1867,   655,
           585,   357, 11421,   585,   346,  2316,   335,   253, 33258,  5460,
           346,   699,    28,   915,  2951,   335,   260,  1761,  2489,   282,
           731,  2489,   282, 18303,   582,   355,   827,   288,  1643,   585,
           346,  2316,   702,   335,

Stage2 Epoch 1:  17%|█▋        | 34/205 [00:45<03:43,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 66, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 90, 2048]), dtype=torch.float32, min=-4.4329, max=4.6719, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 90, 2048]), dtype=torch.float32, min=-54.5700, max=49.9013, mean=0.0085
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 66]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22096,    28,
         40782,    28, 40782,   346,  1365,   325, 22657,    30,     0,     2,
           198,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  17%|█▋        | 35/205 [00:47<03:42,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 45, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 69, 2048]), dtype=torch.float32, min=-4.2699, max=4.4586, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 69, 2048]), dtype=torch.float32, min=-54.5700, max=48.9658, mean=0.0054
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 45]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1065,  2771,
            28,   965,    28,  2910,    30,   216,  9725,   392, 14577,  1488,
            47,   216,  2838,    30,   216,  1065,  2771,    28,   965,    28,
          2910,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 15465

Stage2 Epoch 1:  18%|█▊        | 36/205 [00:48<03:39,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 42, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 66, 2048]), dtype=torch.float32, min=-4.4666, max=4.5145, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 66, 2048]), dtype=torch.float32, min=-54.5700, max=49.9577, mean=0.0057
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 42]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  6737,
          1811,   325,   588, 23486,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1032,  1869,
           3

Stage2 Epoch 1:  18%|█▊        | 37/205 [00:49<03:36,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 43, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-4.2821, max=4.5037, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-54.5700, max=49.0952, mean=0.0072
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 43]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 36779,    28,
         36779,    30,   216,   339,   744,  2045,   288,   919,   253, 18598,
           327,   346,    30,     0,     2,   198,     0,     0,     0,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 36727,    30,
     

Stage2 Epoch 1:  19%|█▊        | 38/205 [00:51<03:34,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 56, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 80, 2048]), dtype=torch.float32, min=-4.4576, max=4.4983, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 80, 2048]), dtype=torch.float32, min=-55.5934, max=50.3856, mean=0.0045
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 56]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  5713,  2009,
           339,  3060,   325,   281,  4653, 14105,   284,   638,   506,   357,
          2045,   288,   820,   288,   549,   281,  4653, 14105,    47,   216,
          1206,  2045,   288,   919,  2090,   357,  4364,   288,   260,  4653,
         14105, 17678,    47,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451

Stage2 Epoch 1:  19%|█▉        | 39/205 [00:52<03:33,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 47, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 71, 2048]), dtype=torch.float32, min=-4.3654, max=4.3425, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 71, 2048]), dtype=torch.float32, min=-55.4796, max=48.5990, mean=0.0050
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 47]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1046,  3543,
          3363,  2428,    28,   338,   506,  1701,   392,  4935,  1535,    28,
         24651,    47,   216,  4651,   665,   314,   253,  1194,  1176,  1187,
           260,  7280,    30,     0,     2,   198,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531

Stage2 Epoch 1:  20%|█▉        | 40/205 [00:53<03:31,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 55, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 79, 2048]), dtype=torch.float32, min=-4.3039, max=4.2833, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 79, 2048]), dtype=torch.float32, min=-54.5700, max=48.9987, mean=0.0075
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 55]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 29503,   732,
           506,   451,    47,   216,  1350,   339,   685,    28,   339,  1326,
           982,   699,    30,   216,  2306,  3935,   423,     0,     2,   198,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389

Stage2 Epoch 1:  20%|██        | 41/205 [00:54<03:31,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 61, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 85, 2048]), dtype=torch.float32, min=-4.4919, max=4.5308, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 85, 2048]), dtype=torch.float32, min=-54.5700, max=49.8794, mean=0.0054
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 61]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1326,
           982,   699,    30,     0,     2,   198,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  40

Stage2 Epoch 1:  20%|██        | 42/205 [00:56<03:31,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 52, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 76, 2048]), dtype=torch.float32, min=-4.4142, max=4.3526, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 76, 2048]), dtype=torch.float32, min=-54.5700, max=49.0428, mean=0.0098
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 52]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,   338,
           506,  1942,   282,  6243,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           2

Stage2 Epoch 1:  21%|██        | 43/205 [00:57<03:30,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 43, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-4.3242, max=4.5519, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-54.5700, max=49.9541, mean=0.0035
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 43]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,   338,
           506,  1109,    30,  3361,  2910,  1137,   392,   416,  1485,   281,
           351,   874,  1163,   585,   392,  1658,   578,   282,  2478,    30,
             0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  2838,    30,
     

Stage2 Epoch 1:  21%|██▏       | 44/205 [00:58<03:29,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 50, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-4.2787, max=4.2666, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-54.5700, max=49.2638, mean=0.0050
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 50]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1848,   506,
           732,  5901,   702,    28,  1048,    47,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198

Stage2 Epoch 1:  22%|██▏       | 45/205 [01:00<03:49,  1.43s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 37, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-4.3962, max=4.4728, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-54.5700, max=49.1904, mean=0.0052
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 37]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1385,  4775,
           982,   750,   701,   614, 40404,    30,   216,   657,   506,   253,
         20847,   506, 18373,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,   339,
           737,   346,   288,   536,   314,   685,   288

Stage2 Epoch 1:  22%|██▏       | 46/205 [01:01<03:38,  1.38s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 62, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 86, 2048]), dtype=torch.float32, min=-4.8581, max=4.3848, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 86, 2048]), dtype=torch.float32, min=-54.5700, max=51.3680, mean=0.0091
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 62]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,   699,
           732,   339,  1441,    47,   216,  1626,  1138,   338,    28,   288,
          2635,   338,  4581,   260,  2591,  1163,    28,   702,   634,  1942,
           282, 12248,   284,  2573,   736,  1407,   357,  6479,  2893,  1272,
           284,   357,   736,   919,   253,  3193,   288,  1272,    30,     0,
             2,   198],
        [   

Stage2 Epoch 1:  23%|██▎       | 47/205 [01:03<03:34,  1.36s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 53, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-4.4937, max=4.4726, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-54.5700, max=49.7728, mean=0.0067
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 53]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  3543,
         13197,  2025,  2221,    30,   216,   339,  5248,   441,  1535, 13532,
            30,     0,     2,   198,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
     

Stage2 Epoch 1:  23%|██▎       | 48/205 [01:04<03:30,  1.34s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 43, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-4.3123, max=4.4144, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-54.5700, max=48.7378, mean=0.0036
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 43]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,    47,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1069,  2316,
     

Stage2 Epoch 1:  24%|██▍       | 49/205 [01:05<03:26,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 58, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 82, 2048]), dtype=torch.float32, min=-4.4181, max=4.3883, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 82, 2048]), dtype=torch.float32, min=-54.5700, max=50.0118, mean=0.0039
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 58]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1839,    28,
           357,   506,   346,   699,    28,   392,  3543,  4957,   588,  1083,
           655,   351,   601,   357,   506,   915,  2159, 33275,    28,  2159,
         19774,   338,   392,  2316,   441,  2045,   288,   325,  1730,   288,
           963,   874, 13532,   284,   685,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137

Stage2 Epoch 1:  24%|██▍       | 50/205 [01:07<03:24,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 41, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-4.5225, max=4.1853, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-54.5700, max=49.6082, mean=0.0005
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 41]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  3763,
           982,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 25606,   982,
           357,   2

Stage2 Epoch 1:  25%|██▍       | 51/205 [01:08<03:21,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 57, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-4.3578, max=4.6817, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-54.5700, max=49.1229, mean=0.0086
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 57]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  2018,  1048,
            30,   216,  3756,    30,   216,  3929,   392,   457,   253,  1035,
          1440,  6688,   338,  3711,   418,  7462,   253,    30,    93,    30,
           284, 46651,   326,    88,   260,   655,  1363,  2390,  3935, 26969,
         22271,   429,   338,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281

Stage2 Epoch 1:  25%|██▌       | 52/205 [01:09<03:20,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 44, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 68, 2048]), dtype=torch.float32, min=-4.4166, max=4.6795, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 68, 2048]), dtype=torch.float32, min=-54.5700, max=49.1759, mean=0.0041
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 44]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,   416,
           963,   357,   281,   469,  3497,    30,   216,  1206,  1277,   288,
           685,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506

Stage2 Epoch 1:  26%|██▌       | 53/205 [01:10<03:17,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 42, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 66, 2048]), dtype=torch.float32, min=-4.4543, max=4.4499, mean=-0.0001
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 66, 2048]), dtype=torch.float32, min=-54.5700, max=51.0116, mean=0.0080
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 42]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1073,   359,
           346,  2045,   288,  2290,   327,  1125,    47,  9725,    29,  1250,
           346,   423,    47,     0,     2,   198,     0,     0,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 12947,    28,
          18

Stage2 Epoch 1:  26%|██▋       | 54/205 [01:12<03:14,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 60, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 84, 2048]), dtype=torch.float32, min=-4.3075, max=4.5113, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 84, 2048]), dtype=torch.float32, min=-54.5700, max=49.0094, mean=0.0034
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 60]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   588,   357,
           314,   253,  1838,   253,  1838,  1759,   281,   338,  3223,   585,
           346,  1277,  3009,   346,   699,   346,   416,   731,   382,    29,
           392,   416,    29,   392,   416,  2843,  1372,  5186,  3421,    30,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780

Stage2 Epoch 1:  27%|██▋       | 55/205 [01:13<03:14,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 49, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 73, 2048]), dtype=torch.float32, min=-4.5242, max=4.3736, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 73, 2048]), dtype=torch.float32, min=-54.5700, max=49.1290, mean=0.0051
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 49]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
           441,   469, 39527,    30,     0,     2,   198,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1

Stage2 Epoch 1:  27%|██▋       | 56/205 [01:14<03:12,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 60, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 84, 2048]), dtype=torch.float32, min=-4.3986, max=4.2997, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 84, 2048]), dtype=torch.float32, min=-54.5700, max=50.1260, mean=0.0088
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 60]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
          3363,   338,   563,   357,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780

Stage2 Epoch 1:  28%|██▊       | 57/205 [01:16<03:12,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 51, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 75, 2048]), dtype=torch.float32, min=-4.1995, max=4.2603, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 75, 2048]), dtype=torch.float32, min=-54.5700, max=48.5716, mean=0.0043
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 51]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 16222,   384,
            29,   255,   506,  8947,   915,   702,   384,  1250,   990,   384,
          4242,    28,   346,   699,    47,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 491

Stage2 Epoch 1:  28%|██▊       | 58/205 [01:17<03:10,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 42, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 66, 2048]), dtype=torch.float32, min=-4.1056, max=4.4817, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 66, 2048]), dtype=torch.float32, min=-57.4151, max=49.5830, mean=0.0049
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 42]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  8278,   443,
           269,   338,   506,  5485,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1848,   506,
          49

Stage2 Epoch 1:  29%|██▉       | 59/205 [01:18<03:07,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 45, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 69, 2048]), dtype=torch.float32, min=-4.3682, max=4.5089, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 69, 2048]), dtype=torch.float32, min=-54.5700, max=49.8456, mean=0.0058
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 45]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,    28,
           585,   346, 18918,   982,   719,   588,  8166,   105,   284, 19882,
            28,   346,  2093,   736,   457,  3984,   563,   357,    30,     0,
             2,   198,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339

Stage2 Epoch 1:  29%|██▉       | 60/205 [01:19<03:05,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 74, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 98, 2048]), dtype=torch.float32, min=-4.5716, max=4.3983, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 98, 2048]), dtype=torch.float32, min=-54.5700, max=49.2547, mean=0.0037
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 74]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   450,  2243,
            28, 13731,   418,   549,   564,   339,   523,    30,   216,  4651,
           346,   699,   732,  5024,   645,  1041,  2422,  1535,    47,   216,
          1550,  1041,  2422,  1056,  1041, 31633,   281,   650,  3589,    28,
           260, 17278, 10764,   618,  4443,    30, 16632,    28,  1492,    30,
          8432,   418,   357,    30,

Stage2 Epoch 1:  30%|██▉       | 61/205 [01:21<03:07,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 63, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 87, 2048]), dtype=torch.float32, min=-4.4489, max=4.4445, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 87, 2048]), dtype=torch.float32, min=-54.5700, max=50.6151, mean=0.0043
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 63]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
           441,   253,  3516,   357,   506,   253,  3116,    30,   216,  7903,
           346,   915,  1188,   357,    47,   216, 24282,   335,    30,     0,
             2,   198,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
     

Stage2 Epoch 1:  30%|███       | 62/205 [01:22<03:06,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 41, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-4.3156, max=4.3536, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-54.5700, max=49.6398, mean=0.0058
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 41]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   638,   436,
          6217,    47,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,   338,
           506,  31

Stage2 Epoch 1:  31%|███       | 63/205 [01:23<03:04,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 35, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 59, 2048]), dtype=torch.float32, min=-4.5022, max=4.7118, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 59, 2048]), dtype=torch.float32, min=-54.5700, max=49.4859, mean=0.0011
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 35]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1277,
           346,  1209,    28, 36779,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1068,  1345,
            47,     0,     2,   198,     0,     0,     0,     0,     0

Stage2 Epoch 1:  31%|███       | 64/205 [01:25<03:01,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 35, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 59, 2048]), dtype=torch.float32, min=-4.2952, max=4.3029, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 59, 2048]), dtype=torch.float32, min=-54.5700, max=49.9567, mean=0.0041
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 35]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
           915,  2698,    28,   357,   506,  3878,   288,   820,  5966, 32982,
           690,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
         14867,    30,     0,     2,   198,     0,     0,     0,     0

Stage2 Epoch 1:  32%|███▏      | 65/205 [01:26<02:58,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 69, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 93, 2048]), dtype=torch.float32, min=-4.7178, max=4.3311, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 93, 2048]), dtype=torch.float32, min=-54.5700, max=50.8546, mean=0.0071
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 69]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  2838,    30,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  32%|███▏      | 66/205 [01:27<02:59,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 65, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-4.5031, max=4.4582, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-54.5700, max=50.2127, mean=0.0093
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 65]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812, 39527,
            47,     0,     2,   198,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  33%|███▎      | 67/205 [01:29<02:59,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 53, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-4.5836, max=4.6209, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-54.5700, max=49.1202, mean=0.0046
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 53]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1326,
           982,  1277,   288,   325, 30879,  1535,    28, 12102,    23,   332,
            28,   564,  1761,   827,   389,   314,   837,   346,   737,   288,
           325,    30,   216,   669,   314,   260,  1761,   327, 13670,    30,
             0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
     

Stage2 Epoch 1:  33%|███▎      | 68/205 [01:30<02:57,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 65, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-4.4577, max=4.3847, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-54.5700, max=49.9499, mean=0.0044
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 65]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,    28,
         10668,   288,   260,  1205,  4677,    30,   216,  3315,   346,  1510,
           451,   314,   732,   339,   761,   281,  1945,   645,   339,  5433,
            47,   216,  1848,  1876,   929,  1187,   260,  3398,    28,   392,
          6737,   325,   418,   260, 10724, 45734,   418,   971,   550,   690,
          2698,    47,     0,     2,

Stage2 Epoch 1:  34%|███▎      | 69/205 [01:31<02:57,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 26, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 50, 2048]), dtype=torch.float32, min=-4.3988, max=4.6160, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 50, 2048]), dtype=torch.float32, min=-54.5700, max=50.8773, mean=0.0082
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 26]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22096,    30,
             0,     2,   198,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1848,   339,
          5248, 22657,    47,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           

Stage2 Epoch 1:  34%|███▍      | 70/205 [01:32<02:53,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 59, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 83, 2048]), dtype=torch.float32, min=-4.7547, max=4.4918, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 83, 2048]), dtype=torch.float32, min=-54.5700, max=50.0569, mean=0.0050
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 59]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 34763,    30,
           216,  1644,   585,   346,   416,   915,  1928,   549,   469,  1462,
            28,   469,  1462,   284,   469,  2014,   284,   469,  5460,  1230,
           339,   416,  2770,   346,   347,  3318,   347,   392,  1042,   578,
           540,  1096,    30,     0,     2,   198,     0,     0,     0],
        [    1,  4093,   198,  1780,   314

Stage2 Epoch 1:  35%|███▍      | 71/205 [01:34<02:54,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 53, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-4.5160, max=4.5474, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-54.5700, max=48.7118, mean=0.0048
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 53]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  3984,
           357,   506,   588,  9228,   288,   820,   618,    30,     0,     2,
           198,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
     

Stage2 Epoch 1:  35%|███▌      | 72/205 [01:35<02:52,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 52, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 76, 2048]), dtype=torch.float32, min=-4.2447, max=4.3129, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 76, 2048]), dtype=torch.float32, min=-54.5700, max=48.8030, mean=0.0067
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 52]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22096,   588,
            29,  1848,  3009,  5907,  1250,   685,   738,    30,     0,     2,
           198,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           2

Stage2 Epoch 1:  36%|███▌      | 73/205 [01:36<02:51,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 55, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 79, 2048]), dtype=torch.float32, min=-4.6945, max=4.2444, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 79, 2048]), dtype=torch.float32, min=-54.5700, max=48.8083, mean=0.0042
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 55]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  8765,   915,
          1830,   260,   910,   281,    28,  1188,   260,  4177,    28,  1303,
           549,   820,   957,  9768,   284,   820,   260, 20031,   578,   282,
          1535,    30,   216,  8765,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389

Stage2 Epoch 1:  36%|███▌      | 74/205 [01:38<02:49,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 48, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 72, 2048]), dtype=torch.float32, min=-4.4347, max=4.3917, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 72, 2048]), dtype=torch.float32, min=-54.5700, max=50.0152, mean=0.0052
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 48]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,   339,
          3060,  1690,  1056,    30,   216,  1812,   314,   338,  7556,   288,
          1441,    47,   216,  1206,  2316,   441,  6819,   288,  5847,   327,
           549,   288,  1690,  1056,    47,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520

Stage2 Epoch 1:  37%|███▋      | 75/205 [01:39<02:47,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 29, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 53, 2048]), dtype=torch.float32, min=-4.4941, max=4.5660, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 53, 2048]), dtype=torch.float32, min=-54.5700, max=49.1757, mean=0.0054
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 29]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,   282,
           357,    47,     0,     2,   198,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3305,    28,
           346,   699,    28, 34408,  7576,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  113

Stage2 Epoch 1:  37%|███▋      | 76/205 [01:40<02:44,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 57, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-4.2251, max=4.3086, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-54.5700, max=48.9848, mean=0.0063
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 57]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1812,   536,
           346,  1441,    47,   216,   909,  6679,   614,  1867,  6644,   288,
           260,  8180,   511,   650,  1029,    30,   216,  4896, 10606,   982,
           384,  1277,   288,  3009,  1701, 10606,   982,   384,  1277,   288,
           963,   874,  1163,    47,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281

Stage2 Epoch 1:  38%|███▊      | 77/205 [01:41<02:44,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 74, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 98, 2048]), dtype=torch.float32, min=-4.8574, max=4.6022, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 98, 2048]), dtype=torch.float32, min=-54.5700, max=49.6960, mean=0.0053
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 74]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1441,
            28,   314,   665, 17101,   346,   416,   946,    47,   216,  1816,
           260, 18904,   338,  1830,   260,  7386,   335,   260,   423,   258,
           260,  8303,    30,   216,   339,  1441,    28,   359,   502,   423,
         11518,  1251,   325,  2159, 44250,   288,   457,   915, 15784,   957,
         45635,    30,   216,   339,

Stage2 Epoch 1:  38%|███▊      | 78/205 [01:43<02:45,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 50, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-4.3336, max=4.3921, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-54.5700, max=48.1333, mean=0.0013
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 50]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,   699,
          1701,   339, 11596, 36779,    28,  1048,    47,     0,     2,   198,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198

Stage2 Epoch 1:  39%|███▊      | 79/205 [01:44<02:43,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 53, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-4.3100, max=4.5840, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-54.5700, max=48.8881, mean=0.0059
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 53]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 28844,   614,
            30,  6910,   957,  1839,    30,   216, 27434,    47,  3305,   702,
           338,    47,   216,  1206,   915,  3413,   260,  1962,   284,   732,
          2026,  6778,  1250,   346,  1998,    47,     0,     2,   198,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
     

Stage2 Epoch 1:  39%|███▉      | 80/205 [01:45<02:43,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 73, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 97, 2048]), dtype=torch.float32, min=-4.5826, max=4.3071, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 97, 2048]), dtype=torch.float32, min=-54.5700, max=49.7785, mean=0.0037
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 73]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  5248,
           915,  4976,   346,  1326,   982,   737,   288,  3015,   288,   549,
           702,   339,    29,     0,     2,   198,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  40%|███▉      | 81/205 [01:47<02:44,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 39, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 63, 2048]), dtype=torch.float32, min=-4.2656, max=4.3223, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 63, 2048]), dtype=torch.float32, min=-54.5700, max=49.4584, mean=0.0041
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 39]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,    28,
           339,   744,  3039,   702,   338,    30,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1249,   585,
           339,   416,   982,   457,   338

Stage2 Epoch 1:  40%|████      | 82/205 [01:48<02:40,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 41, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-4.6168, max=4.8049, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 65, 2048]), dtype=torch.float32, min=-54.5700, max=48.1333, mean=0.0062
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 41]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,   416,
           982,  1012,   874,  1945,    30,     0,     2,   198,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 17866,  7732,
          1870,    

Stage2 Epoch 1:  40%|████      | 83/205 [01:49<02:38,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 34, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 58, 2048]), dtype=torch.float32, min=-4.4076, max=4.5289, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 58, 2048]), dtype=torch.float32, min=-54.5700, max=49.1149, mean=0.0009
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 34]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1385,   506,
           787,   737,   327,   346,   288,  3467,   469,  5274,    28, 38061,
            30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1073,   536,
           346,   699,   384,   506,   908,  3039,   563,   357,    47,     0

Stage2 Epoch 1:  41%|████      | 84/205 [01:51<02:35,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 75, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 99, 2048]), dtype=torch.float32, min=-4.3853, max=4.4948, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 99, 2048]), dtype=torch.float32, min=-54.5700, max=49.8117, mean=0.0065
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 75]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 15753, 38673,
          6392,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  41%|████▏     | 85/205 [01:52<02:36,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 49, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 73, 2048]), dtype=torch.float32, min=-4.3410, max=4.4598, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 73, 2048]), dtype=torch.float32, min=-54.5700, max=49.0570, mean=0.0064
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 49]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1441,
            28,   346,   699,    28,   418,  2124,   384,   506,   441,   281,
         13532,  1556,  1209,    28,  1048,    47,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1

Stage2 Epoch 1:  42%|████▏     | 86/205 [01:53<02:34,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 62, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 86, 2048]), dtype=torch.float32, min=-4.6999, max=4.7061, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 86, 2048]), dtype=torch.float32, min=-54.5700, max=48.5135, mean=0.0042
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 62]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   669,   732,
            28,   732,   314,   451,    47,   216,   669,  3247,   982,   908,
          3534,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   

Stage2 Epoch 1:  42%|████▏     | 87/205 [01:54<02:33,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 54, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 78, 2048]), dtype=torch.float32, min=-4.5043, max=4.6239, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 78, 2048]), dtype=torch.float32, min=-54.5700, max=50.9009, mean=0.0082
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 54]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,    28,
          1209,   338,   346,  3311,   357,    30,   216,  2838,    28,   339,
          1326,   982,    30,     0,     2,   198,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47

Stage2 Epoch 1:  43%|████▎     | 88/205 [01:56<02:32,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 46, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-4.3661, max=4.4560, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-54.5700, max=49.3927, mean=0.0047
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 46]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 34763,    30,
         34763,    28, 22939,    30,     0,     2,   198,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198

Stage2 Epoch 1:  43%|████▎     | 89/205 [01:57<02:29,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 37, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-4.5672, max=4.4656, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-54.5700, max=51.1264, mean=0.0081
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 37]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  3680,
           338,    30,   216,  1350,   876,   346,  3543,   719,  3009,     0,
             2,   198,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  9905,
           732,   339,   310, 27893,  9905,    30,   216

Stage2 Epoch 1:  44%|████▍     | 90/205 [01:58<02:27,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 47, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 71, 2048]), dtype=torch.float32, min=-4.1645, max=4.5015, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 71, 2048]), dtype=torch.float32, min=-54.5700, max=49.8158, mean=0.0095
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 47]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22096,    28,
         33662,    30,     0,     2,   198,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531

Stage2 Epoch 1:  44%|████▍     | 91/205 [02:00<02:26,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 58, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 82, 2048]), dtype=torch.float32, min=-4.7211, max=4.3850, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 82, 2048]), dtype=torch.float32, min=-54.5700, max=49.7864, mean=0.0012
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 58]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  3543,
          1836,   357,  1296,   929,   282,  2275,    30,   216,   339,  2275,
           338,   585,   339, 25693, 34359,   736,  6616,   563, 29810,   284,
           392,   856,   457,   253,  2649, 18338,   284,  3117,  5587,    30,
             0,     2,   198,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137

Stage2 Epoch 1:  45%|████▍     | 92/205 [02:01<02:25,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 36, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-4.4128, max=4.2452, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 60, 2048]), dtype=torch.float32, min=-54.5700, max=49.4683, mean=0.0021
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 36]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22096,    28,
          5196,   585,   346,   592,   351,  2206,  1745,  1147,    30,     0,
             2,   198,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22096,    28,
           564,   346,   592,   665,   511,   260,  5625,  2009

Stage2 Epoch 1:  45%|████▌     | 93/205 [02:02<02:23,  1.28s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 33, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 57, 2048]), dtype=torch.float32, min=-4.6614, max=4.4978, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 57, 2048]), dtype=torch.float32, min=-54.5700, max=50.0424, mean=0.0044
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 33]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   909,   436,
           578,  1535,   645,   357, 10764,    30,     0,     2,   198,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
           990, 20654,    30,  2959,   506,   915,   685,  1478,    30,     0,
     

Stage2 Epoch 1:  46%|████▌     | 94/205 [02:03<02:21,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 43, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-4.4695, max=4.5772, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 67, 2048]), dtype=torch.float32, min=-54.5700, max=48.6182, mean=0.0060
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 43]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,   699,
           357,   506,   441,   915,   957,  1861,    30,     0,     2,   198,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  6910,    28,
     

Stage2 Epoch 1:  46%|████▋     | 95/205 [02:05<02:20,  1.27s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 51, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 75, 2048]), dtype=torch.float32, min=-4.5167, max=4.5178, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 75, 2048]), dtype=torch.float32, min=-54.5700, max=48.5858, mean=0.0054
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 51]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1385,   314,
           338,    28, 24368,   787,   338,   506,   915, 30702,    28,  3247,
           982,   357,    47,     0,     2,   198,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 491

Stage2 Epoch 1:  47%|████▋     | 96/205 [02:06<02:29,  1.37s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 76, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 100, 2048]), dtype=torch.float32, min=-4.2847, max=4.3495, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 100, 2048]), dtype=torch.float32, min=-54.5700, max=49.3172, mean=0.0047
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 76]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,   506,
          1759,  2001,  2221,    28, 20165,    28,   339,   416,   982,    30,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     

Stage2 Epoch 1:  47%|████▋     | 97/205 [02:08<02:26,  1.36s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 52, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 76, 2048]), dtype=torch.float32, min=-4.3623, max=4.4763, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 76, 2048]), dtype=torch.float32, min=-54.5700, max=49.1641, mean=0.0059
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 52]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 38699,    47,
           216, 38699,    47,   216,  1812,    47,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           2

Stage2 Epoch 1:  48%|████▊     | 98/205 [02:09<02:23,  1.34s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 66, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 90, 2048]), dtype=torch.float32, min=-4.4575, max=4.4571, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 90, 2048]), dtype=torch.float32, min=-54.5700, max=48.8548, mean=0.0067
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 66]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   407,  6640,
            30,   216,  4184,   346,  4011,    47,   216,  3315,   346,  1277,
           957, 29111,    47,   216,  2838,    47,   216,   346,   699,    28,
           392,   868,   457,  3621,   260, 21714,    47,   216,  3698, 21714,
            30,     0,     2,   198,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  48%|████▊     | 99/205 [02:10<02:21,  1.34s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 31, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 55, 2048]), dtype=torch.float32, min=-4.5981, max=4.3901, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 55, 2048]), dtype=torch.float32, min=-54.5700, max=49.1303, mean=0.0110
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 31]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 22297,   253,
          8427,    30,   216,   339,  3543,  3363,   357,    30,     0,     2,
           198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  2838,    30,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
             0],
  

Stage2 Epoch 1:  49%|████▉     | 100/205 [02:12<02:17,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 46, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-4.4081, max=4.5667, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-54.5700, max=48.5129, mean=0.0036
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 46]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3513,  1740,
            30,  1094,   665,   506,  3534,   339,   416,   536,    28,   346,
           915,  1303,   549,   699,    30,   216, 34763,    47,     0,     2,
           198,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198

Stage2 Epoch 1:  49%|████▉     | 101/205 [02:13<02:15,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 66, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 90, 2048]), dtype=torch.float32, min=-4.5022, max=4.2254, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 90, 2048]), dtype=torch.float32, min=-54.5700, max=48.1333, mean=0.0058
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 66]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339,  1326,
           982,   699,  1701,   357,   314,   564,   897,  2256,   339,  3470,
           578,   327,  1488,   339,    29,   338,   339,  1277,    28,   339,
           457,   288,  6536,  1056,   975,   339,  1124,  9285, 17101,  1745,
            30,   216,  3361,  2444, 23271,  1029,    43,   655,   990,   655,
           990,   655,    30,     0,

Stage2 Epoch 1:  50%|████▉     | 102/205 [02:14<02:14,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 53, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-4.5559, max=4.4547, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 77, 2048]), dtype=torch.float32, min=-54.5700, max=51.2673, mean=0.0065
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 53]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 34763, 34763,
            28,  1350, 43376,    29,  5073,   284,  2966,   284,  7386,    30,
           216,   514,    88,    29,  1978,   346,   731,  1978,   346,   915,
          1502,   281,  6806,  2014,  3179,    47,  1350,  5380,  4137,    30,
             0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
     

Stage2 Epoch 1:  50%|█████     | 103/205 [02:15<02:13,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 37, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-4.2855, max=4.4347, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 61, 2048]), dtype=torch.float32, min=-54.5700, max=48.7101, mean=0.0037
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 37]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1046, 11535,
           982,   281,   253,  1123,  4897,    30,     0,     2,   198,     0,
             0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   357,   506,
           702,   339,   699,   347,   986,   347,  1041

Stage2 Epoch 1:  51%|█████     | 104/205 [02:17<02:10,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 50, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-4.2525, max=4.5257, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-56.4467, max=49.2742, mean=0.0061
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 50]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   933,    50,
          3256, 12455,  4728,    77,   339,  1441,    28,   665,   506,  3878,
          2159,   288,  3015,   563,    30,   216,   909,   506,  7270,   284,
           339,  3060,  2093,   963,  1272,  1163,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198

Stage2 Epoch 1:  51%|█████     | 105/205 [02:18<02:08,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 50, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-4.5295, max=4.3484, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-54.5700, max=50.5807, mean=0.0063
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 50]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   339, 21394,
           502,   736,    29,   469,  3379,   418,  2124,    30,     0,     2,
           198,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198

Stage2 Epoch 1:  52%|█████▏    | 106/205 [02:19<02:07,  1.29s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 59, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 83, 2048]), dtype=torch.float32, min=-4.5725, max=4.6533, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 83, 2048]), dtype=torch.float32, min=-54.5700, max=49.5150, mean=0.0042
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 59]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198, 34763,    30,
          1812,   416,   392,   536,   314,   392,   416,   820,   469,  1462,
           284,  1096,   284,   585,   357,  2744,   614,   392,   416,  5091,
           357,   288,   346,   588,   423,  7655,   331,   346,  2951,     0,
             2,   198,     0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314

Stage2 Epoch 1:  52%|█████▏    | 107/205 [02:21<02:06,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 65, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-4.5449, max=4.2780, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 89, 2048]), dtype=torch.float32, min=-54.5700, max=48.3770, mean=0.0035
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 65]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,   699,
           732,    47,   216,   339,  3543,   719,  1891,   418, 10321,  5152,
           327,   260,  1896,  2976,  2704,    30,   216,  1206,   699,   638,
         38793,   357,   314,    47,     0,     2,   198,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,

Stage2 Epoch 1:  53%|█████▎    | 108/205 [02:22<02:06,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 57, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-4.3010, max=4.5115, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 81, 2048]), dtype=torch.float32, min=-54.5700, max=49.9567, mean=0.0077
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 57]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,   288,
         20031,   351,   338,    30,     0,     2,   198,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281

Stage2 Epoch 1:  53%|█████▎    | 109/205 [02:23<02:05,  1.31s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 46, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-4.4436, max=4.4058, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 70, 2048]), dtype=torch.float32, min=-54.5700, max=48.9445, mean=0.0070
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 46]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  3929,    30,
             0,     2,   198,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198

Stage2 Epoch 1:  54%|█████▎    | 110/205 [02:24<02:03,  1.30s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 89, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 113, 2048]), dtype=torch.float32, min=-4.2963, max=4.2912, mean=-0.0004
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 113, 2048]), dtype=torch.float32, min=-54.5700, max=49.6388, mean=0.0065
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 89]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   284,   339,
           436, 15662,   957,  4416,   578,   284,   357,   436,   702,    28,
         24368,   957,  1839,    30,   216,  1644,   339,   915,   761,   288,
          9199,   260,  1739,   588,  1209,   502,  2316,   702,    28,  1701,
           359,   346,   354,  4665,  1739,    28,   284,  1701,   359,   346,
          1891,    28,  3396,   28

Stage2 Epoch 1:  54%|█████▍    | 111/205 [02:26<02:04,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 50, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-4.4236, max=4.2071, mean=-0.0003
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 74, 2048]), dtype=torch.float32, min=-54.5700, max=48.8082, mean=0.0085
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 50]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  2316,
           441,  1625,   451,   750,  4285,   327,   549,    28,   346,   699,
            30,   216,   339,  5248,   260,   582,   617,   553,   288,   685,
           284,   346,  2316,  2159,   441,  4307,    30,     0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198

Stage2 Epoch 1:  55%|█████▍    | 112/205 [02:27<02:02,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 33, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])
Debug(AudioLanguageModel): Input Embeds = 
: shape=torch.Size([4, 57, 2048]), dtype=torch.float32, min=-4.5374, max=4.3747, mean=-0.0002
Debug(AudioLanguageModel): Decoder Output Embeds = 
: shape=torch.Size([4, 57, 2048]), dtype=torch.float32, min=-54.5700, max=50.2288, mean=0.0063
Debug(AudioLanguageModel): input_ids:  torch.Size([4, 33]) input_ids:  tensor([[    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,  1206,  1303,
          1272, 21111,   346,    30,   216,  1206,  1137,   346,  1250,    30,
             0,     2,   198],
        [    1,  4093,   198,  1780,   314,  1137,   281,   451,  8389,    47,
           216, 49152,     2,   198,     1,   520,  9531,   198,   657,  2761,
           982,  2395,   288,   746,   715, 13278,   351,   346,    30,     0,
     

Stage2 Epoch 1:  55%|█████▌    | 113/205 [02:29<02:01,  1.32s/it]

Debug(AudioLanguageModel): text_embeds:  torch.Size([4, 97, 2048])
Debug(AudioLanguageModel): audio_embeds:  torch.Size([4, 25, 2048])





KeyboardInterrupt: 

In [None]:
final_model = train_step3_instruction_tuning(train_cfg, alm_cfg, stage2_model, tokenizer, device)
final_model.save_pretrained("/content/")

As you can see the model trains, so feel free to play around with the architecture or data! Let us know what you build with it!

PS: If you want to test the model, check out generate.py to see how to do inference with it

## Test

In [None]:
!cp /content/drive/MyDrive/nanoALM/output_txt1.wav /content/output_txt1.wav
!cp /content/drive/MyDrive/nanoALM/output_txt2.wav /content/output_txt2.wav
!cp /content/drive/MyDrive/nanoALM/output_txt3.wav /content/output_txt3.wav

In [None]:
!cp ../model.safetensors /content/drive/MyDrive/nanoALM/model.safetensors
!cp ../config.json /content/drive/MyDrive/nanoALM/config.json
!cp ./model.safetensors /content/drive/MyDrive/nanoALM
!cp ./config.json /content/drive/MyDrive/nanoALM

In [None]:
# final_model.save_pretrained("/content/")
!python generate.py --checkpoint ../ --audio ../output_txt1.wav

In [None]:
!python generate.py --checkpoint ../ --audio ../output_txt2.wav

In [None]:
!python generate.py --checkpoint ../ --audio ../output_txt3.wav

In [None]:
!python generate.py --checkpoint ../ --audio ../output_txt3.wav

In [None]:
import librosa

messages = [
    {"role": "user", "content": "What is said in this audio? <AUDIO>"},
    {"role": "assistant", "content": ""},  # 讓模板能加上生成提示
]
full_input_ids = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True
)
full_input_ids = torch.tensor(full_input_ids, dtype=torch.long).to(device)
if full_input_ids.dim() == 1:
    full_input_ids = full_input_ids.unsqueeze(0)

generations = 5
max_new_tokens = 100

try:
    audio_array, sr = librosa.load("../output_txt1.wav", sr=16000)
    # 轉換為 torch tensor 並添加 batch 維度
    audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
    if audio_tensor.dim() == 1:
        audio_tensor = audio_tensor.unsqueeze(0)  # 添加 channel 維度

    # 使用 audio_processor 處理音頻
    ap = get_audio_processor(alm_cfg)
    audio_t = ap(audio_array, sr).unsqueeze(0).to(device)
except Exception as e:
    print(f"Error loading audio file: {e}")
    print("Please check if the audio file exists and is in a supported format.")

tested_model = final_model if 'final_model' in globals() and final_model is not None else stage2_model
tested_model.to(device).eval()

print("\nInput:\n ", messages[0]["content"], "\n\nOutputs:")
for i in range(generations):
    gen = tested_model.generate(
        full_input_ids,
        audio_t,
        max_new_tokens=max_new_tokens,
        greedy=True,
        top_k=10,
        top_p=0.8,
        temperature=0.3
    )
    out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
    print(f"  >> Generation {i+1}: {out}")