Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient computation in case of multi-head (task incremental setting) #7

Closed
bhattg opened this issue Apr 17, 2021 · 4 comments
Closed

Comments

@bhattg
Copy link

bhattg commented Apr 17, 2021

Hi,

Thanks for the great baseline. I was looking at the code implementation for A-GEM models/agem.py and in particular the training pipeline for the multi-headed case. In this regard, I have a few doubts about the code.

During the forward pass in the case of a multi-headed setting (Task-IL), when we say have classes [0, 1, 2, 3, 4] in task 0, why is there no-masking? From what I've understood from the code flow, the following calls happen:

(I am only mentioning the steps for the A-GEM, therefore will be omitting the steps not required in A-GEM)

for epoch in range(args.n_epochs):
    for i, data in enumerate(train_loader):
        if hasattr(dataset.train_loader.dataset, 'logits'):
            # ignoring this
        else:
            inputs, labels, not_aug_inputs = data
            inputs, labels = inputs.to(model.device), labels.to(
                model.device)
            not_aug_inputs = not_aug_inputs.to(model.device)
            loss = model.observe(inputs, labels, not_aug_inputs)

When model.observe(inputs, labels, not_aug_inputs is called, it executes the following steps:

self.zero_grad()
p = self.net.forward(inputs)
loss = self.loss(p, labels) #Here why there is no masking on the output p, for the classes corresponding to the other tasks? 
loss.backward()
self.opt.step()
return loss.item()

As I've mentioned in the comment on line 3, why there is no masking on the output p, for the classes corresponding to the other tasks? If there is no masking then the weights corresponding to the classes in the other tasks will also change, and thus defeating the purpose for multi-head.

Similarly, for computing the reference gradient (in the observe function, inside agem.py), why there is no masking corresponding to the tasks in the memory buffer?

if not self.buffer.is_empty():
    store_grad(self.parameters, self.grad_xy, self.grad_dims)
    buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform)
    self.net.zero_grad()
    buf_outputs = self.net.forward(buf_inputs)
    penalty = self.loss(buf_outputs, buf_labels)
    penalty.backward()

All of these things are present in the original A-GEM source code (TF implementation) Line 392 to Line 426

Please let me know your thoughts on my concerns. Thanks for the PyTorch implementations.

@mbosc
Copy link
Collaborator

mbosc commented Apr 19, 2021

Hello!

In our work we primarily focus on the class-il setting: using one single head for all predictions. This is how we train all models that run on non-domain-il settings, for which then we report both the class-il accuracy and the task-masked accuracy (i.e. the accuracy when masking out wrong tasks). So, yes, you're right, this is not precisely task-il, but rather a slightly more difficult evaluation, where the model is trained on a single head, but then given task ids at test-time.

The reason for this is that we get both class and task accuracy with a single run; in doing so, we prioritize class-il due to its being generally acknowledged as more realistic and difficult than task-il. If you're willing to get pure task-il results, you have to mask as you specify; however, note that this makes the class-il numbers you get as output incorrect.

@bhattg
Copy link
Author

bhattg commented Apr 19, 2021

Hi,

Thanks for the response! From what I've understood, if ordering the difficulty then the following order should hold (in general)

Task-IL (pure multi-head setting) < Task-IL (only evaluation task id provided) < Class- IL (most realistic)

@bhattg bhattg closed this as completed Apr 19, 2021
@mbosc
Copy link
Collaborator

mbosc commented Apr 19, 2021

Yes, I think you're correct! Some good references for a comparison of existing settings are https://arxiv.org/abs/1805.09733 and https://arxiv.org/abs/1904.07734, although none of them directly addresses our eval-time-only Task-IL.

@bhattg
Copy link
Author

bhattg commented Apr 19, 2021

Thanks for the info!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants