In [1]:
!git clone https://github.com/Aria-K-Alethia/BigCodec.git
!wget https://huggingface.co/Alethia/BigCodec/resolve/main/bigcodec.pt

Cloning into 'BigCodec'...
remote: Enumerating objects: 72, done.[K
remote: Counting objects: 100% (72/72), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 72 (delta 14), reused 57 (delta 7), pack-reused 0 (from 0)[K
Receiving objects: 100% (72/72), 30.19 KiB | 572.00 KiB/s, done.
Resolving deltas: 100% (14/14), done.
--2025-03-28 12:17:15--  https://huggingface.co/Alethia/BigCodec/resolve/main/bigcodec.pt
Resolving huggingface.co (huggingface.co)... 18.164.174.17, 18.164.174.23, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.17|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/2e/d1/2ed1c2f0d3a6eeb0584a3dd54200d652d17330f9d54e66392623e9d0521364e3/1fba3806e87cc01c1a65bea22fa1becefbbf46881e4219593c4d9f3cf56206b9?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27bigcodec.pt%3B+filename%3D%22bigcodec.pt%22%3B&Expires=1743167835&Policy=eyJTdGF0ZW1lbnQiOlt7IkNv

In [1]:
import os
import sys
sys.path.append("BigCodec")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import librosa

import torchaudio
import torch

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)


device = torch.device("cuda")
token = ""
base_model = 'ksych/salt-asr-0.5b'

start_sequence_token = "<|im_start|>"
end_sequence_token = "<|im_end|>"
start_audio_token = "<|start_of_audio|>"
end_audio_token = "<|end_of_audio|>"


def resample(audio_data: torch.Tensor, sample_rate: int):
    print("Inout sample rate:", sample_rate)
    if sample_rate != 16000:
      audio_data = torch.tensor(
          librosa.resample(
              audio_data.cpu().detach().numpy(), orig_sr=sample_rate, target_sr=16000
          )
      )

    return audio_data.view(1, -1).float().to(device)


def get_audio_start_end_tokens(
    tokens: torch.Tensor,
    start_audio_token_id=None,
    end_audio_token_id=None,
):
    # find start index of audio tokens
    if start_audio_token_id is not None:
        start = torch.nonzero(tokens == start_audio_token_id)
        start = start[0, -1] + 1 if len(start) else 0
    else:
        start = 0

    # find end index of audio tokens
    if end_audio_token_id is not None:
        end = torch.nonzero(tokens == end_audio_token_id)
        end = end[0, -1] if len(end) else tokens.shape[-1]
    else:
        end = tokens.shape[-1]

    assert start < end, (
        f"Start of audio must be before end. Found: start - {start}, end - {end}"
    )

    return start, end


tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".", use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    cache_dir=".",
    attn_implementation="sdpa",
    device_map={"": 0},
    use_auth_token=token
)
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id





In [2]:
bos_id = tokenizer.convert_tokens_to_ids(start_sequence_token)
soa_id = tokenizer.convert_tokens_to_ids(start_audio_token)
eoa_id = tokenizer.convert_tokens_to_ids(end_audio_token)

bos = torch.tensor([[bos_id]], device=device)
soa = torch.tensor([[soa_id]], device=device)
eoa = torch.tensor([[eoa_id]], device=device)

In [3]:
from BigCodec.vq.codec_encoder import CodecEncoder
from BigCodec.vq.codec_decoder import CodecDecoder


class BigCodecTokenizer:
    def __init__(self, ckpt_path):
        ckpt = torch.load(ckpt_path, map_location="cpu")
        encoder = CodecEncoder()
        encoder.load_state_dict(ckpt["CodecEnc"])
        self.encoder = encoder.eval().cuda()

        decoder = CodecDecoder()
        decoder.load_state_dict(ckpt["generator"])
        self.decoder = decoder.eval().cuda()

    def encode(self, wav):
        vq_emb = self.encoder(wav.unsqueeze(1))
        _, vq_code, _ = self.decoder(vq_emb, vq=True)
        return vq_code


n_codebooks_tts = 1
n_codebooks_asr = 1
quantizer = BigCodecTokenizer("bigcodec.pt")
n_audio_tokens = quantizer.decoder.quantizer.layers[0].codebook_size,

  WeightNorm.apply(module, name, dim)


