In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from safetensors.torch import load_file

import transformers
import torch
from torch.utils.data import Dataset
from datasets import load_dataset, Audio, Value

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
)

from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal

In [3]:
device = "cuda:0"
n_special_tokens = 3

In [4]:
model_path = "AlexWortega/nemo_asr_mixed"

tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", attn_implementation="eager", device_map={"":0})

tokenizer_config.json:   0%|          | 0.00/368k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/18.3M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/339 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/721 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.83G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.16G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.94 GiB. GPU 0 has a total capacity of 44.55 GiB of which 1.90 GiB is free. Process 79061 has 10.86 GiB memory in use. Process 93396 has 18.82 GiB memory in use. Including non-PyTorch memory, this process has 12.95 GiB memory in use. Of the allocated memory 12.69 GiB is allocated by PyTorch, and 2.45 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
class Vikhr4oDataset(Dataset):
    def __init__(self, dataset, tokenizer, quantizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

        self.soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
        self.eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
        self.eos = tokenizer(end_sequence_token, return_tensors="pt")["input_ids"][:, -1:].to(device)

        self.n_original_tokens = len(tokenizer) - 1024
        self.quantizer = quantizer 

    def __len__(self):
        return len(self.dataset)

    def quantize(self, example):
        audio_data, sample_rate = example["audio"]["array"], example["audio"]["sampling_rate"]
    
        # audio -> discrete codes
        audio = torch.tensor(audio_data).view(1, 1, len(audio_data)).float()
        audio = audio.to(device)
        codes = self.quantizer.encode(audio)
        codes = codes.squeeze(1)
    
        # Move tensor back to CPU and delete it to free GPU memory
        del audio
        torch.cuda.empty_cache()
    
        # increment tokens' ids 
        return codes + self.n_original_tokens

    def __getitem__(self, idx):
        row = self.dataset[idx]

        # get text tokens 
        text = row["text"]
        text_tokenized = self.tokenizer(text, return_tensors="pt")
        text_input_tokens = text_tokenized["input_ids"].to(device)

        # quantize audio 
        codes = self.quantize(row)
        raw_audio_tokens = codes[:n_codebooks]
        
        audio_input_tokens = raw_audio_tokens.contiguous().view(1, -1)

        # determine number of audio tokens given max_seq_length 
        audio_length = min(max_seq_length - text_input_tokens.shape[-1] - n_special_tokens, audio_input_tokens.shape[-1])
        audio_length -= audio_length % n_codebooks

        audio_tokens = torch.cat([self.soa, audio_input_tokens[:, :audio_length], self.eoa], dim=1)
        text_tokens = torch.cat([text_input_tokens, self.soa], dim=1)

        return {
            "audio_tokens": audio_tokens, 
            "text_tokens": text_tokens,
        }

In [None]:
def get_audio_padding_tokens(quantizer):
    # create audio without any sounds 
    # seems to work better than radom padding if 
    # length of generated audio is not devisible by n_codebooks
    audio = torch.zeros((1, 1, 1))
    audio = audio.to(device)
    
    codes = quantizer.encode(audio)

    # Move tensor back to CPU and delete it to free GPU memory
    del audio
    torch.cuda.empty_cache()
    
    return {"audio_tokens": codes.squeeze(1)}
    


def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
    # find start and end indices of audio tokens 
    start = torch.nonzero(tokens == start_audio_token_id)
    end = torch.nonzero(tokens == end_audio_token_id)
    
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]
    
    # substract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks
    
    if reminder:
        # pad if last frame is incomplete 
        audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)

    transposed = audio_tokens.view(-1, n_codebooks).t()
    codes = transposed.view(n_codebooks, 1, -1).to(device)

    audio = quantizer.decode(codes).squeeze(0)

    del tokens 
    del audio_tokens 
    torch.cuda.empty_cache()
    
    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)

    

In [None]:
def freeze_entire_model(model):
    for n, p in model.named_parameters():
        p.requires_grad = False
    return model

In [None]:
n_codebooks = 3

config_path = "./audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "./audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()

for n, child in quantizer.named_children():
    child.to(model.device)
    child = freeze_entire_model(child)

codebook_size = quantizer.quantizer.bins

In [None]:
def prepare_librispeech():
    raw = load_dataset("openslr/librispeech_asr", "clean", cache_dir=".")
    processed = raw.remove_columns(["chapter_id"])
    processed = processed.cast_column('speaker_id', Value('string'))
    return processed 

In [None]:
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"

n_tokens = len(tokenizer) - 1024

dataset = prepare_librispeech()
val_data = dataset["validation"]

val_dataset = Vikhr4oDataset(val_data, tokenizer, quantizer)

padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)

In [None]:
n_examples = 5
max_seq_length = 1024
start_audio_token_id = tokenizer(start_audio_token)["input_ids"][-1]
end_audio_token_id = tokenizer(end_audio_token)["input_ids"][-1]

for i in range(n_examples):
    row = val_dataset[i]

    for k, v in row.items():
        row[k] = v.to("cuda:0")

    # GT
    print("GT:", tokenizer.decode(row["text_tokens"][0], skip_special_tokens=True))

    # audio
    attention_mask=torch.ones(row["text_tokens"].size(), device=model.device)
    output_audio = model.generate(row["text_tokens"], attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=20, do_sample=True)
    audio = decode_audio(output_audio[0], quantizer, padding_tokens.t()[0], n_tokens)
    audio.write(f"tests/audio_mixed_{i}.wav")

    # text
    attention_mask=torch.ones(row["audio_tokens"].size(), device=model.device)
    output_text = model.generate(row["audio_tokens"], attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=20, do_sample=True)
    output_text = output_text.cpu()[0]
    output_text = output_text[output_text < start_audio_token_id]
    text = tokenizer.decode(output_text, skip_special_tokens=True)
    print(text)

    print()

In [24]:
decode_audio(output_audio[0], quantizer, padding_tokens.t()[0], n_tokens)

cuda:0


<audiotools.core.audio_signal.AudioSignal at 0x79293b99a4e0>

In [22]:
padding_tokens.t()

tensor([[ 328,   15,  842,  573,  612,  345, 1003,  399]], device='cuda:0')