Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generative Replay #931

Merged
merged 58 commits into from Apr 8, 2022
Merged

Conversation

travela
Copy link
Contributor

@travela travela commented Mar 9, 2022

closes #927

This is adds a GenerativeReplayPlugin and a GenerativeReplay strategy to the library with which one can train models according to the vanilla Generative Replay algorithm. One can either apply it to a generator model alone or a pair consisting of a classifier and a generator.

In particular I have added two usage examples where we train two models on the splitMNIST scenario, i.e.:

  • the SimpleMLP model in examples/generative_replay_splitMNIST.py and
  • a VAE model in examples/generative_replay_MNIST_generator.py

into init of solver strategy.
extend GR plugin to work without generator initialization;
clean up GR template and make it more modular; rename VAETraining
VAETraining can now be trained alone, with or without GR simply by adding the Plugin.
@AntonioCarta
Copy link
Collaborator

Hey @travela thanks for your contribution. It looks solid in general, I just left a couple of minor comments.
Did you reproduce the results on Split MNIST? It would be ideal to have a script to add to the reproducible-cl repository, as we do for the other strategies.

@coveralls
Copy link

Pull Request Test Coverage Report for Build 1958289232

  • 56 of 158 (35.44%) changed or added relevant lines in 5 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-0.5%) to 77.416%

Changes Missing Coverage Covered Lines Changed/Added Lines %
avalanche/training/supervised/strategy_wrappers.py 8 24 33.33%
avalanche/training/plugins/generative_replay.py 12 46 26.09%
avalanche/models/generator.py 34 86 39.53%
Files with Coverage Reduction New Missed Lines %
avalanche/benchmarks/scenarios/generic_definitions.py 2 85.37%
Totals Coverage Status
Change from base Build 1958111169: -0.5%
Covered Lines: 11682
Relevant Lines: 15090

💛 - Coveralls

@travela
Copy link
Contributor Author

travela commented Mar 11, 2022

Hey @travela thanks for your contribution. It looks solid in general, I just left a couple of minor comments. Did you reproduce the results on Split MNIST? It would be ideal to have a script to add to the reproducible-cl repository, as we do for the other strategies.

Hi @AntonioCarta, great thanks! I don't seem to be able to see your comments when looking in the "Files changed" tab. Did you publish them or could it be that they are still pending?

And yes, I reproduced the results on the 10 classes splitMNIST scenario in generative_replay_splitMNIST.py. I will look into the reproducible-cl repo and try to convert my example into a test script that can be added.

avalanche/models/generator.py Outdated Show resolved Hide resolved
avalanche/models/generator.py Outdated Show resolved Hide resolved
# Sample data from generator
memory = self.generator.generate(
len(strategy.adapted_dataset) *
(strategy.experience.current_experience)).to(strategy.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain why the generated data has the same length of the original data? Can't we generate the data on-demand with a dataloader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! There are two approaches (as mentioned here): either we generate all data needed for the experience beforehand and then start the training, or we save a copy of the "old" generator and generate our data on demand. I chose the former approach also because this way I could stay in line with the existing Replay strategy, where we update the strategy.data_loader once before each experience.

Do you you think the other method would be more efficient? How would the implementation of an additional dataloader roughly look like? Since the memory variable is only used temporarily and it is then passed to the ReplayDataloader in the next step, I was hoping that we are indeed already making use of the dataloader properties and do not clog RAM too much (but I am not entirely sure about this).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are reasonable solutions. If you want to use this one, you should:

  • batch the results of the Generator. In general you want minibatch size < data size, because the entire dataset may be too large to generate in one step.
  • for the same reasons, unless the dataset is small, it's better to keep it in cpu and move only the minibatch to the gpu when needed.

If you don't do this you will get an out-of-memory error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. How would the first point tie in with the ReplayDataLoader? I could generate the data one mini-batch at a time to make it more resistant for larger datasets, but then I would still have to obtain an AvalancheDataset (ie. to concatenate them again?). This is because ReplayDataLoader expects an AvalancheDataset and not some kind of Dataloader, if I understand it correctly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because ReplayDataLoader expects an AvalancheDataset and not some kind of Dataloader, if I understand it correctly?

Exactly. I would make another dataloader similar to ReplayDataLoader that accepts a dataset (the current data), and an iterator (data generator).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely like the idea of using a Python Generator. But that would then correspond to the second of the two approaches I mentioned above, as when yielding the next batch we need access to the old version of the generator. I was trying to avoid having to store an additional model as well as trying to stick to the existing ReplayDataLoader as I understood it to be designed for any "rehearsal/replay strategies".

I am aiming to update the pull request by the beginning of next week.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I just merged my changes (bba78b9). As mentioned in my last comment I now opted for the "on-demand" replay data generation, where we store the old generator+model and create replay data before each training iteration. This way I got around using the whole ReplayDataLoader syntax and instead just extend the current mini-batch with newly generated replay data.

Accuracy on splitMNIST stayed the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your implementation and IMO is clearer now.
Only a note:
when you create new data, you create as much replay data as the current minibatch dimension times the experience counter. This can result in huge minibatches (especially in benchmarks with a lot of experiences) that can easily fill the GPU memory if cuda is used, especially with datasets with bigger images than MNIST.
We should take this into consideration if we want to use this implementation as the starting point for all the generative replay strategies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I have actually implemented an alternative solution where we use a weighted loss to gradually reduce the importance of a new experience/class as the total number grows. As mentioned below, this solution yielded a lower accuracy (in my particular example). However not using any of the two methods yielded much lower accuracy, suggesting that at least one of the two is necessary.

When implementing the weighted loss I had to overwrite the criterions of both, the model and the generator, which maybe makes it less intuitive, but I guess proper documentation will take care of that.

I was thinking of adding the weighted loss as an option for the user. Maybe I could make it the default option and offer the increasing minibatch size as an alternative (for simple cases like MNIST)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Consider that the simple replay in Avalanche does not change the batch size, so it would be better to have consistent behavior here.

avalanche/models/generator.py Outdated Show resolved Hide resolved
* Extend current mbatch with replay data dynamically before each iteration.

* Update boolean after first experience.

* Fix mbatch[-1] extension

* Put replay_output to device.

* Resolve change requests: class names; VAELoss doc

* Documentation.
Copy link
Member

@ggraffieti ggraffieti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work!
I added some comments and little requests, but this seems a great starting point!
I labeled this review as request changes only to follow in details the next steps, but the work to be done is fairly low.
As a general comment, many classes are named quite generally (Generator, VAE, VAEEncoder..) but lack this generality in the implementation (e.g. a CNN VAE). I'd suggest to call them in different ways (e.g. SimpleVAE or MlpVAE) in order to don't have the need, in the future, to rename them losing compatibility with old code.

avalanche/models/generator.py Outdated Show resolved Hide resolved
avalanche/models/generator.py Outdated Show resolved Hide resolved
avalanche/models/generator.py Outdated Show resolved Hide resolved
avalanche/models/generator.py Outdated Show resolved Hide resolved
avalanche/models/generator.py Outdated Show resolved Hide resolved
avalanche/training/plugins/generative_replay.py Outdated Show resolved Hide resolved
or we use the strategy's model as the generator.
If the generator is None after initialization
we assume that strategy.model is the generator."""
if not self.generator_strategy:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is clear but confusing. If I understand well the plugin needs a generative strategy, which contains the generative model. If no strategy is passed, the "default" strategy is used, and the model defined for the strategy is used as generative model. In this case, is the generative model the only model used? (no classifier).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly: if no generative strategy is passed, we assume that the strategy our plugin was added to already has a model which is generative. E.g. this allows us to easily train a generator with generative replay by simply adding the plugin. In that case there would be no classifier. (Another scenario is that the classifier and generator are combined in a single model, as in this paper about Generative replay with feedback connections).

I agree that it can be confusing (a result of trying to combine all scenarios in a single plugin). I added another line in the doc string (4f5246f) to refer to the example of training a generator, hoping this would make it clearer. Or do you have another suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's fine, I read the code sequentially, so I didn't have a complete idea until the end.
I still believe is a bit confusing, but I don't have a better idea at the moment.
I really like the generality of the strategy, and overall the pros of this implementation greatly outshine a bit of confusion in this part 😉

Comment on lines 126 to 128
replay = self.old_generator.generate(
len(strategy.mbatch[0]) * (strategy.experience.current_experience)
).to(strategy.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why for each minibatch the number of generated data increases with the experiences? Is it a particular implementation used in some paper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually got inspired to do this from this intro_to_generative_replay notebook from the Avalanche/colab repo. In the literature I usually see authors employ a weighted loss instead to achieve a similar effect (e.g. as described here in equation 3). It could be due to the VAE and classifier I use, but in my case the increasing generated data fared a few percentage points higher in accuracy than the weighted loss approach.

avalanche/training/plugins/generative_replay.py Outdated Show resolved Hide resolved
Comment on lines 303 to 317
def __init__(
self,
model: Module,
optimizer: Optimizer,
criterion=CrossEntropyLoss(),
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: int = None,
device=None,
plugins: Optional[List[SupervisedPlugin]] = None,
evaluator: EvaluationPlugin = default_evaluator,
eval_every=-1,
generator_strategy: BaseTemplate = None,
**base_kwargs
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like the idea of having 2 different strategies, one for the classifier (this one) and one for the generative model (passed as an argument). This allows a great generalization.

@travela
Copy link
Contributor Author

travela commented Mar 27, 2022

Ciao @ggraffieti! Thanks a lot for the detailed review and the helpful suggestions. I added all requested changes just now.

@AntonioCarta
Copy link
Collaborator

@ggraffieti are there any other changes that you want on this PR?

@AntonioCarta
Copy link
Collaborator

I checked the test manually and it seems to work. I'm merging this, and we can investigate the CI problems separately.

@travela thanks for your contribution and please remember to push the reproducibility script to the reproducibility-repo.

@AntonioCarta AntonioCarta merged commit 26b5cb2 into ContinualAI:master Apr 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implementation of a Generative Replay Strategy
4 participants