In [10]:
def decode_audio_bigcodec(
    tokens,
    quantizer,
    n_original_tokens,
    n_codebooks,
):
    # find audio start and end tokens
    start, end = get_audio_start_end_tokens(
        tokens, soa, eoa
    )

    # subtract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[start:end] % n_original_tokens
    audio_tokens = audio_tokens.reshape(1, -1, 1)

    with torch.no_grad():
      emb = quantizer.decoder.vq2emb(audio_tokens).transpose(1, 2)
      audio = quantizer.decoder(emb, vq=False).squeeze().detach().cpu()

    del tokens
    del audio_tokens
    torch.cuda.empty_cache()

    return audio, 16000


def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20, temperature=0.5):
    text_tokenized = tokenizer(text.lower(), return_tensors="pt")
    text_input_tokens = text_tokenized["input_ids"].to(device)

    text_tokens = torch.cat([bos, text_input_tokens, soa], dim=1)
    attention_mask = torch.ones(text_tokens.size(), device=device)

    output_audio_tokens = model.generate(
        text_tokens,
        attention_mask=attention_mask,
        max_new_tokens=max_seq_length,
        top_k=top_k,
        do_sample=True,
        temperature=temperature,
    )

    audio_signal = decode_audio_bigcodec(
        output_audio_tokens[0],
        quantizer,
        len(tokenizer),
        n_codebooks_tts
    )

    return audio_signal


def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20, temperature=0.9):
    audio_data, sample_rate = torchaudio.load(audio_path)

    audio = resample(audio_data, sample_rate)

    codes = quantizer.encode(audio.reshape(1, -1))
    raw_tokens = codes + len(tokenizer)
    audio_tokens = raw_tokens[:n_codebooks_asr].view(1, -1)
    tokens = torch.cat([bos, soa, audio_tokens, eoa], dim=1)
    print(tokens)

    audio, sr = decode_audio_bigcodec(tokens[0], quantizer, len(tokenizer), n_codebooks_asr)
    torchaudio.save("decoded_audio.wav", audio.unsqueeze(0), sr)

    attention_mask = torch.ones(tokens.size(), device=device)

    output_text_tokens = model.generate(
        tokens,
        attention_mask=attention_mask,
        do_sample=True,
        temperature=0.95,
        top_k=100,
        top_p=0.99,
        repetition_penalty=1.1,
    )
    output_text_tokens = output_text_tokens.cpu()[0]
    # output_text_tokens = output_text_tokens[output_text_tokens < soa]
    decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)

    return decoded_text



In [5]:
text = "count numbers from one to ten"
audio_signal, sr = infer_text_to_audio(text, model, tokenizer, quantizer, top_k=50, temperature=0.9)
torchaudio.save("male_loud_count_numbers.wav", audio_signal.unsqueeze(0), sr)


In [11]:
audio_path = "/content/asr_test.wav"
text = infer_audio_to_text(audio_path, model, tokenizer, quantizer, top_k=50, temperature=1.0)
print(text)


Inout sample rate: 24000
tensor([[151644, 151665, 157628, 156939, 152519, 153194, 157270, 156910, 152037,
         153581, 158721, 155582, 156218, 155833, 157150, 156359, 154447, 155112,
         156434, 158325, 158615, 154358, 156056, 155565, 156056, 158420, 155471,
         154715, 159834, 155854, 151715, 155448, 153194, 154783, 156801, 153033,
         157183, 152956, 152565, 152729, 154945, 152411, 156098, 158286, 159688,
         152543, 153603, 155098, 157600, 158263, 158403, 156827, 159649, 153978,
         156650, 154258, 155383, 156865, 158550, 157374, 156837, 155794, 154705,
         155973, 157996, 159826, 152960, 158220, 153517, 156679, 156494, 157823,
         154912, 153578, 158814, 159611, 152078, 152122, 158007, 154984, 158894,
         153597, 155189, 157745, 157718, 159509, 156695, 159019, 158308, 152618,
         153139, 151867, 155083, 153250, 154279, 154480, 156733, 158573, 157846,
         157470, 159103, 156127, 153680, 156996, 156375, 159431, 155412, 153497,
   