Skip to content

Commit eca4a2e

Browse files
Merge pull request #68 from jkrukowski/cli-cleanup
WhisperKit CLI cleanup
2 parents 9c4d8e0 + e556132 commit eca4a2e

File tree

4 files changed

+157
-138
lines changed

4 files changed

+157
-138
lines changed

Sources/WhisperKit/Core/AudioProcessor.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,7 @@ public extension AudioProcessor {
479479
&inputDeviceID,
480480
UInt32(MemoryLayout<AudioDeviceID>.size)
481481
)
482-
483-
let format = inputNode.outputFormat(forBus: 0)
484-
482+
485483
if error != noErr {
486484
Logging.error("Error setting Audio Unit property: \(error)")
487485
} else {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// For licensing see accompanying LICENSE.md file.
2+
// Copyright © 2024 Argmax, Inc. All rights reserved.
3+
4+
import ArgumentParser
5+
6+
struct CLIArguments: ParsableArguments {
7+
@Option(help: "Path to audio file")
8+
var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"
9+
10+
@Option(help: "Path of model files")
11+
var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"
12+
13+
@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
14+
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
15+
16+
@Option(help: "Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
17+
var textDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
18+
19+
@Flag(help: "Verbose mode")
20+
var verbose: Bool = false
21+
22+
@Option(help: "Language spoken in the audio")
23+
var language: String?
24+
25+
@Option(help: "Temperature to use for sampling")
26+
var temperature: Float = 0
27+
28+
@Option(help: "Temperature to increase on fallbacks during decoding")
29+
var temperatureIncrementOnFallback: Float = 0.2
30+
31+
@Option(help: "Number of times to increase temperature when falling back during decoding")
32+
var temperatureFallbackCount: Int = 5
33+
34+
@Option(help: "Number of candidates when sampling with non-zero temperature")
35+
var bestOf: Int = 5
36+
37+
@Flag(help: "Force initial prompt tokens based on language, task, and timestamp options")
38+
var usePrefillPrompt: Bool = false
39+
40+
@Flag(help: "Use decoder prefill data for faster initial decoding")
41+
var usePrefillCache: Bool = false
42+
43+
@Flag(help: "Skip special tokens in the output")
44+
var skipSpecialTokens: Bool = false
45+
46+
@Flag(help: "Force no timestamps when decoding")
47+
var withoutTimestamps: Bool = false
48+
49+
@Flag(help: "Add timestamps for each word in the output")
50+
var wordTimestamps: Bool = false
51+
52+
@Argument(help: "Supress given tokens in the output")
53+
var supressTokens: [Int] = []
54+
55+
@Option(help: "Gzip compression ratio threshold for decoding failure")
56+
var compressionRatioThreshold: Float?
57+
58+
@Option(help: "Average log probability threshold for decoding failure")
59+
var logprobThreshold: Float?
60+
61+
@Option(help: "Probability threshold to consider a segment as silence")
62+
var noSpeechThreshold: Float?
63+
64+
@Flag(help: "Output a report of the results")
65+
var report: Bool = false
66+
67+
@Option(help: "Directory to save the report")
68+
var reportPath: String = "."
69+
70+
@Flag(help: "Process audio directly from the microphone")
71+
var stream: Bool = false
72+
}

Sources/WhisperKitCLI/CLIUtils.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// For licensing see accompanying LICENSE.md file.
2+
// Copyright © 2024 Argmax, Inc. All rights reserved.
3+
4+
import ArgumentParser
5+
import CoreML
6+
import Foundation
7+
import WhisperKit
8+
9+
enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
10+
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine, random
11+
var asMLComputeUnits: MLComputeUnits {
12+
switch self {
13+
case .all: return .all
14+
case .cpuAndGPU: return .cpuAndGPU
15+
case .cpuOnly: return .cpuOnly
16+
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
17+
case .random: return Bool.random() ? .cpuAndGPU : .cpuAndNeuralEngine
18+
}
19+
}
20+
}

Sources/WhisperKitCLI/transcribe.swift

Lines changed: 64 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -4,82 +4,33 @@
44
import ArgumentParser
55
import CoreML
66
import Foundation
7-
87
import WhisperKit
98

109
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
1110
@main
1211
struct WhisperKitCLI: AsyncParsableCommand {
13-
@Option(help: "Path to audio file")
14-
var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"
15-
16-
@Option(help: "Path of model files")
17-
var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"
18-
19-
@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
20-
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
21-
22-
@Option(help: "Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
23-
var textDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
24-
25-
@Flag(help: "Verbose mode")
26-
var verbose: Bool = false
27-
28-
@Option(help: "Task to perform (transcribe or translate)")
29-
var task: String = "transcribe"
30-
31-
@Option(help: "Language spoken in the audio")
32-
var language: String?
33-
34-
@Option(help: "Temperature to use for sampling")
35-
var temperature: Float = 0
36-
37-
@Option(help: "Temperature to increase on fallbacks during decoding")
38-
var temperatureIncrementOnFallback: Float = 0.2
39-
40-
@Option(help: "Number of times to increase temperature when falling back during decoding")
41-
var temperatureFallbackCount: Int = 5
42-
43-
@Option(help: "Number of candidates when sampling with non-zero temperature")
44-
var bestOf: Int = 5
45-
46-
@Flag(help: "Force initial prompt tokens based on language, task, and timestamp options")
47-
var usePrefillPrompt: Bool = false
48-
49-
@Flag(help: "Use decoder prefill data for faster initial decoding")
50-
var usePrefillCache: Bool = false
51-
52-
@Flag(help: "Skip special tokens in the output")
53-
var skipSpecialTokens: Bool = false
54-
55-
@Flag(help: "Force no timestamps when decoding")
56-
var withoutTimestamps: Bool = false
57-
58-
@Flag(help: "Add timestamps for each word in the output")
59-
var wordTimestamps: Bool = false
60-
61-
@Argument(help: "Supress given tokens in the output")
62-
var supressTokens: [Int] = []
63-
64-
@Option(help: "Gzip compression ratio threshold for decoding failure")
65-
var compressionRatioThreshold: Float?
66-
67-
@Option(help: "Average log probability threshold for decoding failure")
68-
var logprobThreshold: Float?
69-
70-
@Option(help: "Probability threshold to consider a segment as silence")
71-
var noSpeechThreshold: Float?
12+
static let configuration = CommandConfiguration(
13+
commandName: "transcribe",
14+
abstract: "WhisperKit Transcribe CLI",
15+
discussion: "Swift native speech recognition with Whisper for Apple Silicon"
16+
)
7217

73-
@Flag(help: "Output a report of the results")
74-
var report: Bool = false
18+
@OptionGroup
19+
var cliArguments: CLIArguments
7520

76-
@Option(help: "Directory to save the report")
77-
var reportPath: String = "."
78-
79-
@Flag(help: "Process audio directly from the microphone")
80-
var stream: Bool = false
21+
mutating func run() async throws {
22+
if cliArguments.stream {
23+
try await transcribeStream(modelPath: cliArguments.modelPath)
24+
} else {
25+
let audioURL = URL(fileURLWithPath: cliArguments.audioPath)
26+
if cliArguments.verbose {
27+
print("Transcribing audio at \(audioURL)")
28+
}
29+
try await transcribe(audioPath: cliArguments.audioPath, modelPath: cliArguments.modelPath)
30+
}
31+
}
8132

82-
func transcribe(audioPath: String, modelPath: String) async throws {
33+
private func transcribe(audioPath: String, modelPath: String) async throws {
8334
let resolvedModelPath = resolveAbsolutePath(modelPath)
8435
guard FileManager.default.fileExists(atPath: resolvedModelPath) else {
8536
fatalError("Model path does not exist \(resolvedModelPath)")
@@ -91,49 +42,52 @@ struct WhisperKitCLI: AsyncParsableCommand {
9142
}
9243

9344
let computeOptions = ModelComputeOptions(
94-
audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits,
95-
textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits
45+
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
46+
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
9647
)
9748

9849
print("Initializing models...")
9950
let whisperKit = try await WhisperKit(
10051
modelFolder: modelPath,
10152
computeOptions: computeOptions,
102-
verbose: verbose,
53+
verbose: cliArguments.verbose,
10354
logLevel: .debug
10455
)
10556
print("Models initialized")
10657

10758
let options = DecodingOptions(
108-
verbose: verbose,
59+
verbose: cliArguments.verbose,
10960
task: .transcribe,
110-
language: language,
111-
temperature: temperature,
112-
temperatureIncrementOnFallback: temperatureIncrementOnFallback,
113-
temperatureFallbackCount: temperatureFallbackCount,
114-
topK: bestOf,
115-
usePrefillPrompt: usePrefillPrompt,
116-
usePrefillCache: usePrefillCache,
117-
skipSpecialTokens: skipSpecialTokens,
118-
withoutTimestamps: withoutTimestamps,
119-
wordTimestamps: wordTimestamps,
120-
supressTokens: supressTokens,
121-
compressionRatioThreshold: compressionRatioThreshold,
122-
logProbThreshold: logprobThreshold,
123-
noSpeechThreshold: noSpeechThreshold
61+
language: cliArguments.language,
62+
temperature: cliArguments.temperature,
63+
temperatureIncrementOnFallback: cliArguments.temperatureIncrementOnFallback,
64+
temperatureFallbackCount: cliArguments.temperatureFallbackCount,
65+
topK: cliArguments.bestOf,
66+
usePrefillPrompt: cliArguments.usePrefillPrompt,
67+
usePrefillCache: cliArguments.usePrefillCache,
68+
skipSpecialTokens: cliArguments.skipSpecialTokens,
69+
withoutTimestamps: cliArguments.withoutTimestamps,
70+
wordTimestamps: cliArguments.wordTimestamps,
71+
supressTokens: cliArguments.supressTokens,
72+
compressionRatioThreshold: cliArguments.compressionRatioThreshold,
73+
logProbThreshold: cliArguments.logprobThreshold,
74+
noSpeechThreshold: cliArguments.noSpeechThreshold
12475
)
12576

126-
let transcribeResult = try await whisperKit.transcribe(audioPath: resolvedAudioPath, decodeOptions: options)
77+
let transcribeResult = try await whisperKit.transcribe(
78+
audioPath: resolvedAudioPath,
79+
decodeOptions: options
80+
)
12781

12882
let transcription = transcribeResult?.text ?? "Transcription failed"
12983

130-
if report, let result = transcribeResult {
84+
if cliArguments.report, let result = transcribeResult {
13185
let audioFileName = URL(fileURLWithPath: audioPath).lastPathComponent.components(separatedBy: ".").first!
13286

13387
// Write SRT (SubRip Subtitle Format) for the transcription
134-
let srtReportWriter = WriteSRT(outputDir: reportPath)
88+
let srtReportWriter = WriteSRT(outputDir: cliArguments.reportPath)
13589
let savedSrtReport = srtReportWriter.write(result: result, to: audioFileName)
136-
if verbose {
90+
if cliArguments.verbose {
13791
switch savedSrtReport {
13892
case let .success(reportPath):
13993
print("\n\nSaved SRT Report: \n\n\(reportPath)\n")
@@ -143,9 +97,9 @@ struct WhisperKitCLI: AsyncParsableCommand {
14397
}
14498

14599
// Write JSON for all metadata
146-
let jsonReportWriter = WriteJSON(outputDir: reportPath)
100+
let jsonReportWriter = WriteJSON(outputDir: cliArguments.reportPath)
147101
let savedJsonReport = jsonReportWriter.write(result: result, to: audioFileName)
148-
if verbose {
102+
if cliArguments.verbose {
149103
switch savedJsonReport {
150104
case let .success(reportPath):
151105
print("\n\nSaved JSON Report: \n\n\(reportPath)\n")
@@ -155,47 +109,47 @@ struct WhisperKitCLI: AsyncParsableCommand {
155109
}
156110
}
157111

158-
if verbose {
112+
if cliArguments.verbose {
159113
print("\n\nTranscription: \n\n\(transcription)\n")
160114
} else {
161115
print(transcription)
162116
}
163117
}
164118

165-
func transcribeStream(modelPath: String) async throws {
119+
private func transcribeStream(modelPath: String) async throws {
166120
let computeOptions = ModelComputeOptions(
167-
audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits,
168-
textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits
121+
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
122+
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
169123
)
170124

171125
print("Initializing models...")
172126
let whisperKit = try await WhisperKit(
173127
modelFolder: modelPath,
174128
computeOptions: computeOptions,
175-
verbose: verbose,
129+
verbose: cliArguments.verbose,
176130
logLevel: .debug
177131
)
178132
print("Models initialized")
179133

180134
let decodingOptions = DecodingOptions(
181-
verbose: verbose,
135+
verbose: cliArguments.verbose,
182136
task: .transcribe,
183-
language: language,
184-
temperature: temperature,
185-
temperatureIncrementOnFallback: temperatureIncrementOnFallback,
137+
language: cliArguments.language,
138+
temperature: cliArguments.temperature,
139+
temperatureIncrementOnFallback: cliArguments.temperatureIncrementOnFallback,
186140
temperatureFallbackCount: 3, // limit fallbacks for realtime
187141
sampleLength: 224, // reduced sample length for realtime
188-
topK: bestOf,
189-
usePrefillPrompt: usePrefillPrompt,
190-
usePrefillCache: usePrefillCache,
191-
skipSpecialTokens: skipSpecialTokens,
192-
withoutTimestamps: withoutTimestamps,
142+
topK: cliArguments.bestOf,
143+
usePrefillPrompt: cliArguments.usePrefillPrompt,
144+
usePrefillCache: cliArguments.usePrefillCache,
145+
skipSpecialTokens: cliArguments.skipSpecialTokens,
146+
withoutTimestamps: cliArguments.withoutTimestamps,
193147
clipTimestamps: [],
194148
suppressBlank: false,
195-
supressTokens: supressTokens,
196-
compressionRatioThreshold: compressionRatioThreshold ?? 2.4,
197-
logProbThreshold: logprobThreshold ?? -1.0,
198-
noSpeechThreshold: noSpeechThreshold ?? 0.6
149+
supressTokens: cliArguments.supressTokens,
150+
compressionRatioThreshold: cliArguments.compressionRatioThreshold ?? 2.4,
151+
logProbThreshold: cliArguments.logprobThreshold ?? -1.0,
152+
noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6
199153
)
200154

201155
let audioStreamTranscriber = AudioStreamTranscriber(
@@ -222,29 +176,4 @@ struct WhisperKitCLI: AsyncParsableCommand {
222176
print("Transcribing audio stream, press Ctrl+C to stop.")
223177
try await audioStreamTranscriber.startStreamTranscription()
224178
}
225-
226-
mutating func run() async throws {
227-
if stream {
228-
try await transcribeStream(modelPath: modelPath)
229-
} else {
230-
let audioURL = URL(fileURLWithPath: audioPath)
231-
if verbose {
232-
print("Transcribing audio at \(audioURL)")
233-
}
234-
try await transcribe(audioPath: audioPath, modelPath: modelPath)
235-
}
236-
}
237-
}
238-
239-
enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
240-
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine, random
241-
var asMLComputeUnits: MLComputeUnits {
242-
switch self {
243-
case .all: return .all
244-
case .cpuAndGPU: return .cpuAndGPU
245-
case .cpuOnly: return .cpuOnly
246-
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
247-
case .random: return Bool.random() ? .cpuAndGPU : .cpuAndNeuralEngine
248-
}
249-
}
250179
}

0 commit comments

Comments
 (0)