# DER (dynamic expandable representation)
참고 논문 : https://arxiv.org/pdf/2103.16788.pdf <br>
참고 코드 : https://github.com/Rhyssiyan/DER-ClassIL.pytorch or https://github.com/G-U-N/PyCIL



## Structure of DER

<img src="https://github.com/Young-Jo-Choi/paper_study/assets/59189961/909d8c72-2c82-4c15-a56a-9be381a65c36" alt="My Image" width=2000>

출처 : https://arxiv.org/pdf/2103.16788.pdf <br>


In [1]:
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms, datasets
from tqdm import tqdm
from PIL import Image
import math
import copy

## Training paramters

In [2]:
EPSILON = 1e-8

init_epoch = 201
init_lr = 0.1
init_milestones = [60, 120, 170]
init_lr_decay = 0.1
init_weight_decay = 0.0005

epochs = 171
lrate = 0.1
milestones = [80, 120, 150]
lrate_decay = 0.1
batch_size = 128
weight_decay = 2e-4
num_workers = 8
T = 2

## Utility functions

In [3]:
# utils
def tensor2numpy(x):
    return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()

def _map_new_class_index(y, order):
    return np.array(list(map(lambda x: order.index(x), y)))

def count_parameters(model, trainable=False):
    if trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())

def accuracy(y_pred, y_true, nb_old, increment=10):
    assert len(y_pred) == len(y_true), "Data length error."
    all_acc = {}
    all_acc["total"] = np.around((y_pred == y_true).sum() * 100 / len(y_true), decimals=2)

    # Grouped accuracy
    for class_id in range(0, np.max(y_true), increment):
        idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0]
        label = "{}-{}".format(str(class_id).rjust(2, "0"), str(class_id + increment - 1).rjust(2, "0"))
        all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2)

    # Old accuracy
    idxes = np.where(y_true < nb_old)[0]
    all_acc["old"] = (0 if len(idxes) == 0
                        else np.around((y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2))

    # New accuracy
    idxes = np.where(y_true >= nb_old)[0]
    all_acc["new"] = np.around((y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2)

    return all_acc

## Problem Setup
During the class incremental learning, the model observes a stream of class groups {$\mathcal{Y}_t$} and their corresponding training data {$\mathcal{D}_t$}. Particularly, the incoming dataset {$\mathcal{D}_t$} at step $t$ has a form of ($x_i^t,y_i^t$) where $x_i^t$ is the input image and $y_i^t \in \mathcal{Y}_t$ is the label within the label set $\mathcal{Y}_t$. The label space of the model is all seen categories $\tilde{\mathcal{Y}_t}=\cup_{i=1}^t\mathcal{Y}_i$ and the model is expected to predict well on classes in $\tilde{\mathcal{Y}_t}$.
<br>


Our method adopts the rehearsal strategy, which saves a part of data as the memory $\mathcal{M}_t$ for future training.

In [5]:
# cifar100 다운로드
class iCIFAR100():
    train_trsf = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=63 / 255),
        transforms.ToTensor()
    ]
    test_trsf = [transforms.ToTensor()]
    common_trsf = [transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761))]
    class_order = np.arange(100).tolist()

    def download_data(self):
        train_dataset = datasets.cifar.CIFAR100("./data", train=True, download=True)
        test_dataset = datasets.cifar.CIFAR100("./data", train=False, download=True)
        self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
        self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)

In [6]:
class DummyDataset(Dataset):
    def __init__(self, images, labels, trsf):
        assert len(images) == len(labels), "Data size error!"
        self.images = images
        self.labels = labels
        self.trsf = trsf

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.trsf(Image.fromarray(self.images[idx]))
        label = self.labels[idx]
        return idx, image, label

In [7]:
# Incremental learning을 위해 데이터를 분할해 관리하는 class
class DataManager(object):
    def __init__(self, seed, init_cls, increment):
        self._setup_data(seed)
        assert init_cls <= len(self._class_order), "No enough classes."
        self._increments = [init_cls]
        while sum(self._increments) + increment < len(self._class_order):
            self._increments.append(increment)
        offset = len(self._class_order) - sum(self._increments)
        if offset > 0:
            self._increments.append(offset)

    @property
    def nb_tasks(self):
        return len(self._increments)

    def get_task_size(self, task):
        return self._increments[task]

    def get_dataset(self, indices, source, mode, appendent=None, ret_data=False, m_rate=None):
        if source == "train":
            x, y = self._train_data, self._train_targets
        elif source == "test":
            x, y = self._test_data, self._test_targets
        else:
            raise ValueError("Unknown data source {}.".format(source))

        if mode == "train":
            trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
        elif mode == "test":
            trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
        else:
            raise ValueError("Unknown mode {}.".format(mode))

        data, targets = [], []
        for idx in indices:
            class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx + 1)
            data.append(class_data)
            targets.append(class_targets)

        # rehearsal memory
        if appendent is not None and len(appendent) != 0:
            appendent_data, appendent_targets = appendent
            data.append(appendent_data)
            targets.append(appendent_targets)

        data, targets = np.concatenate(data), np.concatenate(targets)
        if ret_data:
            return data, targets, DummyDataset(data, targets, trsf)
        else:
            return DummyDataset(data, targets, trsf)

    def _setup_data(self, seed):
        idata = iCIFAR100()
        idata.download_data()
        self._train_data, self._train_targets = idata.train_data, idata.train_targets
        self._test_data, self._test_targets = idata.test_data, idata.test_targets
        self._train_trsf = idata.train_trsf
        self._test_trsf = idata.test_trsf
        self._common_trsf = idata.common_trsf

        # class ordering
        order = [i for i in range(len(np.unique(self._train_targets)))]
        np.random.seed(seed)
        self._class_order = np.random.permutation(len(order)).tolist()

        # Map indices
        self._train_targets = _map_new_class_index(self._train_targets, self._class_order)
        self._test_targets = _map_new_class_index(self._test_targets, self._class_order)

    def _select(self, x, y, low_range, high_range):
        idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
        if isinstance(x, np.ndarray):
            x_return = x[idxes]
        else:
            x_return = []
            for id in idxes:
                x_return.append(x[id])
        return x_return, y[idxes]

## Training process
For the learning of step $t$, we decouple the learning process into two sequential stages as follows.

1) Representation Learning Stage. To achieve better trade-off between stability and plasticity, we fix the previous feature representation and expand it with a new feature extractor trained on the incoming and memory data. We design an auxiliary loss on the novel extractor to promote it to learn diverse and discriminative features. To improve the model efficiency, we dynamically expand the representation according to the complexity of new classes via introducing a channel-level mask-based pruning method.

2) Classifier Learning Stage. After the learning of representation, we retrain the classifier with currently available data $\tilde{\mathcal{D}}_t = \mathcal{D}_t \cup \mathcal{M}_t$ at step $t$.

$\Phi_t$ : super-feature extractor <br>
$\mathcal{F}_t$ : feature extractor <br>
$\mathcal{H}_t$ : classifier

In [None]:
# Feature Extractor : Resnet32
class ResNetBasicblock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn_a = nn.BatchNorm2d(planes)
        self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_b = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        basicblock = self.conv_a(x)
        basicblock = self.bn_a(basicblock)
        basicblock = F.relu(basicblock, inplace=True)
        basicblock = self.conv_b(basicblock)
        basicblock = self.bn_b(basicblock)
        if self.downsample is not None:
            residual = self.downsample(x)
        return F.relu(residual + basicblock, inplace=True)

class DownsampleA(nn.Module):
    def __init__(self, nIn, nOut, stride):
        super().__init__()
        assert stride == 2
        self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)

    def forward(self, x):
        x = self.avg(x)
        return torch.cat((x, x.mul(0)), 1)

class CifarResNet(nn.Module):
    def __init__(self, block, depth, channels=3):
        super().__init__()
        assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
        layer_blocks = (depth - 2) // 6
        self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_1 = nn.BatchNorm2d(16)
        self.inplanes = 16
        self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
        self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
        self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
        self.avgpool = nn.AvgPool2d(8)
        self.out_dim = 64 * block.expansion
        self.fc = nn.Linear(64*block.expansion, 10)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv_1_3x3(x)  # [bs, 16, 32, 32]
        x = F.relu(self.bn_1(x), inplace=True)
        x_1 = self.stage_1(x)    # [bs, 16, 32, 32]
        x_2 = self.stage_2(x_1)  # [bs, 32, 16, 16]
        x_3 = self.stage_3(x_2)  # [bs, 64, 8, 8]
        pooled = self.avgpool(x_3)  # [bs, 64, 1, 1]
        features = pooled.view(pooled.size(0), -1)  # [bs, 64]
        return features

    @property
    def last_conv(self):
        return self.stage_3[-1].conv_b

def get_convnet():
    return CifarResNet(ResNetBasicblock, 32)

At step $t$, our model is composed of a super-feature extractor $\Phi_t$ and the classifier $\mathcal{H}_t$. The super-feature extractor $\Phi_t$ is built by expanding the feature extractor $\Phi_{t-1}$ with a newly created feature extractor $\mathcal{F}_t$. Specifically, given an image $x \in \hat{\mathcal{D}}_t$, the feature $u$ extracted by $\Phi_t$ is obtained by concatenation as follows
$$u = \Phi_{t}(x) = [\Phi_{t-1}(x), \mathcal{F}_t(x)] $$

Here we reuse the previous $\mathcal{F}_1, . . . , \mathcal{F}_{t−1}$ and encourage the new extractor $\mathcal{F}_t$ to learn only novel aspect of new classes. The feature $u$ is then fed into the classifier Ht to make prediction as follows
$$p_{\mathcal{H}_t}(y|x) = \text{Softmax}(\mathcal{H}_t(u))$$

Then the prediction $\hat{y} = \text{argmax} p_{\mathcal{H}_t}(y|x), \hat{y} \in \hat{\mathcal{Y}}_t$. The classifier is designed to match its new input and output dimensions for step $t$. The parameters of $\mathcal{H}_t$ for the old features are inherited from $\mathcal{H}_{t-1}$ to retain old knowledge and its newly added parameters are randomly initialized.

To reduce catastrophic forgetting, we freeze the learned function $\Phi_{t-1}$ at step $t$, as it captures the intrinsic structure of previous data.

In [8]:
# DER network
class DERNet(nn.Module):
    def __init__(self, args, pretrained):
        super().__init__()
        self.convnets = nn.ModuleList()
        self.pretrained = pretrained
        self.out_dim = None
        self.fc = None
        self.aux_fc = None
        self.task_sizes = []
        self.args = args

    @property
    def feature_dim(self):
        if self.out_dim is None:
            return 0
        return self.out_dim * len(self.convnets)

    def extract_vector(self, x):
        features = [convnet(x) for convnet in self.convnets]
        features = torch.cat(features, 1)
        return features

    def forward(self, x):
        features = [convnet(x) for convnet in self.convnets]
        features = torch.cat(features, 1)
        out = self.fc(features)
        aux_logits = self.aux_fc(features[:, -self.out_dim :])
        out_dict = {"logits" :out, "aux_logits": aux_logits, "features": features}
        return out_dict

    def update_fc(self, nb_classes):
        # Feature extractor 구성
        if len(self.convnets) == 0:
            self.convnets.append(get_convnet())
        else:
            self.convnets.append(get_convnet())
            self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())

        if self.out_dim is None:
            self.out_dim = self.convnets[-1].out_dim
        fc = self.generate_fc(self.feature_dim, nb_classes)
        if self.fc is not None:
            nb_output = self.fc.out_features
            weight = copy.deepcopy(self.fc.weight.data)
            bias = copy.deepcopy(self.fc.bias.data)
            fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
            fc.bias.data[:nb_output] = bias
        del self.fc
        self.fc = fc

        new_task_size = nb_classes - sum(self.task_sizes)
        self.task_sizes.append(new_task_size)
        self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1)

    def generate_fc(self, in_dim, out_dim):
        fc = nn.Linear(in_dim, out_dim)
        return fc

    def weight_align(self, increment):
        weights = self.fc.weight.data
        newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
        oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
        meannew = torch.mean(newnorm)
        meanold = torch.mean(oldnorm)
        gamma = meanold / meannew
        print("alignweights,gamma=", gamma)
        self.fc.weight.data[-increment:, :] *= gamma


## Training loss
We learn the model with cross-entropy loss on memory and incoming data as follows
$$ \mathcal{L}_{\mathcal{H}_t}= {-1 \over |\tilde{\mathcal{D}}_t|} \sum_{i=1}^{\hat{\mathcal{D}}_t} log(p_{\mathcal{H}_t}(y=y_i|x_i)) $$
where $x_i$ is image and $y_i$ is the corresponding label.


To enforce the network to learn the diverse and discriminative features for novel concepts, we further develop an auxiliary loss operating on the novel feature $\mathcal{F}_t(x)$. Specifically, we introduce an auxiliary classifier $\mathcal{H}_t^a$ , which predicts the probability $p_{\mathcal{H}_t^a}(y|x)=\text{Softmax}(\mathcal{H}_t^a(\mathcal{F}_t(x)))$. To encourage the network to learn features to discriminate between old and new concepts, the label space of $\mathcal{H}_t^a$ is $|\mathcal{Y}_t|+1$ including the new category set $\mathcal{Y}_t$ and the other class by treating all old concepts as one category. Thusly, we introduce the auxiliary loss and obtain the expandable representation loss as follows
$$\mathcal{L}_{ER} = \mathcal{L}_{\mathcal{H}_t} + \lambda_a \mathcal{L}_{\mathcal{H}_t^a}$$
where $\lambda_a$ is the hyper-parameter to control the effect of the auxiliary classifier. It is worth noting that $\lambda_a =0$ for first step $t = 1$.


In [9]:
# Incremental learning을 위한 class
class DERBase():
    def __init__(self, args):
        self.args = args
        self._cur_task = -1
        self._known_classes = 0
        self._total_classes = 0
        self._network = DERNet(args, False)
        self._data_memory, self._targets_memory = np.array([]), np.array([])
        self.topk = 5

        self._memory_size = args["memory_size"]
        self._memory_per_class = args.get("memory_per_class", None)
        self._fixed_memory = args.get("fixed_memory", False)
        self._device = args["device"][0]

    def after_task(self):
        self._known_classes = self._total_classes
        print("Exemplar size: {}".format(self.exemplar_size))

    @property
    def feature_dim(self):
        if isinstance(self._network, nn.DataParallel):
            return self._network.module.feature_dim
        else:
            return self._network.feature_dim

    @property
    def exemplar_size(self):
        assert len(self._data_memory) == len(self._targets_memory), "Exemplar size error."
        return len(self._targets_memory)
    def _get_memory(self):
        if len(self._data_memory) == 0:
            return None
        else:
            return (self._data_memory, self._targets_memory)

    @property
    def samples_per_class(self):
        if self._fixed_memory:
            return self._memory_per_class
        else:
            assert self._total_classes != 0, "Total classes is 0"
            return self._memory_size // self._total_classes

    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs)["logits"]
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)
        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)

    def _reduce_exemplar(self, data_manager, m):
        print("Reducing exemplars...({} per classes)".format(m))
        dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(self._targets_memory)
        self._class_means = np.zeros((self._total_classes, self.feature_dim))
        self._data_memory, self._targets_memory = np.array([]), np.array([])

        for class_idx in range(self._known_classes):
            mask = np.where(dummy_targets == class_idx)[0]
            dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m]
            self._data_memory = (
                np.concatenate((self._data_memory, dd))
                if len(self._data_memory) != 0
                else dd
            )
            self._targets_memory = (
                np.concatenate((self._targets_memory, dt))
                if len(self._targets_memory) != 0
                else dt
            )

            # Exemplar mean
            idx_dataset = data_manager.get_dataset([], source="train", mode="test", appendent=(dd, dt))
            idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
            vectors, _ = self._extract_vectors(idx_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            mean = np.mean(vectors, axis=0)
            mean = mean / np.linalg.norm(mean)

            self._class_means[class_idx, :] = mean

    def _construct_exemplar(self, data_manager, m):
        print("Constructing exemplars...({} per classes)".format(m))
        for class_idx in range(self._known_classes, self._total_classes):
            data, targets, idx_dataset = data_manager.get_dataset(
                np.arange(class_idx, class_idx + 1),
                source="train",
                mode="test",
                ret_data=True,
            )
            idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
            vectors, _ = self._extract_vectors(idx_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            class_mean = np.mean(vectors, axis=0)

            # Select
            selected_exemplars = []
            exemplar_vectors = []   # [n, feature_dim]
            for k in range(1, m + 1):
                S = np.sum(exemplar_vectors, axis=0)            # [feature_dim] sum of selected exemplars vectors
                mu_p = (vectors + S) / k                        # [n, feature_dim] sum to all vectors
                i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
                selected_exemplars.append(np.array(data[i]))    # New object to avoid passing by inference
                exemplar_vectors.append(np.array(vectors[i]))   # New object to avoid passing by inference

                vectors = np.delete(vectors, i, axis=0)         # Remove it to avoid duplicative selection
                data = np.delete(data, i, axis=0)               # Remove it to avoid duplicative selection

            selected_exemplars = np.array(selected_exemplars)
            exemplar_targets = np.full(m, class_idx)
            self._data_memory = (
                np.concatenate((self._data_memory, selected_exemplars))
                if len(self._data_memory) != 0
                else selected_exemplars
            )
            self._targets_memory = (
                np.concatenate((self._targets_memory, exemplar_targets))
                if len(self._targets_memory) != 0
                else exemplar_targets
            )

            # Exemplar mean
            idx_dataset = data_manager.get_dataset([],source="train",mode="test",appendent=(selected_exemplars, exemplar_targets))
            idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
            vectors, _ = self._extract_vectors(idx_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            mean = np.mean(vectors, axis=0)
            mean = mean / np.linalg.norm(mean)
            self._class_means[class_idx, :] = mean

    def build_rehearsal_memory(self, data_manager, per_class):
        self._reduce_exemplar(data_manager, per_class)
        self._construct_exemplar(data_manager, per_class)

    def eval_task(self):
        y_pred, y_true = self._eval_cnn(self.test_loader)
        cnn_accy = self._evaluate(y_pred, y_true)
        return cnn_accy

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = self._network(inputs)["logits"]
            predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())
        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _evaluate(self, y_pred, y_true):
        ret = {}
        grouped = accuracy(y_pred.T[0], y_true, self._known_classes)
        ret["grouped"] = grouped
        ret["top1"] = grouped["total"]
        ret["top{}".format(self.topk)] = np.around((y_pred.T == np.tile(y_true, (self.topk, 1))).sum() * 100 / len(y_true),decimals=2)
        return ret

    def _extract_vectors(self, loader):
        self._network.eval()
        vectors, targets = [], []
        for _, _inputs, _targets in loader:
            _targets = _targets.numpy()
            if isinstance(self._network, nn.DataParallel):
                _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device)))
            else:
                _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device)))
            vectors.append(_vectors)
            targets.append(_targets)
        return np.concatenate(vectors), np.concatenate(targets)

