Skip to content

Commit

Permalink
skip functional tests for models that are not downloaded. (#48)
Browse files Browse the repository at this point in the history
* skip tests for models that are not downloaded, but assume that openai_whisper-tiny and openai_whisper-large-v3 are downloaded.

* fix typo.

* Revert to before the typo fix

* Revert "Revert to before the typo fix"

This reverts commit 34e26bc.
  • Loading branch information
metropol committed Mar 9, 2024
1 parent bfa357e commit 37cf113
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ struct ContentView: View {
localModels.append(model)
}

availableLanguages = whisperKit.tokenizer?.langauges.map { $0.key }.sorted() ?? ["english"]
availableLanguages = whisperKit.tokenizer?.languages.map { $0.key }.sorted() ?? ["english"]
loadingProgressValue = 1.0
modelState = whisperKit.modelState
}
Expand Down Expand Up @@ -1009,7 +1009,7 @@ struct ContentView: View {
func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? {
guard let whisperKit = whisperKit else { return nil }

let languageCode = whisperKit.tokenizer?.langauges[selectedLanguage] ?? "en"
let languageCode = whisperKit.tokenizer?.languages[selectedLanguage] ?? "en"
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip = [lastConfirmedSegmentEndSeconds]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ struct WhisperAXWatchView: View {
try await whisperKit.loadModels()

await MainActor.run {
availableLanguages = whisperKit.tokenizer?.langauges.map { $0.key }.sorted() ?? ["english"]
availableLanguages = whisperKit.tokenizer?.languages.map { $0.key }.sorted() ?? ["english"]
loadingProgressValue = 1.0
modelState = whisperKit.modelState
}
Expand Down Expand Up @@ -491,7 +491,7 @@ struct WhisperAXWatchView: View {
func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? {
guard let whisperKit = whisperKit else { return nil }

let languageCode = whisperKit.tokenizer?.langauges[selectedLanguage] ?? "en"
let languageCode = whisperKit.tokenizer?.languages[selectedLanguage] ?? "en"
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip = [lastConfirmedSegmentEndSeconds]

Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ public extension Tokenizer {
return false
}

var langauges: [String: String] { [
var languages: [String: String] { [
"english": "en",
"chinese": "zh",
"german": "de",
Expand Down
53 changes: 53 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,24 @@ final class UnitTests: XCTestCase {
XCTAssertEqual(mergedAlignmentTiming[i].probability, expectedWordTimings[i].probability, "Probability at index \(i) does not match")
}
}

func testGitLFSPointerFile() {
// Assumption:
// 1 - the openai_whisper-tiny is downloaded locally. This means that the proxyFile is an actual data file.
// 2 - the openai_whisper-large-v3_turbo is not downloaded locally. This means that the proxyFile is pointer file.
let proxyFile = "AudioEncoder.mlmodelc/coremldata.bin"

// First, we check that a data file is not considered a git lfs pointer file.
var filePath = URL(filePath: tinyModelPath()).appending(path: proxyFile)
var isPointerFile = isGitLFSPointerFile(url: filePath)
XCTAssertEqual(isPointerFile, false, "Assuming whisper-tiny was downloaded, \(proxyFile) should not be a git-lfs pointer file.")

// Second, we check that a pointer file is considered so.
let modelDir = largev3TurboModelPath()
filePath = URL(filePath: modelDir).appending(path: proxyFile)
isPointerFile = isGitLFSPointerFile(url: filePath)
XCTAssertEqual(isPointerFile, true, "Assuming whisper-large-v3_turbo was not downloaded, \(proxyFile) should be a git-lfs pointer file.")
}
}

// MARK: Helpers
Expand Down Expand Up @@ -907,6 +925,15 @@ extension XCTestCase {
return modelPath
}

func largev3TurboModelPath() -> String {
let modelDir = "whisperkit-coreml/openai_whisper-large-v3_turbo"
guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "mlmodelc", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else {
print("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-models`")
return ""
}
return modelPath
}

func allModelPaths() -> [String] {
let fileManager = FileManager.default
var modelPaths: [String] = []
Expand All @@ -924,6 +951,13 @@ extension XCTestCase {
for folderURL in directoryContents {
let resourceValues = try folderURL.resourceValues(forKeys: Set(resourceKeys))
if resourceValues.isDirectory == true {
// Check if the directory contains actual data files, or if it contains pointer files.
// As a proxy, use the MelSpectrogramc.mlmodel/coredata.bin file.
let proxyFileToCheck = folderURL.appendingPathComponent("MelSpectrogram.mlmodelc/coremldata.bin")
if isGitLFSPointerFile(url: proxyFileToCheck) {
continue
}

// Check if the directory name contains the quantization pattern
// Only test large quantized models
let dirName = folderURL.lastPathComponent
Expand All @@ -938,6 +972,25 @@ extension XCTestCase {

return modelPaths
}

// Function to check if the beginning of the file matches a Git LFS pointer pattern
func isGitLFSPointerFile(url: URL) -> Bool {
do {
let fileHandle = try FileHandle(forReadingFrom: url)
// Read the first few bytes of the file to get enough for the Git LFS pointer signature
let data = fileHandle.readData(ofLength: 512) // Read first 512 bytes
fileHandle.closeFile()

if let string = String(data: data, encoding: .utf8),
string.starts(with: "version https://git-lfs.github.com/") {
return true
}
} catch {
print("Failed to read file: \(error)")
}

return false
}

func trackForMemoryLeaks(on instance: AnyObject, file: StaticString = #filePath, line: UInt = #line) {
addTeardownBlock { [weak instance] in
Expand Down

0 comments on commit 37cf113

Please sign in to comment.