Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect language helper #146

Merged
merged 6 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public protocol TextDecoding {
) async throws -> DecodingResult

@available(*, deprecated, message: "Subject to removal in a future version. Use `decodeText(from:using:sampler:options:callback:) async throws -> DecodingResult` instead.")
@_disfavoredOverload
func decodeText(
from encoderOutput: MLMultiArray,
using decoderInputs: DecodingInputs,
Expand All @@ -58,6 +59,7 @@ public protocol TextDecoding {
) async throws -> DecodingResult

@available(*, deprecated, message: "Subject to removal in a future version. Use `detectLanguage(from:using:sampler:options:temperature:) async throws -> DecodingResult` instead.")
@_disfavoredOverload
func detectLanguage(
from encoderOutput: MLMultiArray,
using decoderInputs: DecodingInputs,
Expand Down
67 changes: 66 additions & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,71 @@ open class WhisperKit {
Logging.shared.loggingCallback = callback
}

// MARK: - Detect language

/// Detects the language of the audio file at the specified path.
///
/// - Parameter audioPath: The file path of the audio file.
/// - Returns: A tuple containing the detected language and the language log probabilities.
public func detectLanguage(
audioPath: String
) async throws -> (language: String, langProbs: [String: Float]) {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
return try await detectLangauge(audioArray: audioArray)
}

/// Detects the language of the audio samples in the provided array.
///
/// - Parameter audioArray: An array of audio samples.
/// - Returns: A tuple containing the detected language and the language log probabilities.
public func detectLangauge(
audioArray: [Float]
) async throws -> (language: String, langProbs: [String: Float]) {
if modelState != .loaded {
try await loadModels()
}

// Ensure the model is multilingual, as language detection is only supported for these models
guard textDecoder.isModelMultilingual else {
throw WhisperError.decodingFailed("Language detection not supported for this model")
}

// Tokenizer required for decoding
guard let tokenizer else {
throw WhisperError.tokenizerUnavailable()
}

let options = DecodingOptions()
let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken])
decoderInputs.kvCacheUpdateMask[0] = 1.0
decoderInputs.decoderKeyPaddingMask[0] = 0.0

// Detect language using up to the first 30 seconds
guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else {
throw WhisperError.transcriptionFailed("Audio samples are nil")
}
guard let melOutput = try await featureExtractor.logMelSpectrogram(fromAudio: audioSamples) else {
throw WhisperError.transcriptionFailed("Mel output is nil")
}
guard let encoderOutput = try await audioEncoder.encodeFeatures(melOutput) else {
throw WhisperError.transcriptionFailed("Encoder output is nil")
}

let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options)
guard let languageDecodingResult: DecodingResult = try? await textDecoder.detectLanguage(
from: encoderOutput,
using: decoderInputs,
sampler: tokenSampler,
options: options,
temperature: 0
) else {
throw WhisperError.decodingFailed("Language detection failed")
}

return (language: languageDecodingResult.language, langProbs: languageDecodingResult.languageProbs)
}

// MARK: - Transcribe multiple audio files

/// Convenience method to transcribe multiple audio files asynchronously and return the results as an array of optional arrays of `TranscriptionResult`.
Expand Down Expand Up @@ -398,7 +463,7 @@ open class WhisperKit {
/// - decodeOptions: Optional decoding options to customize the transcription process.
/// - callback: Optional callback to receive updates during the transcription process.
///
/// - Returns: An array of tuples, each containing the file path and a `Result` object with either a successful transcription result or an error.
/// - Returns: An array of `Result` objects with either a successful transcription result or an error.
public func transcribeWithResults(
audioPaths: [String],
decodeOptions: DecodingOptions? = nil,
Expand Down
31 changes: 26 additions & 5 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,27 @@ final class UnitTests: XCTestCase {
}
}

func testDetectLanguageHelperMethod() async throws {
let targetLanguages = ["es", "ja"]
let whisperKit = try await WhisperKit(
modelFolder: tinyModelPath(),
verbose: true,
logLevel: .debug
)

for language in targetLanguages {
let audioFilePath = try XCTUnwrap(
Bundle.module.path(forResource: "\(language)_test_clip", ofType: "wav"),
"Audio file not found"
)

// To detect language with the helper, just call the detect method with an audio file path
let result = try await whisperKit.detectLanguage(audioPath: audioFilePath)

XCTAssertEqual(result.language, language)
}
}

func testNoTimestamps() async throws {
let options = DecodingOptions(withoutTimestamps: true)

Expand Down Expand Up @@ -1147,11 +1168,11 @@ final class UnitTests: XCTestCase {

// Select few sentences to compare at VAD border
// TODO: test that WER is in acceptable range
XCTAssertTrue(testResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(testResult.text.normalized)")
XCTAssertTrue(chunkedResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(chunkedResult.text.normalized)")

XCTAssertTrue(testResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(testResult.text.normalized)")
XCTAssertTrue(chunkedResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
// XCTAssertTrue(testResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(testResult.text.normalized)")
// XCTAssertTrue(chunkedResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
//
// XCTAssertTrue(testResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(testResult.text.normalized)")
// XCTAssertTrue(chunkedResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(chunkedResult.text.normalized)")

XCTAssertTrue(testResult.text.normalized.contains("But then came my 90 page senior".normalized), "Expected text not found in \(testResult.text.normalized)")
XCTAssertTrue(chunkedResult.text.normalized.contains("But then came my 90 page senior".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
Expand Down