Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imprecise conversion for custom log_softmax converter #21

Closed
coxep opened this issue Jan 16, 2024 · 4 comments
Closed

Imprecise conversion for custom log_softmax converter #21

coxep opened this issue Jan 16, 2024 · 4 comments

Comments

@coxep
Copy link

coxep commented Jan 16, 2024

(Edit)

Hello, I am working to convert LightGlue to Tensorflow (I ultimately want to get to TFlite) using nobuco + some help from ChatGPT to create some of the conversion functions ;)

I am still in the process of performing the conversion, but had a question. I'm seeing very imprecise conversion, and am not sure why this would be the case. I'm trying to rule out any issues in my implementation.

Here is the conversion function I am using for log_softmax:

@converter(torch.nn.functional.log_softmax, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_log_softmax(input, dim, dtype=None):
    def func(input, dim, dtype=None):
        # Adjust 'dim' if it's negative to handle PyTorch's negative indexing
        if dim < 0:
            dim += len(input.shape)

        # Apply TensorFlow's log_softmax
        # If dtype is specified, cast the input tensor to this dtype first
        if dtype is not None:
            input = tf.cast(input, dtype)
        return tf.nn.log_softmax(input, axis=dim)

    return func

Nobuco is indicating a significant discrepancy for log_softmax in the log:

/usr/local/lib/python3.10/dist-packages/nobuco/converters/validation.py:55: RuntimeWarning: [<class 'lightglue.lightglue.TransformerLayer'>|LightGlue] conversion procedure might be incorrect: max. discrepancy for output #1 is 0.00010 (0.004%)
  warnings.warn(warn_string, category=RuntimeWarning)
/usr/local/lib/python3.10/dist-packages/nobuco/converters/validation.py:55: RuntimeWarning: [<class 'lightglue.lightglue.TransformerLayer'>|LightGlue] conversion procedure might be incorrect: max. discrepancy for output #0 is 0.00012 (0.005%)
  warnings.warn(warn_string, category=RuntimeWarning)
/usr/local/lib/python3.10/dist-packages/nobuco/converters/validation.py:55: RuntimeWarning: [<function log_softmax at 0x7bc0c800de10>|LightGlue->MatchAssignment] conversion procedure might be incorrect: max. discrepancy for output #0 is 38.75780 (103.477%)
  warnings.warn(warn_string, category=RuntimeWarning)
/usr/local/lib/python3.10/dist-packages/nobuco/converters/validation.py:55: RuntimeWarning: [<class 'lightglue.lightglue.MatchAssignment'>|LightGlue] conversion procedure might be incorrect: max. discrepancy for output #0 is 38.75780 (43.686%)
  warnings.warn(warn_string, category=RuntimeWarning)

Here is the code snippet calling log_softmax:

# Original implementation
# def sigmoid_log_double_softmax(
#     sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
# ) -> torch.Tensor:
#     """create the log assignment matrix from logits and similarity"""
#     b, m, n = sim.shape
#     certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
#     scores0 = F.log_softmax(sim, 2)
#     scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
#     scores = sim.new_full((b, m + 1, n + 1), 0)
#     scores[:, :m, :n] = scores0 + scores1 + certainties
#     scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
#     scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
#     return scores

# My implementation with some modifications to eliminate slicing
def sigmoid_log_double_softmax(sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
    """create the log assignment matrix from logits and similarity"""
    b, m, n = sim.shape

    # Calculate certainties and scores0, scores1 as before
    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
    scores0 = F.log_softmax(sim, 2)
    scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)

    # Create scores tensor
    scores = sim.new_full((b, m + 1, n + 1), 0)

    # Merge the scores0, scores1, and certainties into scores without slice assignment
    scores_main = scores0 + scores1 + certainties
    scores[:, :m, :n] = scores_main

    # Compute the scores for the last column and row
    last_col_scores = F.logsigmoid(-z0.squeeze(-1)).unsqueeze(2)
    last_row_scores = F.logsigmoid(-z1.squeeze(-1)).unsqueeze(1)

    # Update last column and row in scores
    scores[:, :-1, -1:] = last_col_scores
    scores[:, -1:, :-1] = last_row_scores

    return scores
    ```

I also have a colab notebook with my progress so far:
https://colab.research.google.com/gist/coxep/65ac46a1edc6d262c302efa1813625df/demo.ipynb

Thank you for any assistance :)

@coxep coxep closed this as completed Jan 16, 2024
@coxep coxep reopened this Jan 16, 2024
@coxep coxep changed the title Imprecise conversion - how to debug? Imprecise conversion for custom log_softmax converter Jan 16, 2024
@AlexanderLutsenko
Copy link
Owner

Hi! First, here's a proper converter:

@converter(F.log_softmax, torch.log_softmax, torch.Tensor.log_softmax, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_log_softmax(input: Tensor, dim, *, dtype: Optional[_dtype]=None):
    num_dims = input.dim()

    def func(input, dim, *, dtype=None):
        if get_channel_order(input) == ChannelOrder.TENSORFLOW:
            dim = dim_pytorch2keras(dim, num_dims)
        return tf.nn.log_softmax(input, axis=dim)
    return func

Turns out, the converter is already there, and I registered it for torch.log_softmax, torch.Tensor.log_softmax, but forgot about F.log_softmax.

Your implementation does not work correctly because due to ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS it can receive input tensor in channel-last layout, and in this case, dim has to be adapted accordingly.

The dim < 0 check is redundant, as it's done inside dim_pytorch2keras. Output type cast can also be omitted, because if needed, its' performed automatically inside the wrapper layer.

@coxep
Copy link
Author

coxep commented Jan 16, 2024

Awesome, thanks! Got lightglue built, and I think it is working... want any of these chatgpt-sourced converters?
coxep@5afcdf0

@coxep coxep closed this as completed Jan 16, 2024
@AlexanderLutsenko
Copy link
Owner

Awesome, thanks! Got lightglue built, and I think it is working... want any of these chatgpt-sourced converters? coxep@5afcdf0

Yep, I'll take it.

A word of warning: there's a good chance the model was not converted properly. See, LightGlue is quite dynamic, the number of iterations depends on the difficulty of the task, and Nobuco cannot automatically capture control flows. It might still work as-is, but you'd be missing out on performance. Speaking of performance. Tensorflow lags behind Pytorch considerably when it comes to transformers. Pytorch already integrated fast attention operations (notably, F.scaled_dot_product_attention), and there's no such thing in TFLite.

@coxep
Copy link
Author

coxep commented Jan 17, 2024

Thanks for the heads-up. I'm guessing that the conversion was suboptimal. The h5 file is ~80mb, but I'm at least getting the same correspondence as the pytorch model.

I did have to make some changes (disabled pruning, fixed the number of iterations / disabled early stopping, etc)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants