/
common.py
318 lines (240 loc) · 6.84 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
from dataclasses import dataclass, asdict, field
import inspect
from pathlib import Path
import contextlib
import json
import logging
import shutil
import tempfile
import typing
from modal import Stub, Dict, NetworkFileSystem
import llm
# directory to store media on the volume
MEDIA_PATH = Path("/media")
# fixed language for now
LANGUAGE = "en"
# fixed english whisper model for now
MODEL_NAME = "large-v2"
# main storage volume
volume = NetworkFileSystem.persisted("media")
# nfs
nfs = {
str(MEDIA_PATH): volume
}
# stub
stub = Stub(name="timething-web")
stub.transcriptions = Dict.new()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Shared data structures
#
#
@dataclass
class UploadInfo:
"""
Original file metadata for a file uploaded by the user.
"""
filename: str = None
content_type: str = None
size_bytes: int = None
def from_dict(d):
return UploadInfo(**{
k: v for k, v in d.items()
if k in inspect.signature(UploadInfo).parameters
})
@dataclass
class Track:
"""
A processed upload. Tracks go through the following states:
- uploaded (an upload was completed by the user)
- transcoded (ffmpeg ran and a wav was written out)
- transcribed (whisper ran and a json result was written out)
"""
title: str = None
artist: str = None
album: str = None
comment: str = None
description: str = None
date: str = None
duration: float = None
def from_dict(d):
return Track(**{
k: v for k, v in d.items()
if k in inspect.signature(Track).parameters
})
def from_probe(probe):
tags = probe.get("format", {}).get("tags", {})
allowed_tags = Track().__dict__.keys()
tags = {k: v for (k, v) in tags.items() if k in allowed_tags}
tags["duration"] = float(probe['format']['duration'])
return Track(**tags)
@dataclass
class Turn:
"""
Diarization information
"""
# name of the speaker
speaker: str
# time in seconds
start: float
# time in seconds
end: float
@dataclass
class Diarization:
"""
Diarization of this transcript
"""
# speaker turns
turns: typing.List[Turn]
def from_dict(d):
return Diarization(**{
k: v for k, v in d.items()
if k in inspect.signature(Diarization).parameters
})
@dataclass
class Segment:
"""
A single alignment segment
"""
# token
label: str
# time in seconds
start: float
# time in seconds
end: float
# likelihood of this alignment
score: float
@dataclass
class Alignment:
"""
A single word level alignment
"""
# original word segments
words: typing.List[Segment] = field(default_factory=list)
def from_dict(d):
return Alignment(**{
k: v for k, v in d.items()
if k in inspect.signature(Alignment).parameters
})
@dataclass
class Transcription:
"""
A complete transcription and metadata of a processed Upload.
"""
# canonical id
transcription_id: str
# initial upload information
upload: UploadInfo
# track metadta
track: Track = None
# transcript
transcript: str = None
# diarization
diarization: typing.Optional[Diarization] = None
# alignment
alignment: typing.Optional[Alignment] = None
# is it already transcoded
transcoded: bool = False
# path to the original upload
path: str = None
# source language
language: str = None
@property
def transcribed(self):
return self.transcript is not None
@property
def uploaded_file(self):
return Path(self.path) if self.path else None
@property
def transcoded_file(self):
return self.uploaded_file.with_suffix('.wav')
@property
def transcribed_file(self):
return self.uploaded_file.with_suffix('.json')
@property
def content_type(self):
if self.upload:
return self.upload.content_type
def from_dict(d: dict):
track = Track()
if 'track' in d and d['track']:
track = Track.from_dict(d['track'])
upload_info = UploadInfo()
if 'upload' in d and d['upload']:
upload_info = UploadInfo.from_dict(d['upload'])
alignment = Alignment()
if 'alignment' in d and d['alignment']:
alignment = Alignment.from_dict(d['alignment'])
diarization = Diarization(turns=[])
if 'diarization' in d and d['diarization']:
diarization = Diarization.from_dict(d['diarization'])
return Transcription(
transcription_id=d['transcription_id'],
track=track,
upload=upload_info,
alignment=alignment,
diarization=diarization,
transcoded=d.get('transcoded', False),
transcript=d.get('transcript', {}),
path=d.get('path'),
language=d.get('language')
)
# Modal abstractions
#
#
class Store:
"""Keep a data layer here so we can move it out of modal later.
"""
def __init__(self, media_path: Path):
self.media_path = media_path
def create(self, t: Transcription):
if not t.transcription_id:
raise Exception(f'id not specified')
stub.transcriptions[t.transcription_id] = t
content = json.dumps(asdict(t), cls=JSONEncoder)
with open(t.transcribed_file, 'w') as f:
f.write(content)
def select(self, transcription_id: str) -> typing.Optional[Transcription]:
if not transcription_id:
raise Exception(f'id not specified')
# guard against path traversal attacks
if transcription_id not in stub.transcriptions:
return None
path = self.media_path / transcription_id
meta = path.with_suffix('.json')
if not meta.exists():
raise Exception(f'id not found')
with open(meta, 'r') as f:
t_dict = json.load(f)
t = Transcription.from_dict(t_dict)
stub.transcriptions[transcription_id] = t
return t
# store on nfs
db = Store(MEDIA_PATH)
# whisper gpt
def whisper_gpt():
return llm.ChatGPT(
llm.WHISPER_SYSTEM_PROMPT,
llm.WHISPER_MAX_TOKENS
)
# Utils
#
#
def dataclass_to_event(x):
data = json.dumps(asdict(x), ensure_ascii=False)
return f"event: {type(x).__name__}\ndata: {data}\n\n"
def get_device():
import torch
return "cuda:0" if torch.cuda.is_available() else "cpu"
@contextlib.contextmanager
def tmpdir_scope():
tmpdir = tempfile.mkdtemp()
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir)
class JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Path):
return str(obj)
return json.JSONEncoder.default(self, obj)