Skip to content

Commit

Permalink
CLI Task Handling (#85)
Browse files Browse the repository at this point in the history
* Add task handling to CLI

* Only cancel in-progress jobs for current workflow
  • Loading branch information
ZachNagengast committed Mar 22, 2024
1 parent cf75348 commit ae1cf96
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/development-tests.yml
Expand Up @@ -8,7 +8,7 @@ on:
workflow_dispatch:

concurrency:
group: ${{ github.ref_name }}
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
Expand Down
8 changes: 6 additions & 2 deletions Sources/WhisperKitCLI/CLIArguments.swift
Expand Up @@ -21,15 +21,19 @@ struct CLIArguments: ParsableArguments {

@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine

@Option(help: "Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var textDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine

@Flag(help: "Verbose mode")
var verbose: Bool = false


@Option(help: "Task to perform (transcribe or translate)")
var task: String = "transcribe"

@Option(help: "Language spoken in the audio")
var language: String?

@Option(help: "Temperature to use for sampling")
var temperature: Float = 0

Expand Down
14 changes: 11 additions & 3 deletions Sources/WhisperKitCLI/Transcribe.swift
Expand Up @@ -28,8 +28,16 @@ struct Transcribe: AsyncParsableCommand {
guard FileManager.default.fileExists(atPath: resolvedAudioPath) else {
throw CocoaError.error(.fileNoSuchFile)
}

let task: DecodingTask
if cliArguments.task.lowercased() == "translate" {
task = .translate
} else {
task = .transcribe
}

if cliArguments.verbose {
print("Transcribing audio at \(cliArguments.audioPath)")
print("Task: \(task.description.capitalized) audio at \(cliArguments.audioPath)")
}

var audioEncoderComputeUnits = cliArguments.audioEncoderComputeUnits.asMLComputeUnits
Expand Down Expand Up @@ -82,13 +90,13 @@ struct Transcribe: AsyncParsableCommand {

let options = DecodingOptions(
verbose: cliArguments.verbose,
task: .transcribe,
task: task,
language: cliArguments.language,
temperature: cliArguments.temperature,
temperatureIncrementOnFallback: cliArguments.temperatureIncrementOnFallback,
temperatureFallbackCount: cliArguments.temperatureFallbackCount,
topK: cliArguments.bestOf,
usePrefillPrompt: cliArguments.usePrefillPrompt,
usePrefillPrompt: cliArguments.usePrefillPrompt || cliArguments.language != nil,
usePrefillCache: cliArguments.usePrefillCache,
skipSpecialTokens: cliArguments.skipSpecialTokens,
withoutTimestamps: cliArguments.withoutTimestamps,
Expand Down

0 comments on commit ae1cf96

Please sign in to comment.