In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from safetensors.torch import load_file

import transformers
import torch
from torch.utils.data import Dataset
from datasets import load_dataset, Audio, Value

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
)

from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal

In [3]:
device = "cuda:0"
n_special_tokens = 3

In [4]:
model_path = "results_asr_semantic/checkpoint-16000"

tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", attn_implementation="eager", device_map={"":0})

In [5]:
class Vikhr4oDataset(Dataset):
    def __init__(self, dataset, tokenizer, quantizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

        self.soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
        self.eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
        self.eos = tokenizer(end_sequence_token, return_tensors="pt")["input_ids"][:, -1:].to(device)

        self.n_original_tokens = len(tokenizer) - 1024
        self.quantizer = quantizer 

    def __len__(self):
        return len(self.dataset)

    def quantize(self, example):
        audio_data, sample_rate = example["audio"]["array"], example["audio"]["sampling_rate"]
    
        # audio -> discrete codes
        audio = torch.tensor(audio_data).view(1, 1, len(audio_data)).float()
        audio = audio.to(device)
        codes = self.quantizer.encode(audio)
        codes = codes.squeeze(1)
    
        # Move tensor back to CPU and delete it to free GPU memory
        del audio
        torch.cuda.empty_cache()
    
        # increment tokens' ids 
        return codes + self.n_original_tokens

    def __getitem__(self, idx):
        row = self.dataset[idx]

        # get text tokens 
        text = row["text"]
        text_tokenized = self.tokenizer(text, return_tensors="pt")
        text_input_tokens = text_tokenized["input_ids"].to(device)

        # quantize audio 
        codes = self.quantize(row)
        raw_audio_tokens = codes[:n_codebooks]
        
        audio_input_tokens = raw_audio_tokens.contiguous().view(1, -1)

        # determine number of audio tokens given max_seq_length 
        audio_length = min(max_seq_length - text_input_tokens.shape[-1] - n_special_tokens, audio_input_tokens.shape[-1])
        audio_length -= audio_length % n_codebooks

        audio_tokens = torch.cat([self.soa, audio_input_tokens[:, :audio_length], self.eoa], dim=1)
        text_tokens = torch.cat([text_input_tokens, self.soa], dim=1)

        return {
            "audio_tokens": audio_tokens, 
            "text_tokens": text_tokens,
        }

In [6]:
def get_audio_padding_tokens(quantizer):
    # create audio without any sounds 
    # seems to work better than radom padding if 
    # length of generated audio is not devisible by n_codebooks
    audio = torch.zeros((1, 1, 1))
    audio = audio.to(device)
    
    codes = quantizer.encode(audio)

    # Move tensor back to CPU and delete it to free GPU memory
    del audio
    torch.cuda.empty_cache()
    
    return {"audio_tokens": codes.squeeze(1)}
    


def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
    # find start and end indices of audio tokens 
    start = torch.nonzero(tokens == start_audio_token_id)
    end = torch.nonzero(tokens == end_audio_token_id)
    
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]

    # substract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks
    
    if reminder:
        # pad if last frame is incomplete 
        audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:]], dim=0)

    codes = audio_tokens.view(1, n_codebooks, -1).to(device)
    audio = quantizer.decode(codes).squeeze(0)

    del tokens 
    del audio_tokens 
    torch.cuda.empty_cache()
    
    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)

    

In [7]:
def freeze_entire_model(model):
    for n, p in model.named_parameters():
        p.requires_grad = False
    return model

In [8]:
n_codebooks = 1

config_path = "./audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "./audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()

for n, child in quantizer.named_children():
    child.to(model.device)
    child = freeze_entire_model(child)

codebook_size = quantizer.quantizer.bins

  WeightNorm.apply(module, name, dim)
  params = torch.load(ckpt_path, map_location='cpu')


In [9]:
def prepare_librispeech():
    raw = load_dataset("openslr/librispeech_asr", "clean", cache_dir=".")
    processed = raw.remove_columns(["chapter_id"])
    processed = processed.cast_column('speaker_id', Value('string'))
    return processed 

In [10]:
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"

n_tokens = len(tokenizer) - 1024

dataset = prepare_librispeech()
val_data = dataset["validation"]

val_dataset = Vikhr4oDataset(val_data, tokenizer, quantizer)

padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

In [11]:
n_examples = 5
max_seq_length = 1024
start_audio_token_id = tokenizer(start_audio_token)["input_ids"][-1]
end_audio_token_id = tokenizer(end_audio_token)["input_ids"][-1]

for i in range(n_examples):
    row = val_dataset[i]

    for k, v in row.items():
        row[k] = v.to("cuda:0")

    # GT
    print("GT:", tokenizer.decode(row["text_tokens"][0], skip_special_tokens=True))

    # audio
    attention_mask=torch.ones(row["text_tokens"].size(), device=model.device)
    output_audio = model.generate(row["text_tokens"], attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=20, do_sample=True)
    audio = decode_audio(output_audio.cpu()[0], quantizer, padding_tokens, n_tokens)
    audio.write(f"tests/audio_{i}.wav")

    # text
    attention_mask=torch.ones(row["audio_tokens"].size(), device=model.device)
    output_text = model.generate(row["audio_tokens"], attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=20, do_sample=True)
    text = tokenizer.decode(output_text.cpu()[0], skip_special_tokens=True)
    print(text)

    print()

GT: HE WAS IN A FEVERED STATE OF MIND OWING TO THE BLIGHT HIS WIFE'S ACTION THREATENED TO CAST UPON HIS ENTIRE FUTURE




<audio_token_334><audio_token_334><audio_token_334><audio_token_334><audio_token_334><audio_token_334><audio_token_338><audio_token_313><audio_token_51><audio_token_405><audio_token_405><audio_token_86><audio_token_567><audio_token_567><audio_token_100><audio_token_91><audio_token_91><audio_token_1001><audio_token_977><audio_token_977><audio_token_828><audio_token_671><audio_token_671><audio_token_185><audio_token_904><audio_token_279><audio_token_532><audio_token_651><audio_token_651><audio_token_124><audio_token_589><audio_token_312><audio_token_248><audio_token_248><audio_token_714><audio_token_714><audio_token_101><audio_token_101><audio_token_332><audio_token_577><audio_token_379><audio_token_379><audio_token_984><audio_token_984><audio_token_249><audio_token_750><audio_token_750><audio_token_884><audio_token_571><audio_token_140><audio_token_140><audio_token_778><audio_token_289><audio_token_293><audio_token_406><audio_token_406><audio_token_855><audio_token_855><audio_token_388>