diff --git a/dataflow/example/SpeechTranscription/audio/test.wav b/dataflow/example/SpeechTranscription/audio/test.wav new file mode 100644 index 00000000..4439a1e5 Binary files /dev/null and b/dataflow/example/SpeechTranscription/audio/test.wav differ diff --git a/dataflow/example/SpeechTranscription/pipeline_speechtranscription.jsonl b/dataflow/example/SpeechTranscription/pipeline_speechtranscription.jsonl new file mode 100644 index 00000000..87f7674f --- /dev/null +++ b/dataflow/example/SpeechTranscription/pipeline_speechtranscription.jsonl @@ -0,0 +1,2 @@ +{"raw_content": "../example_data/SpeechTranscription/audio/test.wav"} +{"raw_content": "https://raw.githubusercontent.com/FireRedTeam/FireRedASR/main/examples/wav/IT0011W0001.wav"} \ No newline at end of file diff --git a/dataflow/operators/generate/SpeechTranscription/__init__.py b/dataflow/operators/generate/SpeechTranscription/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dataflow/operators/generate/SpeechTranscription/speech_transcriptor.py b/dataflow/operators/generate/SpeechTranscription/speech_transcriptor.py new file mode 100644 index 00000000..13a774fd --- /dev/null +++ b/dataflow/operators/generate/SpeechTranscription/speech_transcriptor.py @@ -0,0 +1,165 @@ +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow import get_logger + +from dataflow.utils.storage import DataFlowStorage +from dataflow.core import OperatorABC +from dataflow.core import LLMServingABC + +import os +import math +import warnings +import base64 +from io import BytesIO +from typing import List, Optional, Union, Dict, Tuple + +import librosa +import numpy as np +import requests + +# 不重采样 +DEFAULT_SR = None + +def _read_audio_remote(path: str, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]: + url = path + resp = requests.get(url, stream=True) + + audio_bytes = BytesIO(resp.content) + y, sr = librosa.load(audio_bytes, sr=sr) + return y, sr + +def _read_audio_local(path: str, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]: + return librosa.load(path, sr=sr, mono=True) + +def _read_audio_bytes(data: bytes, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]: + return librosa.load(BytesIO(data), sr=sr, mono=True) + +def _read_audio_base64(b64: str, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]: + header, b64data = b64.split(",", 1) + data = base64.b64decode(b64data) + return _read_audio_bytes(data, sr=sr) + +def process_audio_info( + conversations: List[dict] | List[List[dict]], # 这个conversation对应的是vllm中的messages列表(对应的是conversation_to_message函数的message) + sampling_rate: Optional[int] +) -> Tuple[ + Optional[List[np.ndarray]], + Optional[List[int]], + Optional[List[str]] +]: + """ + 类似于 vision 的 process_vision_info,从 message 列表中提取音频输入。 + 支持三种格式输入: + - 本地或 http(s) URL 路径(通过 librosa 接口处理) + - base64 编码 (data:audio/…;base64,…) + - 直接传入 bytes 对象 + 返回二元组: + - audio_arrays: 解码后的 waveform (List[np.ndarray]) + - sample_rates: 采样率列表 (List[int]) + """ + if isinstance(conversations, list) and conversations and isinstance(conversations[0], dict): + # 单条 conversaion + conversations = [conversations] # conversations被统一为List[List[dict]] + + audio_arrays = [] + sampling_rates = [] + + for conv in conversations: + for msg in conv: + if not isinstance(msg.get("content"), list): + continue + for ele in msg["content"]: + if ele.get("type") != "audio": + continue + aud = ele.get("audio") + if isinstance(aud, str): + if aud.startswith("data:audio") and "base64," in aud: + arr, sr = _read_audio_base64(aud, sr=sampling_rate) + audio_arrays.append(arr) + sampling_rates.append(sr) + elif aud.startswith("http://") or aud.startswith("https://"): + # 使用 librosa 支持远程路径 + arr, sr = _read_audio_remote(aud, sr=sampling_rate) + audio_arrays.append(arr) + sampling_rates.append(sr) + else: + # 本地路径 + arr, sr = _read_audio_local(aud, sr=sampling_rate) + audio_arrays.append(arr) + sampling_rates.append(sr) + elif isinstance(aud, (bytes, bytearray)): + arr, sr = _read_audio_bytes(bytes(aud), sr=sampling_rate) + audio_arrays.append(arr) + sampling_rates.append(sr) + else: + raise ValueError(f"Unsupported audio type: {type(aud)}") + + if not audio_arrays: + return None, None + return audio_arrays, sampling_rates + +@OPERATOR_REGISTRY.register() +class SpeechTranscriptor(OperatorABC): + def __init__( + self, + llm_serving: LLMServingABC, + system_prompt: str = "You are a helpful assistant", + ): + self.logger = get_logger() + self.llm_serving = llm_serving + self.system_prompt = system_prompt + + def run(self, storage: DataFlowStorage, input_key: str = "raw_content", output_key: str = "generated_content"): + self.input_key, self.output_key = input_key, output_key + self.logger.info("Running Speech Transcriptor...") + + dataframe = storage.read('dataframe') + self.logger.info(f"Loading, number of rows: {len(dataframe)}") + + conversations = [] + for index, row in dataframe.iterrows(): + path_or_url = row.get(self.input_key, '') + conversation = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": [ + { + "type": "audio", + "audio": path_or_url + }, + { + "type": "text", + "text": "请把语音转录为中文文本" + } + ] + } + ] + conversations.append(conversation) + + user_inputs = [self.llm_serving.processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=True, + add_audio_id = True + ) for conversation in conversations] + print(user_inputs) + + + audio_arrays, sampling_rates = process_audio_info(conversations=conversations, sampling_rate=16000) + audio_inputs = [(audio_array, sampling_rate) for audio_array, sampling_rate in zip(audio_arrays, sampling_rates)] + + transcriptions = self.llm_serving.generate_from_input( + user_inputs=user_inputs, + audio_inputs=audio_inputs, + system_prompt=self.system_prompt + ) + + dataframe[self.output_key] = transcriptions + output_file = storage.write(dataframe) + self.logger.info(f"Saving to {output_file}") + self.logger.info("Speech Transcriptor done") + + return output_key diff --git a/dataflow/operators/generate/__init__.py b/dataflow/operators/generate/__init__.py index c2e94f16..70506d1c 100644 --- a/dataflow/operators/generate/__init__.py +++ b/dataflow/operators/generate/__init__.py @@ -49,6 +49,9 @@ #VQA from .VQA.PromptedVQAGenerator import PromptedVQAGenerator + + # SpeechTranscription + from .SpeechTranscription.speech_transcriptor import SpeechTranscriptor else: import sys from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking diff --git a/dataflow/serving/LocalModelLALMServing.py b/dataflow/serving/LocalModelLALMServing.py new file mode 100644 index 00000000..ac5ba04c --- /dev/null +++ b/dataflow/serving/LocalModelLALMServing.py @@ -0,0 +1,125 @@ +import os +import torch +from dataflow import get_logger +from huggingface_hub import snapshot_download +from dataflow.core import LLMServingABC +from transformers import AutoProcessor +from typing import Optional, Union, List, Dict, Any + +class LocalModelLALMServing_vllm(LLMServingABC): + ''' + A class for generating text using vllm, with model from huggingface or local directory + ''' + def __init__(self, + hf_model_name_or_path: str = None, + hf_cache_dir: str = None, + hf_local_dir: str = None, + vllm_tensor_parallel_size: int = 1, + vllm_temperature: float = 0.7, + vllm_top_p: float = 0.9, + vllm_max_tokens: int = 1024, + vllm_top_k: int = 40, + vllm_repetition_penalty: float = 1.0, + vllm_seed: int = 42, + vllm_max_model_len: int = None, + vllm_gpu_memory_utilization: float=0.9, + ): + + self.load_model( + hf_model_name_or_path=hf_model_name_or_path, + hf_cache_dir=hf_cache_dir, + hf_local_dir=hf_local_dir, + vllm_tensor_parallel_size=vllm_tensor_parallel_size, + vllm_temperature=vllm_temperature, + vllm_top_p=vllm_top_p, + vllm_max_tokens=vllm_max_tokens, + vllm_top_k=vllm_top_k, + vllm_repetition_penalty=vllm_repetition_penalty, + vllm_seed=vllm_seed, + vllm_max_model_len=vllm_max_model_len, + vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, + ) + + def load_model(self, + hf_model_name_or_path: str = None, + hf_cache_dir: str = None, + hf_local_dir: str = None, + vllm_tensor_parallel_size: int = 1, + vllm_temperature: float = 0.7, + vllm_top_p: float = 0.9, + vllm_max_tokens: int = 1024, + vllm_top_k: int = 40, + vllm_repetition_penalty: float = 1.0, + vllm_seed: int = 42, + vllm_max_model_len: int = None, + vllm_gpu_memory_utilization: float=0.9, + ): + self.logger = get_logger() + if hf_model_name_or_path is None: + raise ValueError("hf_model_name_or_path is required") + elif os.path.exists(hf_model_name_or_path): + self.logger.info(f"Using local model path: {hf_model_name_or_path}") + self.real_model_path = hf_model_name_or_path + else: + self.logger.info(f"Downloading model from HuggingFace: {hf_model_name_or_path}") + self.real_model_path = snapshot_download( + repo_id=hf_model_name_or_path, + cache_dir=hf_cache_dir, + local_dir=hf_local_dir, + ) + # get the model name from the real_model_path + self.model_name = os.path.basename(self.real_model_path) + self.processor = AutoProcessor.from_pretrained(self.real_model_path, cache_dir=hf_cache_dir) + print(f"Model name: {self.model_name}") + + + # Import vLLM and set up the environment for multiprocessing + # vLLM requires the multiprocessing method to be set to spawn + try: + from vllm import LLM,SamplingParams + except: + raise ImportError("please install vllm first like 'pip install open-dataflow[vllm]'") + # Set the environment variable for vllm to use spawn method for multiprocessing + # See https://docs.vllm.ai/en/v0.7.1/design/multiprocessing.html + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "spawn" + + self.sampling_params = SamplingParams( + temperature=vllm_temperature, + top_p=vllm_top_p, + max_tokens=vllm_max_tokens, + top_k=vllm_top_k, + repetition_penalty=vllm_repetition_penalty, + seed=vllm_seed + ) + + self.llm = LLM( + model=self.real_model_path, + tensor_parallel_size=vllm_tensor_parallel_size, + max_model_len=vllm_max_model_len, + gpu_memory_utilization=vllm_gpu_memory_utilization, + ) + self.logger.success(f"Model loaded from {self.real_model_path} by vLLM backend") + + def generate_from_input(self, + user_inputs: list[str], + audio_inputs: list, + system_prompt: str = "You are a helpful assistant", + ) -> list[str]: + + + full_prompts = [] + for user_input, audio_input in zip(user_inputs, audio_inputs): + full_prompts.append({ + 'prompt': user_input, + 'multi_modal_data': {'audio': audio_input} + }) + + responses = self.llm.generate(full_prompts, self.sampling_params) + return [output.outputs[0].text for output in responses] + + def cleanup(self): + del self.llm + import gc; + gc.collect() + torch.cuda.empty_cache() + \ No newline at end of file diff --git a/dataflow/serving/__init__.py b/dataflow/serving/__init__.py index c7c56c2b..87d486a7 100644 --- a/dataflow/serving/__init__.py +++ b/dataflow/serving/__init__.py @@ -3,11 +3,13 @@ from .LocalModelLLMServing import LocalModelLLMServing_sglang from .GoogleAPIServing import PerspectiveAPIServing from .LiteLLMServing import LiteLLMServing +from .LocalModelLALMServing import LocalModelLALMServing_vllm __all__ = [ "APILLMServing_request", "LocalModelLLMServing_vllm", "LocalModelLLMServing_sglang", "PerspectiveAPIServing", - "LiteLLMServing" + "LiteLLMServing", + "LocalModelLALMServing_vllm" ] \ No newline at end of file diff --git a/dataflow/statics/pipelines/gpu_pipelines/speechtranscription_pipeline.py b/dataflow/statics/pipelines/gpu_pipelines/speechtranscription_pipeline.py new file mode 100644 index 00000000..9d577bcb --- /dev/null +++ b/dataflow/statics/pipelines/gpu_pipelines/speechtranscription_pipeline.py @@ -0,0 +1,32 @@ +from dataflow.operators.generate.SpeechTranscription.speech_transcriptor import SpeechTranscriptor +from dataflow.serving import LocalModelLALMServing_vllm +from dataflow.utils.storage import FileStorage + +class SpeechTranscription_GPUPipeline(): + def __init__(self): + self.storage = FileStorage( + first_entry_file_name="../example_data/SpeechTranscription/pipeline_speechtranscription.jsonl", + cache_path="./cache", + file_name_prefix="dataflow_cache_step", + cache_type="jsonl", + ) + + self.llm_serving = LocalModelLALMServing_vllm( + hf_model_name_or_path='/data0/gty/models/Qwen2-Audio-7B-Instruct', + vllm_tensor_parallel_size=4, + vllm_max_tokens=8192, + ) + self.speech_transcriptor = SpeechTranscriptor( + llm_serving = self.llm_serving, + system_prompt="你是一个专业的翻译员,你需要将语音转录为文本。" + ) + + def forward(self): + self.speech_transcriptor.run( + storage=self.storage.step(), + input_key="raw_content" + ) + +if __name__ == "__main__": + pipeline = SpeechTranscription_GPUPipeline() + pipeline.forward() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d96cd5d2..98805e65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,3 +73,4 @@ agent = [ "uvicorn", "sseclient-py", ] +audio = ['librosa', 'soundfile'] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e9d8af2b..4a0fae3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,3 +63,7 @@ requests termcolor uvicorn sseclient-py + +# speech +librosa +soundfile \ No newline at end of file