In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch import autograd
# from elastic_weight_consolidation import ElasticWeightConsolidation

In [2]:
class ElasticWeightConsolidation:
    def __init__(self, model, crit, device, lr=0.001, weight=1000000):
        self.model = model.to(device)
        self.weight = weight
        self.crit = crit
        self.device = device
        self.optimizer = optim.Adam(self.model.parameters(), lr)

    def _update_mean_params(self):
        for param_name, param in self.model.named_parameters():
            _buff_param_name = param_name.replace('.', '__')
            self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())

    def _update_fisher_params(self, current_ds, batch_size, num_batch):
        dl = DataLoader(current_ds, batch_size, shuffle=True)
        log_liklihoods = []
        for i, (input, target) in enumerate(dl):
            input = input.to(self.device)
            target = target.to(self.device)
            if i > num_batch:
                break
            output = F.log_softmax(self.model(input), dim=1)
            log_liklihoods.append(output[:, target])
        log_likelihood = torch.cat(log_liklihoods).mean()
        grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters())
        _buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
        for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
            self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)

    def register_ewc_params(self, dataset, batch_size, num_batches):
        self._update_fisher_params(dataset, batch_size, num_batches)
        self._update_mean_params()

    def _compute_consolidation_loss(self, weight):
        try:
            losses = []
            for param_name, param in self.model.named_parameters():
                _buff_param_name = param_name.replace('.', '__')
                estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
                estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
                losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
            return (weight / 2) * sum(losses)
        except AttributeError:
            return 0

    def forward_backward_update(self, input, target):
        input = input.to(self.device)
        target = target.to(self.device)
        output = self.model(input)
        loss = self._compute_consolidation_loss(self.weight) + self.crit(output, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save(self, filename):
        torch.save(self.model, filename)

    def load(self, filename):
        self.model = torch.load(filename)

In [3]:
def accu(model, dataloader, device):
    model = model.eval()    # .to(device)
    acc = 0
    for input, target in dataloader:
        o = model(input.to(device))
        acc += (o.argmax(dim=1).long() == target.to(device)).float().mean()
    return acc / len(dataloader)

In [4]:
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

In [6]:
class LinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
        super(LinearLayer, self).__init__()
        self.use_bn = use_bn
        self.lin = nn.Linear(input_dim, output_dim)
        self.act = nn.ReLU() if act == 'relu' else act
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
    def forward(self, x):
        if self.use_bn:
            return self.bn(self.act(self.lin(x)))
        return self.act(self.lin(x))

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)


In [7]:
class BaseModel(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(BaseModel, self).__init__()
        self.f1 = Flatten()
        self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)
        self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)
        self.lin3 = nn.Linear(num_hidden, num_outputs)
        
    def forward(self, x):
        return self.lin3(self.lin2(self.lin1(self.f1(x))))

In [9]:
crit = nn.CrossEntropyLoss()
# ewc = ElasticWeightConsolidation(BaseModel(28 * 28, 100, 10), crit=crit, lr=1e-4)
ewc = ElasticWeightConsolidation(BaseModel(28 * 28, 100, 10), crit=crit, lr=1e-4, device=device)

In [10]:
for _ in range(10):
    for input, target in tqdm(train_loader):
        ewc.forward_backward_update(input, target)

100%|██████████| 600/600 [00:08<00:00, 69.72it/s]
100%|██████████| 600/600 [00:08<00:00, 72.76it/s]
100%|██████████| 600/600 [00:08<00:00, 68.84it/s]
100%|██████████| 600/600 [00:08<00:00, 71.50it/s]
100%|██████████| 600/600 [00:08<00:00, 73.53it/s]
100%|██████████| 600/600 [00:08<00:00, 72.68it/s]
100%|██████████| 600/600 [00:08<00:00, 73.32it/s]
100%|██████████| 600/600 [00:08<00:00, 72.73it/s]
100%|██████████| 600/600 [00:08<00:00, 71.80it/s]
100%|██████████| 600/600 [00:07<00:00, 75.58it/s]


In [11]:
accu(ewc.model, test_loader, device)

tensor(0.9769, device='cuda:0')

In [12]:
ewc.register_ewc_params(mnist_train, 100, 300)

In [13]:
f_mnist_train = datasets.FashionMNIST("./data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("./data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

In [14]:
for _ in range(20):
    for input, target in tqdm(f_train_loader):
        ewc.forward_backward_update(input, target)

100%|██████████| 600/600 [00:10<00:00, 55.84it/s]
100%|██████████| 600/600 [00:10<00:00, 57.26it/s]
100%|██████████| 600/600 [00:10<00:00, 59.30it/s]
100%|██████████| 600/600 [00:10<00:00, 58.66it/s]
100%|██████████| 600/600 [00:10<00:00, 58.84it/s]
100%|██████████| 600/600 [00:10<00:00, 58.59it/s]
100%|██████████| 600/600 [00:10<00:00, 58.60it/s]
100%|██████████| 600/600 [00:09<00:00, 60.18it/s]
100%|██████████| 600/600 [00:10<00:00, 59.57it/s]
100%|██████████| 600/600 [00:10<00:00, 58.39it/s]


In [15]:
ewc.register_ewc_params(f_mnist_train, 100, 300)

In [16]:
accu(ewc.model, f_test_loader, device)

tensor(0.7999, device='cuda:0')

In [17]:
accu(ewc.model, test_loader, device)

tensor(0.9736, device='cuda:0')