## Model

In [None]:
from asr_model import ASRModel

a, b, c = ASRModel(model_dim=768, mode = 'A').to('cuda'), ASRModel(model_dim=768, mode = 'B').to('cuda'), ASRModel(model_dim=768, mode = 'C').to('cuda')
a.params, b.params, c.params





(149466760, 149466760, 141578888)

In [3]:
import torch
audio_features = torch.randn(10, 20, 768).to('cuda')  # Example input tensor (seq_len, batch_size, acoustic_input_dim)
input_ids = torch.randint(0, 21128, (10, 35)).long().to('cuda')
attention_mask = torch.ones(10, 35).to('cuda')  # Example attention mask (batch_size, seq_len)

In [4]:
input_ids.shape, audio_features.shape, attention_mask.shape

(torch.Size([10, 35]), torch.Size([10, 20, 768]), torch.Size([10, 35]))

In [5]:
a_lala = a(input_ids = input_ids,
           attention_mask = None,
           audio_features = audio_features)
print(a_lala.shape)

torch.Size([10, 35, 21128])


In [6]:
b_lala = b(input_ids = input_ids,
           attention_mask = None,
           audio_features = audio_features)
print(b_lala.shape)

torch.Size([10, 35, 21128])


In [7]:
c_lala = c(input_ids = input_ids,
           attention_mask = None,
           audio_features = audio_features)
print(c_lala.shape)

torch.Size([10, 35, 21128])


# Train Function

In [5]:
import torch
from transformers import AutoTokenizer
from asr_model import ASRModel
from evaluate import load
cer = load("cer")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = ASRModel(model_dim=768, mode='A').to(device)

# Tokenzier
bos_token = tokenizer.cls_token_id
eos_token = tokenizer.sep_token_id
pad_token = tokenizer.pad_token_id

# Example text input and features
text = ["而 对 楼市 成交 抑制 作用 最 大 的 限 购", "也 成为 地方 政府 的 眼中 钉", "自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后", "各地 政府 便 纷纷 跟进"]
input_ids = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
decoded_text = tokenizer.batch_decode(input_ids["input_ids"])
print(f'Original text: {text}')
print(f'Tokenized text: {decoded_text}')

downsampled_features = torch.rand(4, 128, 205, 80) # (B, C, T, F) -> (B, T, C*F)
batch = {
    'input_ids': input_ids.input_ids,
    'attention_mask': input_ids.attention_mask,
    'downsampled_features': downsampled_features,
}
for k in batch:
    batch[k] = batch[k].to(device=device, non_blocking=True)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


Original text: ['而 对 楼市 成交 抑制 作用 最 大 的 限 购', '也 成为 地方 政府 的 眼中 钉', '自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后', '各地 政府 便 纷纷 跟进']
Tokenized text: ['[CLS] 而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购 [SEP] [PAD] [PAD] [PAD]', '[CLS] 也 成 为 地 方 政 府 的 眼 中 钉 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', '[CLS] 自 六 月 底 呼 和 浩 特 市 率 先 宣 布 取 消 限 购 后 [SEP]', '[CLS] 各 地 政 府 便 纷 纷 跟 进 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']


In [6]:
input_ids = batch['input_ids']  # Batch_size, seq_length
downsampled_features = batch['downsampled_features']  # Batch_size, seq_length, feature_dim
attention_mask = batch['attention_mask']  # Batch_size, seq_length

shifted_left_outputs = torch.cat([input_ids[:, 1:], torch.full((input_ids.size(0), 1), tokenizer.pad_token_id, dtype=torch.long, device=device)], dim=1)
tokenizer.batch_decode(shifted_left_outputs)

['而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购 [SEP] [PAD] [PAD] [PAD] [PAD]',
 '也 成 为 地 方 政 府 的 眼 中 钉 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '自 六 月 底 呼 和 浩 特 市 率 先 宣 布 取 消 限 购 后 [SEP] [PAD]',
 '各 地 政 府 便 纷 纷 跟 进 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']

In [7]:
B, C, T, F = downsampled_features.shape
outputs = model(input_ids, attention_mask, downsampled_features.view(B, T, C*F))  # Batch_size, seq_length, vocab_size
outputs.shape

torch.Size([4, 20, 21128])

In [8]:
# Compute CER and loss
ids_prediction = outputs.argmax(dim=-1)  # Batch_size, seq_length
predictions = tokenizer.batch_decode(ids_prediction, skip_special_tokens=True)  # Batch_size
references = tokenizer.batch_decode(shifted_left_outputs, skip_special_tokens=True)  # Batch_size

In [9]:
print(f'References: {references}')

References: ['而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购', '也 成 为 地 方 政 府 的 眼 中 钉', '自 六 月 底 呼 和 浩 特 市 率 先 宣 布 取 消 限 购 后', '各 地 政 府 便 纷 纷 跟 进']


In [10]:
print(f'Predictions: {predictions}')

Predictions: ['ex 闌 脳 筑晝還綑弈詡 鲑 a灬 helpapp 驍端 笆窗 錳 间ien', '##餡澆 蚤箍奘兰 蜆過崛 辻 畳med 1907锏锏 喟 焘 による 铿 第', '##しています 鈕 fx娄鋅 汶 繩 xddd丐 害将 惘撮訶のは realᅣvr 畳yo', '##琳尻 贻 蒞 捍 co2 淌 电槓 饰骁 飘炼 剣靴淼 hard 间 庠挾']


In [None]:
cer_score = cer.compute(predictions=predictions, references=references)
cer_score

In [15]:
outputs.shape, shifted_left_outputs.shape

(torch.Size([4, 20, 21128]), torch.Size([4, 20]))

In [31]:
def loss_fn(outputs, labels, cer_score, criterion, gamma=1.0, ignore_index=0):
    """
    Compute the weighted loss for the Decoder using alpha weights based on CER.

    Args:
        outputs (torch.Tensor): Model predictions of shape (batch_size, seq_length, vocab_size).
        labels (torch.Tensor): Ground truth labels of shape (batch_size, seq_length).
        cer_score (float): Character Error Rate (CER) for the Decoder.
        gamma (float): Hyperparameter to control the influence of CER in alpha computation.
        ignore_index (int): The padding value in the labels to be ignored in the loss calculation.

    Returns:
        torch.Tensor: The computed weighted loss.
    """
    # Ensure cer_score is a tensor
    cer_score = torch.tensor(cer_score, dtype=torch.float32, device=outputs.device)

    # Compute alpha weight based on CER
    alpha = -(cer_score ** gamma) * torch.log(1 - cer_score + 1e-8)  # Add epsilon to avoid log(0)

    # Compute the base loss
    base_loss = criterion(outputs.mT, labels)  # Flatten outputs and labels
    if base_loss < 1:
        try: 
            alpha = -(cer_score ** gamma) * torch.log(1 - cer_score + 1e-8)  # Add epsilon to avoid log(0)
            base_loss = base_loss * alpha  # Apply alpha weight to the loss
        except:
            base_loss = base_loss # Avoid cer_score is bigger than 1

    return base_loss

criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_token)
loss = loss_fn(outputs, shifted_left_outputs, cer_score, criterion)
loss

  cer_score = torch.tensor(cer_score, dtype=torch.float32, device=outputs.device)


tensor(10.0796, device='cuda:0', grad_fn=<NllLoss2DBackward0>)