Skip to content

Question regarding to paper/code correspondence #36

@tokkiwa

Description

@tokkiwa

Hello authors, thank you for sharing your insightful work!
I have a question regarding to parameter sharing strategy shown in the paper and the code. In the paper page 7, you mentioned In experiments, we notice that for rational function, we share the denominator coefficient b_n among all groups and use different a_m for each group. It gets better performance.

So I understand that the model has a single combination of denominator weight and multiple (=8) combination of numerator weight. However, the code seems to have single numerator weight and eight denominator weights:

https://github.com/Adamdad/rational_kat_cu/blob/181bae8baf19075bef94b5f62dac320b3d4b27d3/kat_rational/kat_1dgroup_triton.py#L59-L61

            weight_numerator = torch.tensor(data[mode]["init_w_numerator"]).view(1, -1)
            weight_denominator = torch.tensor(data[mode]["init_w_denominator"])
            weight_denominator = torch.cat([weight_denominator] * self.num_groups).view(self.num_groups, -1)

https://github.com/Adamdad/rational_kat_cu/blob/181bae8baf19075bef94b5f62dac320b3d4b27d3/kat_rational/kat_1dgroup_triton.py#L85-L87

        # Repeat the weights for all groups
        weight_numerator = self.weight_numerator.repeat(self.num_groups, 1)
        return self.rational(input, weight_numerator, self.weight_denominator, self.num_groups)

Which one did you really intend?
Thank you in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions