In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen2ForCausalLM, Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import ALL_ATTENTION_FUNCTIONS, eager_attention_forward

import torch
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
from typing import Tuple
from pathlib import Path
from sparktts.utils.file import load_config
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
import torch.cuda.nvtx as nvtx
import time
from transformers import generation

In [3]:
""""
python -m cli.inference \
    --text "text to synthesis." \
    --device 0 \
    --save_dir output \
    --model_dir /home/vishwa/small_projects/pretrained_model \
    --prompt_speech_path example/prompt_audio.wav
"""

'"\npython -m cli.inference     --text "text to synthesis."     --device 0     --save_dir output     --model_dir /home/vishwa/small_projects/pretrained_model     --prompt_speech_path example/prompt_audio.wav\n'

In [4]:
model_dir = "/home/vishwa/small_projects/pretrained_model"
device = torch.device("cuda:0")
text = "Hi! How are you?"
prompt_speech_path = "example/prompt_audio.wav"
temperature = 0.8
top_k = 50
top_p = 0.95

In [5]:
tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}/LLM")
model = AutoModelForCausalLM.from_pretrained(f"{model_dir}/LLM", torch_dtype="float32", _attn_implementation="sdpa")
audio_tokenizer = BiCodecTokenizer(model_dir, device=device)
model.to(device)

  WeightNorm.apply(module, name, dim)


Missing tensor: mel_transformer.spectrogram.window
Missing tensor: mel_transformer.mel_scale.fb


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(166000, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [6]:
print(f"Model dtype: {next(model.parameters()).dtype}")

Model dtype: torch.float32


In [7]:
def process_prompt(
    text: str,
    prompt_speech_path: Path,
    prompt_text: str = None,
) -> Tuple[str, torch.Tensor]:
    """
    Process input for voice cloning.

    Args:
        text (str): The text input to be converted to speech.
        prompt_speech_path (Path): Path to the audio file used as a prompt.
        prompt_text (str, optional): Transcript of the prompt audio.

    Return:
        Tuple[str, torch.Tensor]: Input prompt; global tokens
    """

    global_token_ids, semantic_token_ids = audio_tokenizer.tokenize(
        prompt_speech_path
    )
    global_tokens = "".join(
        [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
    )

    # Prepare the input tokens for the model
    if prompt_text is not None:
        semantic_tokens = "".join(
            [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
        )
        inputs = [
            TASK_TOKEN_MAP["tts"],
            "<|start_content|>",
            prompt_text,
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
            "<|start_semantic_token|>",
            semantic_tokens,
        ]
    else:
        inputs = [
            TASK_TOKEN_MAP["tts"],
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
        ]

    inputs = "".join(inputs)

    return inputs, global_token_ids


In [8]:
prompt, global_token_ids = process_prompt(text, prompt_speech_path)
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)

In [9]:
model_inputs

{'input_ids': tensor([[165137, 165146,  13048,      0,   2585,    525,    498,     30, 165152,
         165150, 155028, 154032, 153256, 155290, 151959, 155221, 152863, 154311,
         155062, 155443, 154124, 154774, 152682, 155525, 154843, 155079, 153392,
         152226, 152761, 154798, 155312, 154843, 154432, 151738, 154083, 153503,
         155310, 154003, 154092, 153809, 151971, 154464, 165156]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [10]:
# with torch.no_grad():
#     with profile(
#             activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#             schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=3),
#             record_shapes=True,
#             profile_memory=True,
#             with_stack=True,
#             with_flops=True,
#             with_modules=True,
#             on_trace_ready=tensorboard_trace_handler('./profiler/forwardpass_sdpa')
#         ) as prof:
#             for _ in range(10):
#                 ip = torch.randint(low = 0, high=166000,size=(1,1000)).cuda()
#                 generated_ids = model(ip)
#                 prof.step()

In [11]:
generated_ids = model(model_inputs['input_ids'])

In [14]:
next_token_logits = generated_ids.logits[0, -1, :]

In [15]:
next_token_logits.shape

torch.Size([166000])

In [16]:
if temperature > 0:
        temp_logits = next_token_logits / temperature
    
    # Apply softmax to get probabilities
probs = torch.softmax(temp_logits, dim=-1)

sorted_probs, sorted_indices = torch.sort(probs, descending=False)

if top_k > 0:
    # Keep only top-k tokens
    sorted_probs = sorted_probs[:top_k]
    sorted_indices = sorted_indices[:top_k]

In [None]:
import triton
import triton.language as tl

@triton.jit
def fused_temperature_softmax_kernel(
    logits_ptr, out_ptr, 
    temp, top_k, top_p,
    n_vocab, BLOCK_SIZE: tl.constexpr
):
    """
    Fused kernel for temperature scaling, top-k and top-p sampling.
    
    Args:
        logits_ptr: Pointer to logits tensor
        out_ptr: Pointer to output tensor
        temp: Temperature value
        top_k: k value for top-k sampling
        top_p: p value for top-p sampling
        n_vocab: Size of vocabulary
    """
    # Get program ID
    pid = tl.program_id(0)
    
    # Compute offsets
    offset_base = pid * n_vocab
    
    # Load logits
    mask = tl.arange(0, BLOCK_SIZE) < n_vocab
    logits = tl.load(logits_ptr + offset_base + tl.arange(0, BLOCK_SIZE), mask=mask, other=-float('inf'))
    
    # Apply temperature scaling
    logits_scaled = logits / temp
    
    row_minus_max = logits_scaled - tl.max(logits_scaled, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    
    # Store result
    tl.store(out_ptr + offset_base + tl.arange(0, BLOCK_SIZE), softmax_output, mask=mask)

@triton.jit
def fused_topk_topp_kernel(
    sorted_probs_ptr, out_ptr, 
    top_k, top_p,
    n_vocab, BLOCK_SIZE: tl.constexpr
):
    """
    Fused kernel for top-k and top-p sampling on already sorted probabilities.
    
    Args:
        sorted_probs_ptr: Pointer to sorted probabilities tensor
        out_ptr: Pointer to output tensor
        top_k: k value for top-k sampling
        top_p: p value for top-p sampling
        n_vocab: Size of vocabulary
    """
    # Get program ID
    pid = tl.program_id(0)
    
    # Compute offsets
    offset_base = pid * n_vocab
    
    # Load sorted probs
    mask = tl.arange(0, BLOCK_SIZE) < n_vocab
    offsets = tl.arange(0, BLOCK_SIZE)
    sorted_probs = tl.load(sorted_probs_ptr + offset_base + offsets, mask=mask, other=0.0)
    
    # Initialize result mask (start with all tokens)
    result_mask = mask
    
    # Apply top-k filtering if specified
    if top_k > 0:
        k_val = tl.minimum(n_vocab, top_k)
        k_mask = offsets < k_val
        result_mask = mask & k_mask
    
    # # Apply top-p filtering if specified
    # if top_p > 0 and top_p < 1.0:
    #     cumulative_probs = tl.cumsum(sorted_probs)
    #     # Keep tokens whose cumulative probability is <= top_p
    #     p_mask = cumulative_probs <= top_p
    #     # Always include at least the first token
    #     p_mask = p_mask | (offsets == 0)
    #     result_mask = result_mask & p_mask
    
    # Apply the combined mask to the probabilities
    filtered_probs = tl.where(result_mask, sorted_probs, 0.0)
    
    # Renormalize the remaining probabilities (optional)
    # sum_probs = tl.sum(filtered_probs, axis=0)
    # filtered_probs = tl.where(sum_probs > 0.0, filtered_probs / sum_probs, filtered_probs)
    
    # Store result
    tl.store(out_ptr + offset_base + offsets, filtered_probs, mask=mask)


# Wrapper function to call the kernel
def sample_next_token_triton(logits, temperature=1.0, top_k=0, top_p=1.0):
    vocab_size = logits.shape[-1]
    output = torch.empty_like(logits)
    print(logits.shape)
    # Configure grid
    grid = (1,)
    # Launch kernel
    fused_temperature_softmax_kernel[grid](
        logits, output, 
        temperature, top_k, top_p,
        vocab_size, triton.next_power_of_2(vocab_size)
    )
    print("softmax done")
    sorted_probs, sorted_indices = torch.sort(output, descending=True)
    # output = torch.empty_like(logits)
    # fused_topk_topp_kernel[grid](
    #     sorted_probs, output, 
    #     temperature, top_k, top_p,
    #     vocab_size, triton.next_power_of_2(vocab_size)
    # )
    return output


In [None]:
triton_probs = sample_next_token_triton(logits=next_token_logits, temperature=temperature, top_k=top_k, top_p=top_p)

torch.Size([166000])


In [None]:
triton_probs

In [None]:
sorted_probs