In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
from safetensors.torch import load_file

import transformers
import torch
from torch.utils.data import Dataset
from datasets import load_from_disk, Audio

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig,
)
from peft import LoraConfig, PeftModelForCausalLM

import dac
from audiotools import AudioSignal



In [5]:
model_path = "results2/checkpoint-8000"
base_model = "google/gemma-2-2b"

lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(base_model, cache_dir=".", attn_implementation="eager", device_map={"":0})
model.resize_token_embeddings(len(tokenizer))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Embedding(257025, 2304, padding_idx=0)

In [6]:
peft_model = PeftModelForCausalLM.from_pretrained(model, model_path, adapter_name="lora", peft_config=lora_config)
merged_model = peft_model.merge_and_unload()

In [14]:
max_seq_length = 512


class TestDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

        self.start_audio_token = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, 1:]
        self.end_audio_token = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, 1:]
        self.end_frame_token = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, 1:].item()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset[idx]
        text = row["text"]
        text_tokenized = self.tokenizer(text, return_tensors="pt")
        text_input_tokens = text_tokenized["input_ids"]
        
        raw_audio_tokens = torch.tensor(row["audio_tokens"])[:, :n_codebooks]
        n_frames = raw_audio_tokens.shape[-1]
        raw_audio_tokens = torch.cat([raw_audio_tokens, torch.full((1, 1, n_frames), self.end_frame_token)], dim=1)
        
        # permute: (n_codebooks, n_frames) -> (n_frames, n_codebooks)
        audio_input_tokens = raw_audio_tokens.permute(2, 0, 1).contiguous().view(1, -1)

        audio_length = min(max_seq_length - text_input_tokens.shape[-1] - 2, audio_input_tokens.shape[-1])
        audio_length -= audio_length % n_codebooks
        
        input_tokens = torch.cat([text_input_tokens, self.start_audio_token], dim=1)
        labels = torch.cat([audio_input_tokens[:, :audio_length], self.end_audio_token], dim=1)
        attention_mask = torch.ones(input_tokens.shape)

        return {
            "input_ids": input_tokens, 
            "attention_mask": attention_mask, 
            "labels": labels,
        }

In [8]:
def get_audio_padding_tokens(quantizer, n_original_tokens):
    # create audio without any sounds 
    # seems to work better than radom padding if 
    # length of generated audio is not devisible by n_codebooks
    audio = torch.zeros((1, 1, 1))
    audio = audio.to(quantizer.device)
    
    x = quantizer.preprocess(audio, quantizer.sample_rate)
    _, codes, _, _, _ = quantizer.encode(x)

    # Move tensor back to CPU and delete it to free GPU memory
    del audio
    del x
    torch.cuda.empty_cache()
    
    return {"audio_tokens": codes[0].t()}
    


def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
    # find start and end indices of audio tokens 
    tokens[tokens != end_frame_token_id]
    start = torch.nonzero(tokens[0] == start_audio_token_id)
    end = torch.nonzero(tokens[0] == tokenizer.eos_token_id)
    
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]
    # substract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[:, start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks
    
    if reminder:
        # pad if last frame is incomplete 
        audio_tokens = torch.cat([audio_tokens, pad_tokens[:, reminder:]], dim=1)

    codes = audio_tokens.view(1, -1, n_codebooks).permute(0, 2, 1).to(quantizer.device)
    z = quantizer.quantizer.from_codes(codes)[0]
    audio = quantizer.decode(z)

    del tokens 
    del audio_tokens 
    torch.cuda.empty_cache()
    
    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)

In [9]:
n_codebooks = 4
quant_path = dac.utils.download(model_type="16khz")
quantizer = dac.DAC.load(quant_path, n_codebooks=n_codebooks).to(f"cpu")

  model_dict = torch.load(location, "cpu")
  WeightNorm.apply(module, name, dim)


In [15]:
start_audio_token = "<soa>"
end_audio_token = "<eos>"
end_frame_token = "<eof>"
n_tokens = 256000

val_data = load_from_disk("./data/processed/val")
val_dataset = TestDataset(val_data, tokenizer)

padding_tokens = get_audio_padding_tokens(quantizer, n_tokens + 1)["audio_tokens"]

In [None]:
n_examples = 5
start_audio_token_id = 256000
end_frame_token_id = tokenizer(end_frame_token)["input_ids"][-1]

for i in range(n_examples):
    row = val_dataset[i]

    for k, v in row.items():
        row[k] = v.to("cuda:0")
        
    output = merged_model.generate(**row, max_new_tokens=max_seq_length)
    audio = decode_audio(output.cpu(), quantizer, padding_tokens, n_tokens + 1)
    audio.write(f"tests/audio_{i}.wav")
    
    print(tokenizer.decode(row["input_ids"][0], skip_special_tokens=True))
    print()