In [None]:
from huggingface_hub import login
login("...")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
lora_path = "/content/drive/MyDrive/orpheus-tts-adapter"

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

base_model = "canopylabs/orpheus-3b-0.1-pretrained"

tokenizer = AutoTokenizer.from_pretrained(lora_path)
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16).to("cuda")
model = PeftModel.from_pretrained(model, lora_path)

In [None]:
#!pip install git+https://github.com/hubertsiuzdak/snac.git

In [None]:
from snac import SNAC
from IPython.display import display, Audio

snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to("cuda")

prompts = [
    "ASAN Xidmət mərkəzinə xoş gəlmisiniz! Büyürun, müraciyətiniz nə məsələ ilə bağlıdır?"
]
chosen_voice = None

In [None]:
prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]

all_input_ids = []
for prompt in prompts_:
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    all_input_ids.append(input_ids)

# Tokens
start_token = torch.tensor([[128259]], dtype=torch.int64)  # Start of human
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)  # End of text, End of human

# Prepare input sequences
all_modified_input_ids = []
for input_ids in all_input_ids:
    modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
    all_modified_input_ids.append(modified_input_ids)

# Padding and attention mask
all_padded_tensors = []
all_attention_masks = []
max_length = max([x.shape[1] for x in all_modified_input_ids])
for x in all_modified_input_ids:
    padding = max_length - x.shape[1]
    padded = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), x], dim=1)
    mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, x.shape[1]), dtype=torch.int64)], dim=1)
    all_padded_tensors.append(padded)
    all_attention_masks.append(mask)

input_ids = torch.cat(all_padded_tensors, dim=0).to("cuda")
attention_mask = torch.cat(all_attention_masks, dim=0).to("cuda")

# Generate output
generated_ids = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=1200,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    num_return_sequences=1,
    eos_token_id=128258,
    use_cache=True
)

# Postprocessing
token_to_find = 128257
token_to_remove = 128258
indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

if len(indices[1]) > 0:
    start_idx = indices[1][-1].item()
    cropped = generated_ids[:, start_idx+1:]
else:
    cropped = generated_ids

processed_rows = [row[row != token_to_remove] for row in cropped]

# Decode logic
code_lists = []
for row in processed_rows:
    row_len = row.size(0)
    trimmed = row[:(row_len // 7) * 7]
    trimmed = [t - 128266 for t in trimmed]
    code_lists.append(trimmed)

def redistribute_codes(code_list):
    layer_1, layer_2, layer_3 = [], [], []
    for i in range((len(code_list) + 1) // 7):
        layer_1.append(code_list[7*i])
        layer_2.append(code_list[7*i+1] - 4096)
        layer_3.append(code_list[7*i+2] - 8192)
        layer_3.append(code_list[7*i+3] - 12288)
        layer_2.append(code_list[7*i+4] - 16384)
        layer_3.append(code_list[7*i+5] - 20480)
        layer_3.append(code_list[7*i+6] - 24576)
        codes = [
        torch.tensor(layer_1).unsqueeze(0).to("cuda"),
        torch.tensor(layer_2).unsqueeze(0).to("cuda"),
        torch.tensor(layer_3).unsqueeze(0).to("cuda")
    ]
    audio_hat = snac_model.decode(codes)
    return audio_hat

my_samples = [redistribute_codes(code_list) for code_list in code_lists]

if len(prompts) != len(my_samples):
    raise Exception("Number of prompts and samples do not match")
else:
    for i in range(len(my_samples)):
        print(prompts[i])
        display(Audio(my_samples[i].detach().squeeze().cpu().numpy(), rate=24000))

del my_samples