In [10]:
# Incremental learning을 위한 class
# DERBase를 상속
class DER(DERBase):
    def __init__(self, args):
        super().__init__(args)

    def incremental_train(self, data_manager):
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        self._network.update_fc(self._total_classes)
        if self._cur_task > 0:
            for i in range(self._cur_task):
                # 이전 task에서 학습한 convnet은 freeze
                for p in self._network.convnets[i].parameters():
                    p.requires_grad = False
        train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),
                                                 source="train", mode="train", appendent=self._get_memory())
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        self._train(self.train_loader, self.test_loader)
        self.build_rehearsal_memory(data_manager, self.samples_per_class)

    # train mode로 전환
    def train(self):
        self._network.train()
        self._network.convnets[-1].train()
        if self._cur_task >= 1:
            for i in range(self._cur_task):
                self._network.convnets[i].eval()

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, self._network.parameters()), momentum=0.9, lr=init_lr, weight_decay=init_weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay)
        if self._cur_task == 0:
            self.init_train(self._network, train_loader, test_loader, optimizer, scheduler)
        else:
            self.update_representation(self._network, train_loader, test_loader, optimizer, scheduler)
            self._network.weight_align(self._total_classes - self._known_classes)

    # init task의 학습 : incremental learning이 아닌 일반적인 모델 학습과 동일
    def init_train(self, network, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(init_epoch))
        for _, epoch in enumerate(prog_bar):
            self.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                logits = self._network(inputs)["logits"]
                loss = F.cross_entropy(logits, targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            if epoch % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(self._cur_task, epoch + 1, init_epoch,
                                                                                                        losses / len(train_loader), train_acc, test_acc)
            prog_bar.set_description(info)

    # init task 이후의 학습 : auxiliary loss를 추가
    def update_representation(self, network, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(epochs))
        for _, epoch in enumerate(prog_bar):
            self.train()
            losses = 0.0
            losses_clf = 0.0
            losses_aux = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                outputs = self._network(inputs)
                logits, aux_logits = outputs["logits"], outputs["aux_logits"]
                loss_clf = F.cross_entropy(logits, targets)
                aux_targets = targets.clone()
                aux_targets = torch.where(
                    aux_targets - self._known_classes + 1 > 0,
                    aux_targets - self._known_classes + 1, 0)
                loss_aux = F.cross_entropy(aux_logits, aux_targets)
                loss = loss_clf + loss_aux
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()
                losses_aux += loss_aux.item()
                losses_clf += loss_clf.item()
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            if epoch % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task, epoch + 1, epochs,
                    losses / len(train_loader), losses_clf / len(train_loader),losses_aux / len(train_loader),
                    train_acc, test_acc)
            prog_bar.set_description(info)

# Implementation

In [11]:
# init_cls는 처음에 학습할 classes의 개수
# increment는 두번째 task 이후 매 task마다 추가적으로 학습할 classes의 개수
args = {"memory_size": 2000, "memory_per_class": 20, "fixed_memory": False,
        "init_cls": 25,"increment": 25, "device": ["cuda:0"], "seed": [1993]}

In [12]:
# 실행
model = DER(args)
data_manager = DataManager(args["seed"], args["init_cls"], args["increment"])

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:07<00:00, 23746243.88it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [13]:
cnn_curve = {"top1": [], "top5": []}
for task in range(data_manager.nb_tasks):
    print("All params: {}".format(count_parameters(model._network)))
    print("Trainable params: {}".format(count_parameters(model._network, True)))
    model.incremental_train(data_manager)
    cnn_accy = model.eval_task()
    model.after_task()
    print("CNN: {}".format(cnn_accy["grouped"]))

    cnn_curve["top1"].append(cnn_accy["top1"])
    cnn_curve["top5"].append(cnn_accy["top5"])

    print("top1 curve: {}".format(cnn_curve["top1"]))
    print("top5 curve: {}\n".format(cnn_curve["top5"]))
    print('Average Accuracy :', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
    print("Average Accuracy : {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))

All params: 0
Trainable params: 0


Task 0, Epoch 201/201 => Loss 0.015, Train_accy 99.86, Test_accy 85.04: 100%|██████████| 201/201 [31:11<00:00,  9.31s/it]

Reducing exemplars...(80 per classes)
Constructing exemplars...(80 per classes)





Exemplar size: 2000
CNN: {'total': 85.04, '00-09': 86.5, '10-19': 83.6, '20-29': 85.0, 'old': 0, 'new': 85.04}
top1 curve: [85.04]
top5 curve: [97.56]

Average Accuracy : 85.04
Average Accuracy : 85.04
All params: 467469
Trainable params: 467469


Task 1, Epoch 171/171 => Loss 0.041, Loss_clf 0.022, Loss_aux 0.020, Train_accy 99.79, Test_accy 71.62: 100%|██████████| 171/171 [32:47<00:00, 11.51s/it]


alignweights,gamma= tensor(0.7513, device='cuda:0')
Reducing exemplars...(40 per classes)
Constructing exemplars...(40 per classes)
Exemplar size: 2000
CNN: {'total': 74.56, '00-09': 75.8, '10-19': 70.1, '20-29': 75.1, '30-39': 75.5, '40-49': 76.3, 'old': 74.08, 'new': 75.04}
top1 curve: [85.04, 74.56]
top5 curve: [97.56, 93.92]

Average Accuracy : 79.80000000000001
Average Accuracy : 79.80000000000001
All params: 936448
Trainable params: 472294


Task 2, Epoch 171/171 => Loss 0.046, Loss_clf 0.025, Loss_aux 0.022, Train_accy 99.86, Test_accy 61.53: 100%|██████████| 171/171 [36:39<00:00, 12.86s/it]

alignweights,gamma= tensor(0.6516, device='cuda:0')
Reducing exemplars...(26 per classes)





Constructing exemplars...(26 per classes)
Exemplar size: 1950
CNN: {'total': 68.48, '00-09': 69.0, '10-19': 64.0, '20-29': 64.5, '30-39': 68.4, '40-49': 72.1, '50-59': 76.4, '60-69': 63.7, '70-79': 71.0, 'old': 67.6, 'new': 70.24}
top1 curve: [85.04, 74.56, 68.48]
top5 curve: [97.56, 93.92, 90.73]

Average Accuracy : 76.02666666666669
Average Accuracy : 76.02666666666669
All params: 1408627
Trainable params: 480319


Task 3, Epoch 171/171 => Loss 0.045, Loss_clf 0.026, Loss_aux 0.018, Train_accy 99.85, Test_accy 54.79: 100%|██████████| 171/171 [42:39<00:00, 14.97s/it]

alignweights,gamma= tensor(0.6183, device='cuda:0')
Reducing exemplars...(20 per classes)





Constructing exemplars...(20 per classes)
Exemplar size: 2000
CNN: {'total': 64.86, '00-09': 63.6, '10-19': 58.6, '20-29': 53.7, '30-39': 56.5, '40-49': 63.2, '50-59': 77.5, '60-69': 62.7, '70-79': 68.6, '80-89': 71.6, '90-99': 72.6, 'old': 62.84, 'new': 70.92}
top1 curve: [85.04, 74.56, 68.48, 64.86]
top5 curve: [97.56, 93.92, 90.73, 88.79]

Average Accuracy : 73.23500000000001
Average Accuracy : 73.23500000000001
