generated from KBVE/nodepy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
145 lines (119 loc) · 4.19 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Generate multi line comment with post body example
"""
{
"model_name": "facebook/musicgen-small",
"duration": 15,
"prompt": "I love",
"strategy": "loudness",
"sampling": true,
"top_k": 0,
"top_p": 0.9,
"temperature": 0.9
"use_diffusion": true
"use_custom": false
}
"""
from flask import Flask, request, send_file, abort
from audiocraft.models import musicgen
from audiocraft.data.audio import audio_write
from audiocraft.models import MultiBandDiffusion
import uuid
import io
import torch
app = Flask(__name__)
@app.route("/generate_music", methods=["POST"])
def generate_music():
# Get the model name, duration, prompt, and strategy from the request body
model_name = request.json.get("model_name")
duration = request.json.get("duration")
prompt = request.json.get("prompt")
strategy = request.json.get("strategy")
sampling = request.json.get("sampling")
top_k = request.json.get("top_k")
top_p = request.json.get("top_p")
temperature = request.json.get("temperature")
use_diffusion = request.json.get("use_diffusion")
use_custom = request.json.get("use_custom")
# Check if the model name is valid
if model_name not in [
"facebook/musicgen-small",
"facebook/musicgen-medium",
"facebook/musicgen-large",
]:
abort(
400,
"Invalid model name (facebook/musicgen-small, facebook/musicgen-medium, facebook/musicgen-large)",
)
# Check if the duration is valid
if duration not in [15, 30, 60, 90, 120]:
abort(400, "Invalid duration in seconds (15, 30, 60, 90, 120)")
# Check if the prompt is valid
if not isinstance(prompt, str):
abort(400, "Invalid prompt (string)")
# Check if the strategy is valid
if strategy not in ["loudness", "peak", "clip"]:
abort(400, "Invalid strategy (loudness, peak, clip)")
# Check if the sampling is true or false
if not isinstance(sampling, bool):
abort(400, "Invalid sampling (true, false)")
# Check if the top_k is valid
if not isinstance(top_k, int):
abort(400, "Invalid top_k (int)")
# Check if the top_p is valid
if not isinstance(top_p, float):
abort(400, "Invalid top_p (float)")
# Check if the temperature is valid
if not isinstance(temperature, float):
abort(400, "Invalid temperature (float)")
# Check if the use_diffusion is valid
if not isinstance(use_diffusion, bool):
abort(400, "Invalid use_diffusion (true, false)")
if not isinstance(use_custom, bool):
abort(400, "Invalid use_custom (true, false)")
# Print the request body
print(request.json)
# Generate a unique UUID for the generated .wav file
myuuid = uuid.uuid4()
# Load the specified model and set the generation parameters
model = musicgen.MusicGen.get_pretrained(model_name, device="cuda")
if use_custom:
model.lm.load_state_dict(torch.load('models/lm_final.pt'))
model.set_generation_params(
duration=duration,
use_sampling=sampling,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
# Generate the music using the specified prompt
wav = model.generate([prompt], progress=True, return_tokens=True)
# if use_diffusion:
if use_diffusion:
print("Using diffusion")
mbd = MultiBandDiffusion.get_mbd_musicgen()
diff = mbd.tokens_to_wav(wav[1])
create_wav(diff, myuuid, model, strategy)
else:
print("Not using diffusion")
create_wav(wav[0], myuuid, model, strategy)
# Read the generated .wav file into memory
with open(f"{str(myuuid)}.wav", "rb") as f:
wav_data = f.read()
# Return the .wav file as a response
return send_file(
io.BytesIO(wav_data),
mimetype="audio/wav",
as_attachment=True,
download_name=f"{str(myuuid)}.wav",
)
def create_wav(output, myuuid, model, strategy):
for idx, one_wav in enumerate(output):
audio_write(
f"{str(myuuid)}",
one_wav.cpu(),
model.sample_rate,
strategy=strategy,
loudness_compressor=True,
)
if __name__ == "__main__":
app.run()