You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thank you for the great baseline repo.
I am trying to setup each dataset (with different number of classes) as a task and perform continual learning. However I am a little lost regarding how the models are handling the multiple heads with potentially different output features. Do you have any suggestions on how this might be addressed?
Currently I am using something like:
classMultiTaskModel(ResNet):
def__init__(self):
super(MultiTaskModel, self).__init__(BasicBlock, [2,2,2,2])
resnet=torchvision.models.resnet34(pretrained=True)
self.in_feature=resnet.fc.in_featuresself.tasks= []
self.fc=None# add all layers that are not fc or classifier to the modelself.shared=nn.Sequential()
forname, moduleinresnet.named_children():
ifname!='fc'andname!='classifiers':
self.shared.add_module(name, module)
# self.classifiers.append(resnet.fc)defset_task(self, task):
print("Setting task to", task)
self.tasks.append(task)
print(f"tasks are {self.tasks}")
print(f"task index is {task_list.index(task)}")
# add a new fc layer for the new taskself.fc=nn.Linear(self.in_feature, classes_per_task[task_list.index(task)])
self.fc.apply(kaiming_normal_init)
print(f"fc is {self.fc}")
defforward(self, x):
x=self.shared(x)
x=x.view(x.size(0), -1)
x=self.fc(x)
returnx
But unable to replicate the results for some methods, such as the EWC . If I use the default lambda = 400 as in the repo, the loss becomes Nan .
Currently, I am maintaining the sequence of recogseq dataset but observing Catastrophic forgetting for EWC and LWF. To be precise, I calculated the performance of each task after training the last task.
Hi, thank you for the great baseline repo.
I am trying to setup each dataset (with different number of classes) as a task and perform continual learning. However I am a little lost regarding how the models are handling the multiple heads with potentially different output features. Do you have any suggestions on how this might be addressed?
Currently I am using something like:
But unable to replicate the results for some methods, such as the
EWC
. If I use the defaultlambda = 400
as in the repo, the loss becomesNan
.Currently, I am maintaining the sequence of recogseq dataset but observing Catastrophic forgetting for EWC and LWF. To be precise, I calculated the performance of each task after training the last task.
This is my training loop:
The text was updated successfully, but these errors were encountered: