In [1]:
from snac import SNAC
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
import torch
import librosa
from IPython.display import display, Audio
import soundfile as sf


class ModelInfer:
    def __init__(self, model_name, tokenizer_name="canopylabs/orpheus-3b-0.1-ft"):
        self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
        self.snac_model = self.snac_model.to("cpu")

        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
        self.model.cuda()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    def redistribute_codes(self, code_list):
        layer_1 = []
        layer_2 = []
        layer_3 = []
        for i in range((len(code_list)+1)//7):
            layer_1.append(code_list[7*i])
            layer_2.append(code_list[7*i+1]-4096)
            layer_3.append(code_list[7*i+2]-(2*4096))
            layer_3.append(code_list[7*i+3]-(3*4096))
            layer_2.append(code_list[7*i+4]-(4*4096))
            layer_3.append(code_list[7*i+5]-(5*4096))
            layer_3.append(code_list[7*i+6]-(6*4096))
        codes = [torch.tensor(layer_1).unsqueeze(0),
                torch.tensor(layer_2).unsqueeze(0),
                torch.tensor(layer_3).unsqueeze(0)]
        audio_hat = self.snac_model.decode(codes)
        return audio_hat

    def tokenise_audio(self, waveform):
        waveform = torch.from_numpy(waveform).unsqueeze(0)
        waveform = waveform.to(dtype=torch.float32)


        waveform = waveform.unsqueeze(0)

        with torch.inference_mode():
            codes = self.snac_model.encode(waveform)

        all_codes = []
        for i in range(codes[0].shape[1]):
            all_codes.append(codes[0][0][i].item()+128266)
            all_codes.append(codes[1][0][2*i].item()+128266+4096)
            all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
            all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
            all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
            all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
            all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))


        return all_codes
    
    def output_to_speech(self, generated_ids):
        token_to_find = 128257
        token_to_remove = 128258

        token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

        if len(token_indices[1]) > 0:
            last_occurrence_idx = token_indices[1][-1].item()
            cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
        else:
            cropped_tensor = generated_ids

        mask = cropped_tensor != token_to_remove

        processed_rows = []

        for row in cropped_tensor:
            masked_row = row[row != token_to_remove]
            processed_rows.append(masked_row)

        code_lists = []

        for row in processed_rows:
            row_length = row.size(0)
            new_length = (row_length // 7) * 7
            trimmed_row = row[:new_length]
            trimmed_row = [t - 128266 for t in trimmed_row]
            code_lists.append(trimmed_row)

        my_samples = []
        for code_list in code_lists:
            samples = self.redistribute_codes(code_list)
            my_samples.append(samples)
        
        return my_samples


    def generate_from_voice(self, texts: list, voice: str, temp=0.6, repetition_penalty=1.1):
        # 1
        prompts = [f"{voice}: " + t for t in texts]

        all_input_ids = []

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            all_input_ids.append(input_ids)

        start_token = torch.tensor([[ 128259]], dtype=torch.int64) # Start of human
        end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human

        all_modified_input_ids = []
        for input_ids in all_input_ids:
            modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
            all_modified_input_ids.append(modified_input_ids)

        all_padded_tensors = []
        all_attention_masks = []
        max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])
        for modified_input_ids in all_modified_input_ids:
            padding = max_length - modified_input_ids.shape[1]
            padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
            attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
            all_padded_tensors.append(padded_tensor)
            all_attention_masks.append(attention_mask)

        all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
        all_attention_masks = torch.cat(all_attention_masks, dim=0)

        input_ids = all_padded_tensors.to("cuda")
        attention_mask = all_attention_masks.to("cuda")

        # 2
        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=2500,
                do_sample=True,
                temperature=temp,
                top_p=0.95,
                repetition_penalty=repetition_penalty,
                num_return_sequences=1,
                eos_token_id=128258,
            )

        # 3
        return self.output_to_speech(generated_ids)


    def generate_from_audio(self, texts: list, audio_path, audio_text, temp=0.5, repetition_penalty=1.1):
        audio_array, sample_rate = librosa.load(audio_path, sr=24000)

        myts = self.tokenise_audio(audio_array)
        start_tokens = torch.tensor([[ 128259]], dtype=torch.int64)
        end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
        final_tokens = torch.tensor([[128258, 128262]], dtype=torch.int64)
        prompt_tokked = self.tokenizer(audio_text, return_tensors="pt")

        input_ids = prompt_tokked["input_ids"]

        zeroprompt_input_ids = torch.cat([start_tokens, input_ids, end_tokens, torch.tensor([myts]), final_tokens], dim=1) # SOH SOT Text EOT EOH


        all_modified_input_ids = []
        for text in texts:
            input_ids = self.tokenizer(text, return_tensors="pt").input_ids
            second_input_ids = torch.cat([zeroprompt_input_ids, start_tokens, input_ids, end_tokens], dim=1)
            all_modified_input_ids.append(second_input_ids)


        all_padded_tensors = []
        all_attention_masks = []

        max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])

        for modified_input_ids in all_modified_input_ids:
            padding = max_length - modified_input_ids.shape[1]
            padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
            attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
            all_padded_tensors.append(padded_tensor)
            all_attention_masks.append(attention_mask)

        all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
        all_attention_masks = torch.cat(all_attention_masks, dim=0)

        input_ids = all_padded_tensors.to("cuda")
        attention_mask = all_attention_masks.to("cuda")

        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids=input_ids,
                # attention_mask=attention_mask,
                max_new_tokens=1500,
                do_sample=True,
                temperature=temp,
                # top_k=40,
                top_p=0.9,
                repetition_penalty=repetition_penalty,
                num_return_sequences=1,
                eos_token_id=128258,
                # end_token_id=128009
            )

        return self.output_to_speech(generated_ids)
        

In [2]:
model = ModelInfer('intedont/orpheus_woman_5000')

  state_dict = torch.load(model_path, map_location="cpu")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
audios = model.generate_from_audio(texts=["<Angry> Mister lawson saw george last night. <Happy> And this is great, let's add some words to make sentence longer!"], 
                        audio_path="0015_000033.wav",
                        audio_text="This used to be Jerry's occupation.",
                        temp=0.6,
                        repetition_penalty=1.1)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128258 for open-end generation.


In [16]:
from IPython.display import display, Audio
for i in range(len(audios)):
    samples = audios[i]
    display(Audio(samples.detach().squeeze().to("cpu").numpy(), rate=24000))


In [17]:
import soundfile as sf

sf.write("output.wav", audios[0].detach().squeeze().to("cpu").numpy(), 24000)