Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multihead classification #6

Open
appledora opened this issue Apr 16, 2024 · 1 comment
Open

Multihead classification #6

appledora opened this issue Apr 16, 2024 · 1 comment

Comments

@appledora
Copy link

appledora commented Apr 16, 2024

Hi, thank you for the great baseline repo.
I am trying to setup each dataset (with different number of classes) as a task and perform continual learning. However I am a little lost regarding how the models are handling the multiple heads with potentially different output features. Do you have any suggestions on how this might be addressed?

Currently I am using something like:

class MultiTaskModel(ResNet):
    def __init__(self):
        super(MultiTaskModel, self).__init__(BasicBlock, [2,2,2,2])
        resnet = torchvision.models.resnet34(pretrained=True)
        self.in_feature = resnet.fc.in_features
        self.tasks = []
        self.fc = None

        # add all layers that are not fc or classifier to the model
        self.shared = nn.Sequential()
        for name, module in resnet.named_children():
            if name != 'fc' and name != 'classifiers':
                self.shared.add_module(name, module)
        # self.classifiers.append(resnet.fc)

    
    def set_task(self, task):
        print("Setting task to", task)
        self.tasks.append(task)
        print(f"tasks are {self.tasks}")
        print(f"task index is {task_list.index(task)}")
        # add a new fc layer for the new task
        self.fc = nn.Linear(self.in_feature, classes_per_task[task_list.index(task)])
        self.fc.apply(kaiming_normal_init)
        print(f"fc is {self.fc}")
        
    def forward(self, x):
        x = self.shared(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

But unable to replicate the results for some methods, such as the EWC . If I use the default lambda = 400 as in the repo, the loss becomes Nan .
Currently, I am maintaining the sequence of recogseq dataset but observing Catastrophic forgetting for EWC and LWF. To be precise, I calculated the performance of each task after training the last task.

This is my training loop:

for task in dataset_names:
    train_loader, val_loader, _, _, _ = get_dataloaders(task, 0.8, batch_size)
    all_train_loaders[task] = train_loader
    all_val_loaders[task] = val_loader

for idx, task in tqdm.tqdm(enumerate(task_list)):
    current_train_loader = all_train_loaders[task]
    current_val_loader = all_val_loaders[task]
    model = MultiTaskModel().to(device)
    if idx > 0:
        print(f"Previous task: {task_list[idx-1]}")
        ckpt = torch.load(f"epoch_{task_list[idx-1]}.pth.tar")
        model.load_state_dict(ckpt['state_dict'])
        model = model.to(device)

    start_time = time.time()
    model, acc = fine_tune_EWC_acuumelation(current_train_loader, current_val_loader, model, reg_lambda=1, num_epochs=num_epochs, lr=0.008, batch_size=batch_size, weight_decay=0, current_task=task)
    ```

Here's a sample output : 

```bash
Model loaded for task svhn

Performance of previous task: flowers
fc set to Linear(in_features=512, out_features=103, bias=True)
Accuracy of the network on the 1311 test images: 3.75

Performance of previous task: scenes
fc set to Linear(in_features=512, out_features=67, bias=True)
Accuracy of the network on the 3123 test images: 5.013020833333333

Performance of previous task: birds
fc set to Linear(in_features=512, out_features=201, bias=True)
Accuracy of the network on the 2358 test images: 0.9982638888888888

Performance of previous task: cars
fc set to Linear(in_features=512, out_features=196, bias=True)
Accuracy of the network on the 1621 test images: 0.6875

Performance of previous task: aircraft
fc set to Linear(in_features=512, out_features=56, bias=True)
Accuracy of the network on the 2000 test images: 15.574596774193548

Performance of previous task: chars
fc set to Linear(in_features=512, out_features=63, bias=True)
Accuracy of the network on the 12599 test images: 43.247767857142854

Performance of previous task: svhn
fc set to Linear(in_features=512, out_features=10, bias=True)
Accuracy of the network on the 26032 test images: 96.13223522167488

Running accuracy for task svhn is [3.75, 5.013020833333333, 0.9982638888888888, 0.6875, 15.574596774193548, 43.247767857142854, 96.13223522167488]
Mean accuracy for task svhn is 23.629054939319072
@appledora
Copy link
Author

@Mattdl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant