Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/cpp/llama-cpp/Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

LLAMA_VERSION?=e725a1a982ca870404a9c4935df52466327bbd02
LLAMA_VERSION?=a0552c8beef74e843bb085c8ef0c63f9ed7a2b27
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp

CMAKE_ARGS?=
Expand Down
94 changes: 86 additions & 8 deletions backend/python/transformers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

XPU=os.environ.get("XPU", "0") == "1"
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
from scipy.io import wavfile
import outetts
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -90,13 +90,38 @@ def LoadModel(self, request, context):
self.CUDA = torch.cuda.is_available()
self.OV=False
self.OuteTTS=False
self.DiaTTS=False
self.SentenceTransformer = False

device_map="cpu"

quantization = None
autoTokenizer = True

# Parse options from request.Options
self.options = {}
options = request.Options

# The options are a list of strings in this form optname:optvalue
# We are storing all the options in a dict so we can use it later when generating
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
# if value is a number, convert it to the appropriate type
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
# Keep as string if conversion fails
pass
self.options[key] = value

print(f"Parsed options: {self.options}", file=sys.stderr)

if self.CUDA:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
if request.MainGPU:
Expand Down Expand Up @@ -202,6 +227,11 @@ def LoadModel(self, request, context):
autoTokenizer = False
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
elif request.Type == "DiaForConditionalGeneration":
autoTokenizer = False
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
self.DiaTTS = True
elif request.Type == "OuteTTS":
autoTokenizer = False
options = request.Options
Expand Down Expand Up @@ -262,7 +292,7 @@ def LoadModel(self, request, context):
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
self.max_tokens = self.model.config.max_position_embeddings
else:
self.max_tokens = 512
self.max_tokens = self.options.get("max_new_tokens", 512)

if autoTokenizer:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
Expand Down Expand Up @@ -485,16 +515,15 @@ def SoundGeneration(self, request, context):
return_tensors="pt",
)

tokens = 256
if request.HasField('duration'):
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
guidance = 3.0
guidance = self.options.get("guidance_scale", 3.0)
if request.HasField('temperature'):
guidance = request.temperature
dosample = True
dosample = self.options.get("do_sample", True)
if request.HasField('sample'):
dosample = request.sample
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
Expand All @@ -506,13 +535,59 @@ def SoundGeneration(self, request, context):
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)


def DiaTTS(self, request, context):
"""
Generates dialogue audio using the Dia model.

Args:
request: A TTSRequest containing text dialogue and generation parameters
context: The gRPC context

Returns:
A Result object indicating success or failure
"""
try:
print("[DiaTTS] generating dialogue audio", file=sys.stderr)

# Prepare text input - expect dialogue format like [S1] ... [S2] ...
text = [request.text]

# Process the input
inputs = self.processor(text=text, padding=True, return_tensors="pt")

# Generate audio with parameters from options or defaults
generation_params = {
**inputs,
"max_new_tokens": self.max_tokens,
"guidance_scale": self.options.get("guidance_scale", 3.0),
"temperature": self.options.get("temperature", 1.8),
"top_p": self.options.get("top_p", 0.90),
"top_k": self.options.get("top_k", 45)
}

outputs = self.model.generate(**generation_params)

# Decode and save audio
outputs = self.processor.batch_decode(outputs)
self.processor.save_audio(outputs, request.dst)

print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
print("[DiaTTS] Dialogue generation done", file=sys.stderr)

except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)


def OuteTTS(self, request, context):
try:
print("[OuteTTS] generating TTS", file=sys.stderr)
gen_cfg = outetts.GenerationConfig(
text="Speech synthesis is the artificial production of human speech.",
temperature=0.1,
repetition_penalty=1.1,
temperature=self.options.get("temperature", 0.1),
repetition_penalty=self.options.get("repetition_penalty", 1.1),
max_length=self.max_tokens,
speaker=self.speaker,
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
Expand All @@ -529,6 +604,9 @@ def OuteTTS(self, request, context):
def TTS(self, request, context):
if self.OuteTTS:
return self.OuteTTS(request, context)

if self.DiaTTS:
return self.DiaTTS(request, context)

model_name = request.model
try:
Expand Down