4
4
import ArgumentParser
5
5
import CoreML
6
6
import Foundation
7
-
8
7
import WhisperKit
9
8
10
9
@available ( macOS 13 , iOS 16 , watchOS 10 , visionOS 1 , * )
11
10
@main
12
11
struct WhisperKitCLI : AsyncParsableCommand {
13
- @Option ( help: " Path to audio file " )
14
- var audioPath : String = " Tests/WhisperKitTests/Resources/jfk.wav "
15
-
16
- @Option ( help: " Path of model files " )
17
- var modelPath : String = " Models/whisperkit-coreml/openai_whisper-tiny "
18
-
19
- @Option ( help: " Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random} " )
20
- var audioEncoderComputeUnits : ComputeUnits = . cpuAndNeuralEngine
21
-
22
- @Option ( help: " Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random} " )
23
- var textDecoderComputeUnits : ComputeUnits = . cpuAndNeuralEngine
24
-
25
- @Flag ( help: " Verbose mode " )
26
- var verbose : Bool = false
27
-
28
- @Option ( help: " Task to perform (transcribe or translate) " )
29
- var task : String = " transcribe "
30
-
31
- @Option ( help: " Language spoken in the audio " )
32
- var language : String ?
33
-
34
- @Option ( help: " Temperature to use for sampling " )
35
- var temperature : Float = 0
36
-
37
- @Option ( help: " Temperature to increase on fallbacks during decoding " )
38
- var temperatureIncrementOnFallback : Float = 0.2
39
-
40
- @Option ( help: " Number of times to increase temperature when falling back during decoding " )
41
- var temperatureFallbackCount : Int = 5
42
-
43
- @Option ( help: " Number of candidates when sampling with non-zero temperature " )
44
- var bestOf : Int = 5
45
-
46
- @Flag ( help: " Force initial prompt tokens based on language, task, and timestamp options " )
47
- var usePrefillPrompt : Bool = false
48
-
49
- @Flag ( help: " Use decoder prefill data for faster initial decoding " )
50
- var usePrefillCache : Bool = false
51
-
52
- @Flag ( help: " Skip special tokens in the output " )
53
- var skipSpecialTokens : Bool = false
54
-
55
- @Flag ( help: " Force no timestamps when decoding " )
56
- var withoutTimestamps : Bool = false
57
-
58
- @Flag ( help: " Add timestamps for each word in the output " )
59
- var wordTimestamps : Bool = false
60
-
61
- @Argument ( help: " Supress given tokens in the output " )
62
- var supressTokens : [ Int ] = [ ]
63
-
64
- @Option ( help: " Gzip compression ratio threshold for decoding failure " )
65
- var compressionRatioThreshold : Float ?
66
-
67
- @Option ( help: " Average log probability threshold for decoding failure " )
68
- var logprobThreshold : Float ?
69
-
70
- @Option ( help: " Probability threshold to consider a segment as silence " )
71
- var noSpeechThreshold : Float ?
12
+ static let configuration = CommandConfiguration (
13
+ commandName: " transcribe " ,
14
+ abstract: " WhisperKit Transcribe CLI " ,
15
+ discussion: " Swift native speech recognition with Whisper for Apple Silicon "
16
+ )
72
17
73
- @Flag ( help : " Output a report of the results " )
74
- var report : Bool = false
18
+ @OptionGroup
19
+ var cliArguments : CLIArguments
75
20
76
- @Option ( help: " Directory to save the report " )
77
- var reportPath : String = " . "
78
-
79
- @Flag ( help: " Process audio directly from the microphone " )
80
- var stream : Bool = false
21
+ mutating func run( ) async throws {
22
+ if cliArguments. stream {
23
+ try await transcribeStream ( modelPath: cliArguments. modelPath)
24
+ } else {
25
+ let audioURL = URL ( fileURLWithPath: cliArguments. audioPath)
26
+ if cliArguments. verbose {
27
+ print ( " Transcribing audio at \( audioURL) " )
28
+ }
29
+ try await transcribe ( audioPath: cliArguments. audioPath, modelPath: cliArguments. modelPath)
30
+ }
31
+ }
81
32
82
- func transcribe( audioPath: String , modelPath: String ) async throws {
33
+ private func transcribe( audioPath: String , modelPath: String ) async throws {
83
34
let resolvedModelPath = resolveAbsolutePath ( modelPath)
84
35
guard FileManager . default. fileExists ( atPath: resolvedModelPath) else {
85
36
fatalError ( " Model path does not exist \( resolvedModelPath) " )
@@ -91,49 +42,52 @@ struct WhisperKitCLI: AsyncParsableCommand {
91
42
}
92
43
93
44
let computeOptions = ModelComputeOptions (
94
- audioEncoderCompute: audioEncoderComputeUnits. asMLComputeUnits,
95
- textDecoderCompute: textDecoderComputeUnits. asMLComputeUnits
45
+ audioEncoderCompute: cliArguments . audioEncoderComputeUnits. asMLComputeUnits,
46
+ textDecoderCompute: cliArguments . textDecoderComputeUnits. asMLComputeUnits
96
47
)
97
48
98
49
print ( " Initializing models... " )
99
50
let whisperKit = try await WhisperKit (
100
51
modelFolder: modelPath,
101
52
computeOptions: computeOptions,
102
- verbose: verbose,
53
+ verbose: cliArguments . verbose,
103
54
logLevel: . debug
104
55
)
105
56
print ( " Models initialized " )
106
57
107
58
let options = DecodingOptions (
108
- verbose: verbose,
59
+ verbose: cliArguments . verbose,
109
60
task: . transcribe,
110
- language: language,
111
- temperature: temperature,
112
- temperatureIncrementOnFallback: temperatureIncrementOnFallback,
113
- temperatureFallbackCount: temperatureFallbackCount,
114
- topK: bestOf,
115
- usePrefillPrompt: usePrefillPrompt,
116
- usePrefillCache: usePrefillCache,
117
- skipSpecialTokens: skipSpecialTokens,
118
- withoutTimestamps: withoutTimestamps,
119
- wordTimestamps: wordTimestamps,
120
- supressTokens: supressTokens,
121
- compressionRatioThreshold: compressionRatioThreshold,
122
- logProbThreshold: logprobThreshold,
123
- noSpeechThreshold: noSpeechThreshold
61
+ language: cliArguments . language,
62
+ temperature: cliArguments . temperature,
63
+ temperatureIncrementOnFallback: cliArguments . temperatureIncrementOnFallback,
64
+ temperatureFallbackCount: cliArguments . temperatureFallbackCount,
65
+ topK: cliArguments . bestOf,
66
+ usePrefillPrompt: cliArguments . usePrefillPrompt,
67
+ usePrefillCache: cliArguments . usePrefillCache,
68
+ skipSpecialTokens: cliArguments . skipSpecialTokens,
69
+ withoutTimestamps: cliArguments . withoutTimestamps,
70
+ wordTimestamps: cliArguments . wordTimestamps,
71
+ supressTokens: cliArguments . supressTokens,
72
+ compressionRatioThreshold: cliArguments . compressionRatioThreshold,
73
+ logProbThreshold: cliArguments . logprobThreshold,
74
+ noSpeechThreshold: cliArguments . noSpeechThreshold
124
75
)
125
76
126
- let transcribeResult = try await whisperKit. transcribe ( audioPath: resolvedAudioPath, decodeOptions: options)
77
+ let transcribeResult = try await whisperKit. transcribe (
78
+ audioPath: resolvedAudioPath,
79
+ decodeOptions: options
80
+ )
127
81
128
82
let transcription = transcribeResult? . text ?? " Transcription failed "
129
83
130
- if report, let result = transcribeResult {
84
+ if cliArguments . report, let result = transcribeResult {
131
85
let audioFileName = URL ( fileURLWithPath: audioPath) . lastPathComponent. components ( separatedBy: " . " ) . first!
132
86
133
87
// Write SRT (SubRip Subtitle Format) for the transcription
134
- let srtReportWriter = WriteSRT ( outputDir: reportPath)
88
+ let srtReportWriter = WriteSRT ( outputDir: cliArguments . reportPath)
135
89
let savedSrtReport = srtReportWriter. write ( result: result, to: audioFileName)
136
- if verbose {
90
+ if cliArguments . verbose {
137
91
switch savedSrtReport {
138
92
case let . success( reportPath) :
139
93
print ( " \n \n Saved SRT Report: \n \n \( reportPath) \n " )
@@ -143,9 +97,9 @@ struct WhisperKitCLI: AsyncParsableCommand {
143
97
}
144
98
145
99
// Write JSON for all metadata
146
- let jsonReportWriter = WriteJSON ( outputDir: reportPath)
100
+ let jsonReportWriter = WriteJSON ( outputDir: cliArguments . reportPath)
147
101
let savedJsonReport = jsonReportWriter. write ( result: result, to: audioFileName)
148
- if verbose {
102
+ if cliArguments . verbose {
149
103
switch savedJsonReport {
150
104
case let . success( reportPath) :
151
105
print ( " \n \n Saved JSON Report: \n \n \( reportPath) \n " )
@@ -155,47 +109,47 @@ struct WhisperKitCLI: AsyncParsableCommand {
155
109
}
156
110
}
157
111
158
- if verbose {
112
+ if cliArguments . verbose {
159
113
print ( " \n \n Transcription: \n \n \( transcription) \n " )
160
114
} else {
161
115
print ( transcription)
162
116
}
163
117
}
164
118
165
- func transcribeStream( modelPath: String ) async throws {
119
+ private func transcribeStream( modelPath: String ) async throws {
166
120
let computeOptions = ModelComputeOptions (
167
- audioEncoderCompute: audioEncoderComputeUnits. asMLComputeUnits,
168
- textDecoderCompute: textDecoderComputeUnits. asMLComputeUnits
121
+ audioEncoderCompute: cliArguments . audioEncoderComputeUnits. asMLComputeUnits,
122
+ textDecoderCompute: cliArguments . textDecoderComputeUnits. asMLComputeUnits
169
123
)
170
124
171
125
print ( " Initializing models... " )
172
126
let whisperKit = try await WhisperKit (
173
127
modelFolder: modelPath,
174
128
computeOptions: computeOptions,
175
- verbose: verbose,
129
+ verbose: cliArguments . verbose,
176
130
logLevel: . debug
177
131
)
178
132
print ( " Models initialized " )
179
133
180
134
let decodingOptions = DecodingOptions (
181
- verbose: verbose,
135
+ verbose: cliArguments . verbose,
182
136
task: . transcribe,
183
- language: language,
184
- temperature: temperature,
185
- temperatureIncrementOnFallback: temperatureIncrementOnFallback,
137
+ language: cliArguments . language,
138
+ temperature: cliArguments . temperature,
139
+ temperatureIncrementOnFallback: cliArguments . temperatureIncrementOnFallback,
186
140
temperatureFallbackCount: 3 , // limit fallbacks for realtime
187
141
sampleLength: 224 , // reduced sample length for realtime
188
- topK: bestOf,
189
- usePrefillPrompt: usePrefillPrompt,
190
- usePrefillCache: usePrefillCache,
191
- skipSpecialTokens: skipSpecialTokens,
192
- withoutTimestamps: withoutTimestamps,
142
+ topK: cliArguments . bestOf,
143
+ usePrefillPrompt: cliArguments . usePrefillPrompt,
144
+ usePrefillCache: cliArguments . usePrefillCache,
145
+ skipSpecialTokens: cliArguments . skipSpecialTokens,
146
+ withoutTimestamps: cliArguments . withoutTimestamps,
193
147
clipTimestamps: [ ] ,
194
148
suppressBlank: false ,
195
- supressTokens: supressTokens,
196
- compressionRatioThreshold: compressionRatioThreshold ?? 2.4 ,
197
- logProbThreshold: logprobThreshold ?? - 1.0 ,
198
- noSpeechThreshold: noSpeechThreshold ?? 0.6
149
+ supressTokens: cliArguments . supressTokens,
150
+ compressionRatioThreshold: cliArguments . compressionRatioThreshold ?? 2.4 ,
151
+ logProbThreshold: cliArguments . logprobThreshold ?? - 1.0 ,
152
+ noSpeechThreshold: cliArguments . noSpeechThreshold ?? 0.6
199
153
)
200
154
201
155
let audioStreamTranscriber = AudioStreamTranscriber (
@@ -222,29 +176,4 @@ struct WhisperKitCLI: AsyncParsableCommand {
222
176
print ( " Transcribing audio stream, press Ctrl+C to stop. " )
223
177
try await audioStreamTranscriber. startStreamTranscription ( )
224
178
}
225
-
226
- mutating func run( ) async throws {
227
- if stream {
228
- try await transcribeStream ( modelPath: modelPath)
229
- } else {
230
- let audioURL = URL ( fileURLWithPath: audioPath)
231
- if verbose {
232
- print ( " Transcribing audio at \( audioURL) " )
233
- }
234
- try await transcribe ( audioPath: audioPath, modelPath: modelPath)
235
- }
236
- }
237
- }
238
-
239
- enum ComputeUnits : String , ExpressibleByArgument , CaseIterable {
240
- case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine, random
241
- var asMLComputeUnits : MLComputeUnits {
242
- switch self {
243
- case . all: return . all
244
- case . cpuAndGPU: return . cpuAndGPU
245
- case . cpuOnly: return . cpuOnly
246
- case . cpuAndNeuralEngine: return . cpuAndNeuralEngine
247
- case . random: return Bool . random ( ) ? . cpuAndGPU : . cpuAndNeuralEngine
248
- }
249
- }
250
179
}
0 commit comments