-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathmodal_app.py
More file actions
104 lines (87 loc) · 2.81 KB
/
modal_app.py
File metadata and controls
104 lines (87 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
This is the base modal app that will be used to generate the modal. View the instruction in the README.md
or visit this website for more information: https://shush.arihanv.com/#host
"""
from modal import Image, App, method, asgi_app, enter
from fastapi import Request, FastAPI
import tempfile
import time
MODEL_DIR = "/model"
web_app = FastAPI()
def download_model():
from huggingface_hub import snapshot_download
snapshot_download("openai/whisper-large-v3", local_dir=MODEL_DIR)
image = (
Image.from_registry("nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04", add_python="3.9")
.apt_install("git", "ffmpeg")
.pip_install(
"transformers",
"ninja",
"packaging",
"wheel",
"torch",
"hf-transfer~=0.1",
"ffmpeg-python",
)
.run_commands("python -m pip install flash-attn --no-build-isolation", gpu="A10G")
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(
download_model,
)
)
app = App("whisper-v3-demo", image=image)
@app.cls(
gpu="A10G",
allow_concurrent_inputs=80,
container_idle_timeout=40,
)
class WhisperV3:
@enter()
def setup(self):
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_DIR,
torch_dtype=self.torch_dtype,
use_safetensors=True,
use_flash_attention_2=True,
)
processor = AutoProcessor.from_pretrained(MODEL_DIR)
model.to(self.device)
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=24,
return_timestamps=True,
torch_dtype=self.torch_dtype,
model_kwargs={"use_flash_attention_2": True},
device=0,
)
@method()
def generate(self, audio: bytes):
fp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
fp.write(audio)
fp.close()
start = time.time()
output = self.pipe(
fp.name, chunk_length_s=30, batch_size=24, return_timestamps=True
)
elapsed = time.time() - start
return output, elapsed
@app.function()
@web_app.post("/")
async def transcribe(request: Request):
form = await request.form()
audio = await form["audio"].read()
output, elapsed = WhisperV3().generate.remote(audio)
return output, elapsed
@app.function()
@asgi_app()
def entrypoint():
return web_app