Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial evaluation fails because of missing model adaptation #838

Closed
AlbinSou opened this issue Nov 26, 2021 · 1 comment 路 Fixed by #843
Closed

Initial evaluation fails because of missing model adaptation #838

AlbinSou opened this issue Nov 26, 2021 · 1 comment 路 Fixed by #843
Labels
bug Something isn't working Training Related to the Training module

Comments

@AlbinSou
Copy link
Collaborator

AlbinSou commented Nov 26, 2021

馃悰 Describe the bug
The initial evaluation phase, (when _periodic_eval is called with do_initial=True) fails due to missing model adaptation

  cl_strategy.train(experience, eval_streams=[val_stream], num_workers=args.num_workers)
  File "MYDIR/avalanche/training/strategies/base_strategy.py", line 269, in train
    self._periodic_eval(eval_streams, do_final=False, do_initial=True)
  File "/MYDIR/avalanche/training/strategies/base_strategy.py", line 361, in _periodic_eval
    prev_mode = _prev_model_training_modes[name]
KeyError: 'linear.classifiers.1'

More details:

  • The model is an instance of MultiTaskModule (requires task label during forward pass)
  • The classifier is an instance of MultiHeadClassifier

馃悳 To Reproduce

Here is the fully working example, just change the datadir variable to the local data directory.

#!/usr/bin/env python3
import argparse
import os
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam
from torchvision import transforms

from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.strategies import Naive
from avalanche.logging import TextLogger, InteractiveLogger, TensorboardLogger
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.models import as_multitask
from avalanche.models import SimpleCNN
from torchvision import transforms
from torchvision.transforms import ToTensor, RandomCrop, RandomHorizontalFlip
from torchvision.datasets import CIFAR100, CIFAR10
from avalanche.benchmarks import nc_benchmark, benchmark_with_validation_stream

def main():
    num_tasks = 10
    val_size = 0.05
    batch_size = 64
    nepochs = 1
    datadir = 'YOURDATADIR'
    
    
    train_transform = transforms.Compose([
        RandomCrop(32, padding=4),
        RandomHorizontalFlip(),
        ToTensor(), 
        transforms.Normalize((0.5071, 0.4866, 0.4409), 
        (0.2009, 0.1984, 0.2023))
    ])
    test_transform = transforms.Compose([
        ToTensor(),
        transforms.Normalize((0.5071, 0.4866, 0.4409), 
        (0.2009, 0.1984, 0.2023))
    ])
    
    cifar_train = CIFAR100(root=datadir, train=True, download=False)
    cifar_test = CIFAR100(root=datadir, train=False, download=False)
    
    # Get data
    scenario_nc = nc_benchmark(cifar_train, cifar_test, num_tasks, 
                               task_labels=True, train_transform=train_transform, 
                               eval_transform=test_transform, 
                               seed=0, class_ids_from_zero_in_each_exp=True)
    scenario = benchmark_with_validation_stream(scenario_nc, 
                                                validation_size=val_size, 
                                                shuffle=True)

    # Create model
    model = SimpleCNN(1)
    model = as_multitask(model, 'classifier')

    # Create optimizer
    optimizer = SGD(model.parameters(), lr=0.1, 
                    momentum=0.9, 
                    weight_decay=0.0002)

    interactive_logger = InteractiveLogger()
    loggers = [interactive_logger]

    evaluator = EvaluationPlugin(
        accuracy_metrics(epoch=True, stream=True),
        loss_metrics(epoch=True),
        loggers=loggers)

    plugins = []

    cl_strategy = Naive(model, optimizer,
                        criterion=CrossEntropyLoss(),
                        train_mb_size=batch_size, 
                        eval_mb_size=batch_size, 
                        device=torch.device('cuda'), 
                        train_epochs=nepochs, 
                        plugins=plugins, 
                        evaluator=evaluator, 
                        eval_every=1)
    
    # TRAINING LOOP
    print('Starting New classes experiment...')
    results = []
    for t, (experience, val_stream) in enumerate(zip(scenario.train_stream, 
                                                 scenario.valid_stream)):
        print("Start of experience: ", experience.current_experience)
        print("Current Classes: ", experience.classes_in_this_experience)

        cl_strategy.train(experience, eval_streams=[val_stream], num_workers=4)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        results.append(cl_strategy.eval(scenario.test_stream[:t+1]))

if __name__ == '__main__':
    main()

馃悵 Expected behavior
The model should be adapted on new experience dataset before being evaluated on it

馃 Additional context
I believe this is due to the fact that in this stage, self.experience is still the old experience, while the evaluation is performed on new experience validation stream. When I use instead eval_streams=[], the training happens normally and the test evaluation works as well.

@AlbinSou AlbinSou added the bug Something isn't working label Nov 26, 2021
@AntonioCarta AntonioCarta added the Training Related to the Training module label Nov 30, 2021
@AlbinSou
Copy link
Collaborator Author

AlbinSou commented Dec 1, 2021

@AntonioCarta I updated it with a fully working example

AntonioCarta added a commit to AntonioCarta/avalanche that referenced this issue Dec 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Training Related to the Training module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants