-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Imprecise conversion for custom log_softmax converter #21
Comments
Hi! First, here's a proper converter: @converter(F.log_softmax, torch.log_softmax, torch.Tensor.log_softmax, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_log_softmax(input: Tensor, dim, *, dtype: Optional[_dtype]=None):
num_dims = input.dim()
def func(input, dim, *, dtype=None):
if get_channel_order(input) == ChannelOrder.TENSORFLOW:
dim = dim_pytorch2keras(dim, num_dims)
return tf.nn.log_softmax(input, axis=dim)
return func Turns out, the converter is already there, and I registered it for Your implementation does not work correctly because due to The |
Awesome, thanks! Got lightglue built, and I think it is working... want any of these chatgpt-sourced converters? |
Yep, I'll take it. A word of warning: there's a good chance the model was not converted properly. See, LightGlue is quite dynamic, the number of iterations depends on the difficulty of the task, and Nobuco cannot automatically capture control flows. It might still work as-is, but you'd be missing out on performance. Speaking of performance. Tensorflow lags behind Pytorch considerably when it comes to transformers. Pytorch already integrated fast attention operations (notably, |
Thanks for the heads-up. I'm guessing that the conversion was suboptimal. The h5 file is ~80mb, but I'm at least getting the same correspondence as the pytorch model. I did have to make some changes (disabled pruning, fixed the number of iterations / disabled early stopping, etc) |
(Edit)
Hello, I am working to convert LightGlue to Tensorflow (I ultimately want to get to TFlite) using nobuco + some help from ChatGPT to create some of the conversion functions ;)
I am still in the process of performing the conversion, but had a question. I'm seeing very imprecise conversion, and am not sure why this would be the case. I'm trying to rule out any issues in my implementation.
Here is the conversion function I am using for log_softmax:
Nobuco is indicating a significant discrepancy for log_softmax in the log:
Here is the code snippet calling log_softmax:
I also have a colab notebook with my progress so far:
https://colab.research.google.com/gist/coxep/65ac46a1edc6d262c302efa1813625df/demo.ipynb
Thank you for any assistance :)
The text was updated successfully, but these errors were encountered: