-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added MLX Audio Encoder #139
Added MLX Audio Encoder #139
Conversation
|
||
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) | ||
public class MLXAudioEncoder: AudioEncoding { | ||
|
||
public var model: MLModel? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WhisperMLModel
requires this property, however for MLX models it'll be always nil
, maybe we could remove it from the protocol requirement, wdyt @ZachNagengast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two options I see:
- Pull the MLModel stuff into the coreml specific AudioEncoder
- Make
model
a genericWhisperModel
type so we can store it onMLXAudioEncoder
and store it asmodel
so we can use the same protocols that WhisperMLModel has
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd opt for non-generic approach, I'd mean that WhisperMLModel
would looks like this
public protocol WhisperMLModel {
mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
mutating func unloadModel()
}
I'd mean as well that we'd have to remove the protocol extension and implement the load and unload for every class which conforms to WhisperMLModel
protocol. I guess that's ok, we could pull out the current default implementation to some helper function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was just thinking about this in a different PR and remembered we do check the model type there if var featureExtractor = featureExtractor as? WhisperMLModel {
So lets actually keep WhisperMLModel as-is, and create a new WhisperMLXModel
, this will allow us to keep everything generic, and handle different types of models explicitly in WhisperKit.loadModels. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we do this
public protocol WhisperModel {
mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
mutating func unloadModel()
}
public protocol WhisperMLModel: WhisperModel {
var model: MLModel? { get set }
}
public extension WhisperMLModel {
// leave it as is
}
- Create new protocol
WhisperModel
which contains just loading and unloading methods WhisperMLModel
inherits fromWhisperModel
and contains justmodel
property- This way we could (if needed) create
WhisperMLXModel
which will inherit fromWhisperModel
as well. But if it's not needed,MLXAudioEncoder
could just conform toWhisperModel
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more idea
public protocol WhisperModel {
// basic load model method
mutating func loadModel(at modelPath: URL) async throws
mutating func unloadModel()
}
public protocol WhisperMLModel: WhisperModel {
var model: MLModel? { get set }
// specialized load model method
mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
}
public protocol WhisperMLXModel: WhisperModel {}
public extension WhisperMLModel {
// default implementation for the basic load model method
mutating func loadModel(at modelPath: URL) async throws {
try await loadModel(at: modelPath, computeUnits: .all, prewarmMode: false)
}
// leave the rest it as is
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe it's just nitpicking and mutating func loadModel(at modelPath: URL) async throws
can be moved to WhisperMLXModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that too, ok with it as long as WhisperMLModels don't also need to implement the base load model, otherwise there's little benefit IMO. Its not pushed yet but I'm building into this protocol a modelState variable so we know what is loaded or unloaded as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe it's just nitpicking and
mutating func loadModel(at modelPath: URL) async throws
can be moved toWhisperMLXModel
?
Totally ok with this because the model itself can only be used by its parent class, so all the interfaces will be specific to the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe it's just nitpicking and
mutating func loadModel(at modelPath: URL) async throws
can be moved toWhisperMLXModel
?Totally ok with this because the model itself can only be used by its parent class, so all the interfaces will be specific to the model.
ok will do it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great now, will merge it into the mlx-support branch and we can fix the tests there 🙌
This PR adds MLX Audio Encoder
The implementation is based on the
AudioEncoder
from themlx-examples
repository.To make sure the audio encoder works as expected, I have added the weights loading functionality. The weights are taken from https://huggingface.co/jkrukowski/whisper-tiny-mlx-safetensors repository. This repository contains the weights for the
whisper-tiny-mlx
model transformed to thesafetensors
format (for now MLX Swift does not have the ability to load.npz
files). I have added the MLX weights download functionality toMakefile
to make sure the tests are run correctly. This could be removed in the future once the MLX branch is fully integrated into the main repository and we decide on the best way to handle the weights.I have changed the project structure a bit. I have moved the common test utilities to the
WhisperKitTestsUtils
target. The resources used for testing (audio files and models) are moved there as well. This way we can reuse resources in both, MLX and non-MLX tests. Additionally, it simiplifies the project structure a bit --WhisperKitTests
target no longer has to have the custom path and bunch of excluded files.