-
Notifications
You must be signed in to change notification settings - Fork 170
/
predict.py
65 lines (54 loc) · 2.36 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import cog
from typing import Dict
from pathlib import Path
import tempfile
import torch
import torchaudio
import librosa
import subprocess
import os
import soundfile as sf
from scipy.io.wavfile import write as write_wav
SAMPLE_RATE = 16000
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"Mars5 device: {device}")
filePath = '/tmp/output.wav'
class Predictor(cog.BasePredictor):
def setup(self):
self.mars5, self.config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', device=device, trust_repo=True)
print(">>>>> Model Loaded")
def predict(
self,
text: str = cog.Input(description="Text to synthesize"),
ref_audio_file: cog.Path = cog.Input(description='Reference audio file to clone from <= 10 seconds', default="https://files.catbox.moe/be6df3.wav"),
ref_audio_transcript: str = cog.Input(description='Text in the reference audio file', default="We actually haven't managed to meet demand."),
output_format: str = cog.Input(
description="Output format",
choices=["wav", "mp3"],
default="mp3"),
) -> str:
print(f">>>> Ref Audio file: {ref_audio_file}; ref_transcript: {ref_audio_transcript}")
# Load the reference audio
wav, sr = librosa.load(ref_audio_file, sr=self.mars5.sr, mono=True)
wav = torch.from_numpy(wav)
# configuration for the TTS model
deep_clone = True
cfg = self.config_class(deep_clone=deep_clone, rep_penalty_window=100, top_k=100, temperature=0.7, freq_penalty=3)
# Generate the synthesized audio
print(f">>> Running inference")
ar_codes, wav_out = self.mars5.tts(text, wav, ref_audio_transcript, cfg=cfg)
print(f">>>>> Done with inference")
output_path = "/tmp/output.wav"
write_wav(output_path, self.mars5.sr, wav_out.numpy())
# now convert the file stored at output_path to mp3
if output_format == 'mp3':
from pydub import AudioSegment
compressed = AudioSegment.from_wav(output_path)
compressed.export("output.mp3")
output_path = "output.mp3"
print(f">>>> Output file url: {output_path}")
return cog.Path(output_path)