diff --git a/.github/workflows/development-tests.yml b/.github/workflows/development-tests.yml index dda3843..fff4165 100644 --- a/.github/workflows/development-tests.yml +++ b/.github/workflows/development-tests.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.ref_name }} + group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index de1294d..7da08c8 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -21,15 +21,19 @@ struct CLIArguments: ParsableArguments { @Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}") var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine + @Option(help: "Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}") var textDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine @Flag(help: "Verbose mode") var verbose: Bool = false - + + @Option(help: "Task to perform (transcribe or translate)") + var task: String = "transcribe" + @Option(help: "Language spoken in the audio") var language: String? - + @Option(help: "Temperature to use for sampling") var temperature: Float = 0 diff --git a/Sources/WhisperKitCLI/Transcribe.swift b/Sources/WhisperKitCLI/Transcribe.swift index 581aeb0..fb3aaa6 100644 --- a/Sources/WhisperKitCLI/Transcribe.swift +++ b/Sources/WhisperKitCLI/Transcribe.swift @@ -28,8 +28,16 @@ struct Transcribe: AsyncParsableCommand { guard FileManager.default.fileExists(atPath: resolvedAudioPath) else { throw CocoaError.error(.fileNoSuchFile) } + + let task: DecodingTask + if cliArguments.task.lowercased() == "translate" { + task = .translate + } else { + task = .transcribe + } + if cliArguments.verbose { - print("Transcribing audio at \(cliArguments.audioPath)") + print("Task: \(task.description.capitalized) audio at \(cliArguments.audioPath)") } var audioEncoderComputeUnits = cliArguments.audioEncoderComputeUnits.asMLComputeUnits @@ -82,13 +90,13 @@ struct Transcribe: AsyncParsableCommand { let options = DecodingOptions( verbose: cliArguments.verbose, - task: .transcribe, + task: task, language: cliArguments.language, temperature: cliArguments.temperature, temperatureIncrementOnFallback: cliArguments.temperatureIncrementOnFallback, temperatureFallbackCount: cliArguments.temperatureFallbackCount, topK: cliArguments.bestOf, - usePrefillPrompt: cliArguments.usePrefillPrompt, + usePrefillPrompt: cliArguments.usePrefillPrompt || cliArguments.language != nil, usePrefillCache: cliArguments.usePrefillCache, skipSpecialTokens: cliArguments.skipSpecialTokens, withoutTimestamps: cliArguments.withoutTimestamps,