Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MLX Audio Encoder #139

Merged
merged 3 commits into from
May 22, 2024

Conversation

jkrukowski
Copy link
Contributor

@jkrukowski jkrukowski commented May 16, 2024

This PR adds MLX Audio Encoder

The implementation is based on the AudioEncoder from the mlx-examples repository.

To make sure the audio encoder works as expected, I have added the weights loading functionality. The weights are taken from https://huggingface.co/jkrukowski/whisper-tiny-mlx-safetensors repository. This repository contains the weights for the whisper-tiny-mlx model transformed to the safetensors format (for now MLX Swift does not have the ability to load .npz files). I have added the MLX weights download functionality to Makefile to make sure the tests are run correctly. This could be removed in the future once the MLX branch is fully integrated into the main repository and we decide on the best way to handle the weights.

I have changed the project structure a bit. I have moved the common test utilities to the WhisperKitTestsUtils target. The resources used for testing (audio files and models) are moved there as well. This way we can reuse resources in both, MLX and non-MLX tests. Additionally, it simiplifies the project structure a bit -- WhisperKitTests target no longer has to have the custom path and bunch of excluded files.

@jkrukowski jkrukowski changed the base branch from main to mlx-support May 16, 2024 09:04

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class MLXAudioEncoder: AudioEncoding {

public var model: MLModel?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WhisperMLModel requires this property, however for MLX models it'll be always nil, maybe we could remove it from the protocol requirement, wdyt @ZachNagengast?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two options I see:

  1. Pull the MLModel stuff into the coreml specific AudioEncoder
  2. Make model a generic WhisperModel type so we can store it on MLXAudioEncoder and store it as model so we can use the same protocols that WhisperMLModel has

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd opt for non-generic approach, I'd mean that WhisperMLModel would looks like this

public protocol WhisperMLModel {
    mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
    mutating func unloadModel()
}

I'd mean as well that we'd have to remove the protocol extension and implement the load and unload for every class which conforms to WhisperMLModel protocol. I guess that's ok, we could pull out the current default implementation to some helper function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just thinking about this in a different PR and remembered we do check the model type there if var featureExtractor = featureExtractor as? WhisperMLModel {
So lets actually keep WhisperMLModel as-is, and create a new WhisperMLXModel, this will allow us to keep everything generic, and handle different types of models explicitly in WhisperKit.loadModels. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we do this

public protocol WhisperModel {
    mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
    mutating func unloadModel()
}

public protocol WhisperMLModel: WhisperModel {
    var model: MLModel? { get set }
}

public extension WhisperMLModel {
    // leave it as is
}
  1. Create new protocol WhisperModel which contains just loading and unloading methods
  2. WhisperMLModel inherits from WhisperModel and contains just model property
  3. This way we could (if needed) create WhisperMLXModel which will inherit from WhisperModel as well. But if it's not needed, MLXAudioEncoder could just conform to WhisperModel

wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more idea

public protocol WhisperModel {
    // basic load model method
    mutating func loadModel(at modelPath: URL) async throws
    mutating func unloadModel()
}

public protocol WhisperMLModel: WhisperModel {
    var model: MLModel? { get set }
    // specialized load model method
    mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
}

public protocol WhisperMLXModel: WhisperModel {}

public extension WhisperMLModel {
    // default implementation for the basic load model method
    mutating func loadModel(at modelPath: URL) async throws {
        try await loadModel(at: modelPath, computeUnits: .all, prewarmMode: false)
    }
    // leave the rest it as is
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe it's just nitpicking and mutating func loadModel(at modelPath: URL) async throws can be moved to WhisperMLXModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that too, ok with it as long as WhisperMLModels don't also need to implement the base load model, otherwise there's little benefit IMO. Its not pushed yet but I'm building into this protocol a modelState variable so we know what is loaded or unloaded as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe it's just nitpicking and mutating func loadModel(at modelPath: URL) async throws can be moved to WhisperMLXModel?

Totally ok with this because the model itself can only be used by its parent class, so all the interfaces will be specific to the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe it's just nitpicking and mutating func loadModel(at modelPath: URL) async throws can be moved to WhisperMLXModel?

Totally ok with this because the model itself can only be used by its parent class, so all the interfaces will be specific to the model.

ok will do it

Copy link
Contributor

@ZachNagengast ZachNagengast left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great now, will merge it into the mlx-support branch and we can fix the tests there 🙌

@ZachNagengast ZachNagengast merged commit 7cc004b into argmaxinc:mlx-support May 22, 2024
13 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants