Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTK calculation incorrect for networks with multiple outputs? #11

Closed
awe2 opened this issue Aug 24, 2021 · 6 comments
Closed

NTK calculation incorrect for networks with multiple outputs? #11

awe2 opened this issue Aug 24, 2021 · 6 comments

Comments

@awe2
Copy link

awe2 commented Aug 24, 2021

Howdy!

In: https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py

on line 45:

logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)

I am confused about your calculation of the NTK, and believe that you may be misusing the first argument of the torch.Tensor.backward() function.

E.g.: when playing with the codebase with a very small 8 parameter network with 2 outputs:

class small(torch.nn.Module):
    def __init__(self,):
        super(small, self).__init__() 
        self.d1 = torch.nn.Linear(2,2,bias=False)
        self.d2 = torch.nn.Linear(2,2,bias=False)
    def forward(self, x):
        x = self.d1(x)
        x = self.d2(x)
        return x

Where for this explanation I have modified to:

gradient = torch.ones_like(logit[_idx:_idx+1])
gradient[0,0] = a
gradient[0,1] = b
logit[_idx:_idx+1].backward(gradient, retain_graph=True)

whereby J I mean your 'grad' list for a single network:

e.g.: lines 45 & 46:

grads = [torch.stack(_grads, 0) for _grads in grads]
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
print('J: ',grads)

for

gradient[0,0] = 0
gradient[0,1] = 1

J: [tensor([[-0.6255, -0.5019, 0.1758, 0.1411, 0.0000, 0.0000, -0.0727, -0.4643],
[ 0.9368, -0.0947, -0.2633, 0.0266, 0.0000, 0.0000, 0.0955, -0.0812]])]

=======

for

gradient[0,0] = 1
gradient[0,1] = 0

J: [tensor([[ 0.1540, 0.1236, -0.6473, -0.5194, -0.0727, -0.4643, 0.0000, 0.0000],
[-0.2307, 0.0233, 0.9694, -0.0980, 0.0955, -0.0812, 0.0000, 0.0000]])]

=======

for

gradient[0,0] = 1
gradient[0,1] = 1

J: [tensor([[-0.4715, -0.3783, -0.4715, -0.3783, -0.0727, -0.4643, -0.0727, -0.4643],
[ 0.7061, -0.0714, 0.7062, -0.0714, 0.0955, -0.0812, 0.0955, -0.0812]])]

"""

And so you can verify that your code is adding the two components together to get the last result.

The problem is that your Jacobian should have size: number_samples x [(number_outputs x number_weights)] ; See your own paper, page 2, where you show that the Jacobian's components are defined on the subscript i, the ith output of the model.

If I am right, then any network that has multiple outputs would have their NTK values incorrectly calculated, would have a time and memory footprint that is systematically reduced by the fact that these gradients are being pooled together.

@chenwydj
Copy link
Collaborator

chenwydj commented Aug 24, 2021

Hi @awe2,

Thanks for your question and interest in our work!

Yes, I understand your concern. And you are right, what I am doing is equivalent to summing up the output dim of the logit. In this case, it means I treat the output of the network as the sum of logit, instead of a multi-dim output. It will be much faster than back-propagating through each output dim, and it also works well. It is definitely possible to expand the NTK into [num_samples * num_out_dim] x [num_samples * num_out_dim]

Hope that helps!

@awe2
Copy link
Author

awe2 commented Aug 24, 2021

It is certainly faster, but I can calculate the NTK by hand for my simple network, and the results aren't the same. I'm not a math whiz, do you expect that this transformation leaves the value of "condition number" unchanged?

@awe2
Copy link
Author

awe2 commented Aug 25, 2021

If the output of the network is the sum of logit instead of the logits themselves, would the neural architectures you are searching over have the same response? i.e., I see an immediate application value to searching over architectures for image classification where we are (naively) concerned with networks that output logits, but I don't know the value of a search over networks that predict the sum of logits?

I'm relatively new to the field-- is there something I am overlooking?

@chenwydj
Copy link
Collaborator

Thanks for your questions.

  1. I am not saying that "sum up logits then backpropagate" or "backpropagate each output-dim" will give you the same condition number. I highly believe they will be different. The core difference of these two ways in the NAS setting (i.e. rank the architectures) is: which one gives better correlation? By treat the sum of logit as the function's single 1D output, or treat the neural network as a function of multi-dim outputs. Our paper did not have an answer to that. We just show that NTK of networks with 1D output shows a strong correlation of its classification accuracy.
  2. Different architectures will give different responses even if I treat them as 1D output functions, as demonstrated in our Fig.1 and our experiments. Again, I agree with you that calculating NTK of multi-dim outputs is doable.

@awe2
Copy link
Author

awe2 commented Aug 25, 2021

ah, Perfect, I think I understand. Thanks!

@awe2 awe2 closed this as completed Aug 25, 2021
@j0hngou
Copy link

j0hngou commented Jan 13, 2023

According to this paper, the pseudo-NTK (sum of logits) converges to the true empirical NTK at initialization for any network with a wide enough final layer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants