Skip to content

Commit

Permalink
Merge pull request #931 from travela/generative_replay
Browse files Browse the repository at this point in the history
Generative Replay
  • Loading branch information
AntonioCarta committed Apr 8, 2022
2 parents 57a6c15 + dfa1d69 commit 26b5cb2
Show file tree
Hide file tree
Showing 9 changed files with 791 additions and 3 deletions.
1 change: 1 addition & 0 deletions avalanche/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .base_model import BaseModel
from .helper_method import as_multitask
from .pnn import PNN
from .generator import *
193 changes: 193 additions & 0 deletions avalanche/models/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 03-03-2022 #
# Author: Florian Mies #
# Website: https://github.com/travela #
################################################################################

"""
File to place any kind of generative models
and their respective helper functions.
"""

from abc import abstractmethod
from matplotlib import transforms
import torch
import torch.nn as nn
from torchvision import transforms
from avalanche.models.utils import MLP, Flatten
from avalanche.models.base_model import BaseModel


class Generator(BaseModel):
"""
A base abstract class for generators
"""

@abstractmethod
def generate(self, batch_size=None, condition=None):
"""
Lets the generator sample random samples.
Output is either a single sample or, if provided,
a batch of samples of size "batch_size"
:param batch_size: Number of samples to generate
:param condition: Possible condition for a condotional generator
(e.g. a class label)
"""


###########################
# VARIATIONAL AUTOENCODER #
###########################


class VAEMLPEncoder(nn.Module):
'''
Encoder part of the VAE, computer the latent represenations of the input.
:param shape: Shape of the input to the network: (channels, height, width)
:param latent_dim: Dimension of last hidden layer
'''

def __init__(self, shape, latent_dim=128):
super(VAEMLPEncoder, self).__init__()
flattened_size = torch.Size(shape).numel()
self.encode = nn.Sequential(
Flatten(),
nn.Linear(in_features=flattened_size, out_features=400),
nn.BatchNorm1d(400),
nn.LeakyReLU(),
MLP([400, latent_dim])
)

def forward(self, x, y=None):
x = self.encode(x)
return x


class VAEMLPDecoder(nn.Module):
'''
Decoder part of the VAE. Reverses Encoder.
:param shape: Shape of output: (channels, height, width).
:param nhid: Dimension of input.
'''

def __init__(self, shape, nhid=16):
super(VAEMLPDecoder, self).__init__()
flattened_size = torch.Size(shape).numel()
self.shape = shape
self.decode = nn.Sequential(
MLP([nhid, 64, 128, 256, flattened_size], last_activation=False),
nn.Sigmoid())
self.invTrans = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
])

def forward(self, z, y=None):
if (y is None):
return self.invTrans(self.decode(z).view(-1, *self.shape))
else:
return self.invTrans(self.decode(torch.cat((z, y), dim=1))
.view(-1, *self.shape))


class MlpVAE(Generator, nn.Module):
'''
Variational autoencoder module:
fully-connected and suited for any input shape and type.
The encoder only computes the latent represenations
and we have then two possible output heads:
One for the usual output distribution and one for classification.
The latter is an extension the conventional VAE and incorporates
a classifier into the network.
More details can be found in: https://arxiv.org/abs/1809.10635
'''

def __init__(self, shape, nhid=16, n_classes=10, device="cpu"):
"""
:param shape: Shape of each input sample
:param nhid: Dimension of latent space of Encoder.
:param n_classes: Number of classes -
defines classification head's dimension
"""
super(MlpVAE, self).__init__()
self.dim = nhid
self.device = device
self.encoder = VAEMLPEncoder(shape, latent_dim=128)
self.calc_mean = MLP([128, nhid], last_activation=False)
self.calc_logvar = MLP([128, nhid], last_activation=False)
self.classification = MLP([128, n_classes], last_activation=False)
self.decoder = VAEMLPDecoder(shape, nhid)

def get_features(self, x):
"""
Get features for encoder part given input x
"""
return self.encoder(x)

def generate(self, batch_size=None):
"""
Generate random samples.
Output is either a single sample if batch_size=None,
else it is a batch of samples of size "batch_size".
"""
z = torch.randn((batch_size, self.dim)).to(
self.device) if batch_size else torch.randn((1, self.dim)).to(
self.device)
res = self.decoder(z)
if not batch_size:
res = res.squeeze(0)
return res

def sampling(self, mean, logvar):
"""
VAE 'reparametrization trick'
"""
eps = torch.randn(mean.shape).to(self.device)
sigma = 0.5 * torch.exp(logvar)
return mean + eps * sigma

def forward(self, x):
"""
Forward.
"""
represntations = self.encoder(x)
mean, logvar = self.calc_mean(
represntations), self.calc_logvar(represntations)
z = self.sampling(mean, logvar)
return self.decoder(z), mean, logvar


# Loss functions
BCE_loss = nn.BCELoss(reduction="sum")
MSE_loss = nn.MSELoss(reduction="sum")
CE_loss = nn.CrossEntropyLoss()


def VAE_loss(X, forward_output):
'''
Loss function of a VAE using mean squared error for reconstruction loss.
This is the criterion for VAE training loop.
:param X: Original input batch.
:param forward_output: Return value of a VAE.forward() call.
Triplet consisting of (X_hat, mean. logvar), ie.
(Reconstructed input after subsequent Encoder and Decoder,
mean of the VAE output distribution,
logvar of the VAE output distribution)
'''
X_hat, mean, logvar = forward_output
reconstruction_loss = MSE_loss(X_hat, X)
KL_divergence = 0.5 * torch.sum(-1 - logvar + torch.exp(logvar) + mean**2)
return reconstruction_loss + KL_divergence


__all__ = ["MlpVAE", "VAE_loss"]
45 changes: 44 additions & 1 deletion avalanche/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.models.dynamic_modules import MultiTaskModule, DynamicModule
import torch.nn as nn
from collections import OrderedDict


def avalanche_forward(model, x, task_labels):
Expand Down Expand Up @@ -59,4 +60,46 @@ def add_hooks(self, model):
)


__all__ = ["avalanche_forward", "FeatureExtractorBackbone"]
class Flatten(nn.Module):
'''
Simple nn.Module to flatten each tensor of a batch of tensors.
'''

def __init__(self):
super(Flatten, self).__init__()

def forward(self, x):
batch_size = x.shape[0]
return x.view(batch_size, -1)


class MLP(nn.Module):
'''
Simple nn.Module to create a multi-layer perceptron
with BatchNorm and ReLU activations.
:param hidden_size: An array indicating the number of neurons in each layer.
:type hidden_size: int[]
:param last_activation: Indicates whether to add BatchNorm and ReLU
after the last layer.
:type last_activation: Boolean
'''

def __init__(self, hidden_size, last_activation=True):
super(MLP, self).__init__()
q = []
for i in range(len(hidden_size)-1):
in_dim = hidden_size[i]
out_dim = hidden_size[i+1]
q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
if (i < len(hidden_size)-2) or ((i == len(hidden_size) - 2)
and (last_activation)):
q.append(("BatchNorm_%d" % i, nn.BatchNorm1d(out_dim)))
q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
self.mlp = nn.Sequential(OrderedDict(q))

def forward(self, x):
return self.mlp(x)


__all__ = ["avalanche_forward", "FeatureExtractorBackbone", "MLP", "Flatten"]
2 changes: 2 additions & 0 deletions avalanche/training/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
from .lfl import LFLPlugin
from .early_stopping import EarlyStoppingPlugin
from .lr_scheduling import LRSchedulerPlugin
from .generative_replay import GenerativeReplayPlugin, \
TrainGeneratorAfterExpPlugin
from .rwalk import RWalkPlugin
from .mas import MASPlugin

0 comments on commit 26b5cb2

Please sign in to comment.