# importation

In [1]:
import os
import numpy as np
import torch
import argparse
import copy
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.utils import save_image
from torchvision import datasets, transforms

# def

## get_dataset

In [2]:
def get_dataset(dataset, data_path):
    if dataset == 'MNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        mean = [0.1307]
        std = [0.3081]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
        class_names = [str(c) for c in range(num_classes)]

    elif dataset == 'FashionMNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        mean = [0.2861]
        std = [0.3530]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'SVHN':
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        mean = [0.4377, 0.4438, 0.4728]
        std = [0.1980, 0.2010, 0.1970]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.SVHN(data_path, split='train', download=True, transform=transform)  # no augmentation
        dst_test = datasets.SVHN(data_path, split='test', download=True, transform=transform)
        class_names = [str(c) for c in range(num_classes)]

    elif dataset == 'CIFAR10':
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'CIFAR100':
        channel = 3
        im_size = (32, 32)
        num_classes = 100
        mean = [0.5071, 0.4866, 0.4409]
        std = [0.2673, 0.2564, 0.2762]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'TinyImageNet':
        channel = 3
        im_size = (64, 64)
        num_classes = 200
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        data = torch.load(os.path.join(data_path, 'tinyimagenet.pt'), map_location='cpu')

        class_names = data['classes']

        images_train = data['images_train']
        labels_train = data['labels_train']
        images_train = images_train.detach().float() / 255.0
        labels_train = labels_train.detach()
        for c in range(channel):
            images_train[:,c] = (images_train[:,c] - mean[c])/std[c]
        dst_train = TensorDataset(images_train, labels_train)  # no augmentation

        images_val = data['images_val']
        labels_val = data['labels_val']
        images_val = images_val.detach().float() / 255.0
        labels_val = labels_val.detach()

        for c in range(channel):
            images_val[:, c] = (images_val[:, c] - mean[c]) / std[c]

        dst_test = TensorDataset(images_val, labels_val)  # no augmentation

    else:
        exit('unknown dataset: %s'%dataset)


    testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)
    return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader


## get_loops

In [3]:
def get_loops(ipc):
    # Get the two hyper-parameters of outer-loop and inner-loop.
    # The following values are empirically good.
    if ipc == 1:
        outer_loop, inner_loop = 1, 1
    elif ipc == 10:
        outer_loop, inner_loop = 10, 50
    elif ipc == 20:
        outer_loop, inner_loop = 20, 25
    elif ipc == 30:
        outer_loop, inner_loop = 30, 20
    elif ipc == 40:
        outer_loop, inner_loop = 40, 15
    elif ipc == 50:
        outer_loop, inner_loop = 50, 10
    else:
        outer_loop, inner_loop = 0, 0
        exit('loop hyper-parameters are not defined for %d ipc'%ipc)
    return outer_loop, inner_loop


## get_daparam

In [4]:
def get_daparam(dataset, model, model_eval, ipc):
    # We find that augmentation doesn't always benefit the performance.
    # So we do augmentation for some of the settings.

    dc_aug_param = dict()
    dc_aug_param['crop'] = 4
    dc_aug_param['scale'] = 0.2
    dc_aug_param['rotate'] = 45
    dc_aug_param['noise'] = 0.001
    dc_aug_param['strategy'] = 'none'

    if dataset == 'MNIST':
        dc_aug_param['strategy'] = 'crop_scale_rotate'

    if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier.
        dc_aug_param['strategy'] = 'crop_noise'


## distance_wb

In [5]:
def distance_wb(gwr, gws):
    shape = gwr.shape
    if len(shape) == 4: # conv, out*in*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
        gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
    elif len(shape) == 3:  # layernorm, C*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2])
        gws = gws.reshape(shape[0], shape[1] * shape[2])
    elif len(shape) == 2: # linear, out*in
        tmp = 'do nothing'
    elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
        gwr = gwr.reshape(1, shape[0])
        gws = gws.reshape(1, shape[0])
        return torch.tensor(0, dtype=torch.float, device=gwr.device)

    dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
    dis = dis_weight
    return dis



def match_loss(gw_syn, gw_real, args):
    dis = torch.tensor(0.0).to(args.device)

    if args.dis_metric == 'ours':
        for ig in range(len(gw_real)):
            gwr = gw_real[ig]
            gws = gw_syn[ig]
            dis += distance_wb(gwr, gws)

    elif args.dis_metric == 'mse':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = torch.sum((gw_syn_vec - gw_real_vec)**2)

    elif args.dis_metric == 'cos':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)

    else:
        exit('unknown distance function: %s'%args.dis_metric)

    return dis

# get_time

In [6]:
def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))

In [7]:
class TensorDataset(Dataset):
    def __init__(self, images, labels): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return self.images.shape[0]

In [8]:
def DiffAugment(x, strategy='', seed = -1, param = None):
    if strategy == 'None' or strategy == 'none' or strategy == '':
        return x

    if seed == -1:
        param.Siamese = False
    else:
        param.Siamese = True

    param.latestseed = seed

    if strategy:
        if param.aug_mode == 'M': # original
            for p in strategy.split('_'):
                for f in AUGMENT_FNS[p]:
                    x = f(x, param)
        elif param.aug_mode == 'S':
            pbties = strategy.split('_')
            set_seed_DiffAug(param)
            p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
            for f in AUGMENT_FNS[p]:
                x = f(x, param)
        else:
            exit('unknown augmentation mode: %s'%param.aug_mode)
        x = x.contiguous()
    return x

In [9]:
class ParamDiffAug():
    def __init__(self):
        self.aug_mode = 'S' #'multiple or single'
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = 0.5 # the size would be 0.5x0.5
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5

In [10]:
def epoch(mode, dataloader, net, optimizer, criterion, args, aug):
    loss_avg, acc_avg, num_exp = 0, 0, 0
    net = net.to(args.device)
    criterion = criterion.to(args.device)

    if mode == 'train':
        net.train()
    else:
        net.eval()

    for i_batch, datum in enumerate(dataloader):
        img = datum[0].float().to(args.device)
        if aug:
            if args.dsa:
                img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
            else:
                img = augment(img, args.dc_aug_param, device=args.device)
        lab = datum[1].long().to(args.device)
        n_b = lab.shape[0]

        output = net(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg



In [11]:
''' ConvNet '''
class ConvNet(nn.Module):
    def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
        super(ConvNet, self).__init__()

        self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
        num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
        self.classifier = nn.Linear(num_feat, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def embed(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        return out

    def _get_activation(self, net_act):
        if net_act == 'sigmoid':
            return nn.Sigmoid()
        elif net_act == 'relu':
            return nn.ReLU(inplace=True)
        elif net_act == 'leakyrelu':
            return nn.LeakyReLU(negative_slope=0.01)
        elif net_act == 'swish':
            return Swish()
        else:
            exit('unknown activation function: %s'%net_act)

    def _get_pooling(self, net_pooling):
        if net_pooling == 'maxpooling':
            return nn.MaxPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'avgpooling':
            return nn.AvgPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'none':
            return None
        else:
            exit('unknown net_pooling: %s'%net_pooling)

    def _get_normlayer(self, net_norm, shape_feat):
        # shape_feat = (c*h*w)
        if net_norm == 'batchnorm':
            return nn.BatchNorm2d(shape_feat[0], affine=True)
        elif net_norm == 'layernorm':
            return nn.LayerNorm(shape_feat, elementwise_affine=True)
        elif net_norm == 'instancenorm':
            return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
        elif net_norm == 'groupnorm':
            return nn.GroupNorm(4, shape_feat[0], affine=True)
        elif net_norm == 'none':
            return None
        else:
            exit('unknown net_norm: %s'%net_norm)

    def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
        layers = []
        in_channels = channel
        if im_size[0] == 28:
            im_size = (32, 32)
        shape_feat = [in_channels, im_size[0], im_size[1]]
        for d in range(net_depth):
            layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
            shape_feat[0] = net_width
            if net_norm != 'none':
                layers += [self._get_normlayer(net_norm, shape_feat)]
            layers += [self._get_activation(net_act)]
            in_channels = net_width
            if net_pooling != 'none':
                layers += [self._get_pooling(net_pooling)]
                shape_feat[1] //= 2
                shape_feat[2] //= 2

        return nn.Sequential(*layers), shape_feat



''' LeNet '''
class LeNet(nn.Module):
    def __init__(self, channel, num_classes):
        super(LeNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_1 = nn.Linear(16 * 5 * 5, 120)
        self.fc_2 = nn.Linear(120, 84)
        self.fc_3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = self.fc_3(x)
        return x



''' AlexNet '''
class AlexNet(nn.Module):
    def __init__(self, channel, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Linear(192 * 4 * 4, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def embed(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x


''' AlexNetBN '''
class AlexNetBN(nn.Module):
    def __init__(self, channel, num_classes):
        super(AlexNetBN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 192, kernel_size=5, padding=2),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Linear(192 * 4 * 4, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def embed(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x


''' VGG '''
cfg_vgg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
class VGG(nn.Module):
    def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'):
        super(VGG, self).__init__()
        self.channel = channel
        self.features = self._make_layers(cfg_vgg[vgg_name], norm)
        self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def embed(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x

    def _make_layers(self, cfg, norm):
        layers = []
        in_channels = self.channel
        for ic, x in enumerate(cfg):
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1),
                           nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def VGG11(channel, num_classes):
    return VGG('VGG11', channel, num_classes)
def VGG11BN(channel, num_classes):
    return VGG('VGG11', channel, num_classes, norm='batchnorm')
def VGG13(channel, num_classes):
    return VGG('VGG13', channel, num_classes)
def VGG16(channel, num_classes):
    return VGG('VGG16', channel, num_classes)
def VGG19(channel, num_classes):
    return VGG('VGG19', channel, num_classes)

''' MLP '''
class MLP(nn.Module):
    def __init__(self, channel, num_classes):
        super(MLP, self).__init__()
        self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128)
        self.fc_2 = nn.Linear(128, 128)
        self.fc_3 = nn.Linear(128, num_classes)

    def forward(self, x):
        out = x.view(x.size(0), -1)
        out = F.relu(self.fc_1(out))
        out = F.relu(self.fc_2(out))
        out = self.fc_3(out)
        return out

def ResNet18(channel, num_classes):
    return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)

In [12]:
def get_default_convnet_setting():
    net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
    return net_width, net_depth, net_act, net_norm, net_pooling
def get_network(model, channel, num_classes, im_size=(32, 32)):
    torch.random.manual_seed(int(time.time() * 1000) % 100000)
    net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()

    if model == 'MLP':
        net = MLP(channel=channel, num_classes=num_classes)
    elif model == 'ConvNet':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'LeNet':
        net = LeNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNet':
        net = AlexNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNetBN':
        net = AlexNetBN(channel=channel, num_classes=num_classes)
    elif model == 'VGG11':
        net = VGG11( channel=channel, num_classes=num_classes)
   
    elif model == 'ConvNetD1':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD2':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD3':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD4':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetW32':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW64':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW128':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW256':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetAS':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetAR':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetAL':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetASwish':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetASwishBN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetNN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetBN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetLN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetIN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetGN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetNP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none', im_size=im_size)
    elif model == 'ConvNetMP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling', im_size=im_size)
    elif model == 'ConvNetAP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling', im_size=im_size)

    else:
        net = None
        exit('unknown model: %s'%model)

    gpu_num = torch.cuda.device_count()
    if gpu_num>0:
        device = 'cuda'
        if gpu_num>1:
            net = nn.DataParallel(net)
    else:
        device = 'cpu'
    net = net.to(device)

    return net


In [13]:
def get_eval_pool(eval_mode, model, model_eval):
    if eval_mode == 'M': # multiple architectures
        model_eval_pool = ['MLP', 'ConvNet', 'LeNet', 'AlexNet', 'VGG11', 'ResNet18']
    elif eval_mode == 'B':  # multiple architectures with BatchNorm for DM experiments
        model_eval_pool = ['ConvNetBN', 'ConvNetASwishBN', 'AlexNetBN', 'VGG11BN', 'ResNet18BN']
    elif eval_mode == 'W': # ablation study on network width
        model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256']
    elif eval_mode == 'D': # ablation study on network depth
        model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4']
    elif eval_mode == 'A': # ablation study on network activation function
        model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL', 'ConvNetASwish']
    elif eval_mode == 'P': # ablation study on network pooling layer
        model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP']
    elif eval_mode == 'N': # ablation study on network normalization layer
        model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN']
    elif eval_mode == 'S': # itself
        if 'BN' in model:
            print('Attention: Here I will replace BN with IN in evaluation, as the synthetic set is too small to measure BN hyper-parameters.')
        model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model]
    elif eval_mode == 'SS':  # itself
        model_eval_pool = [model]
    else:
        model_eval_pool = [model_eval]
    return model_eval_pool

In [14]:
def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args):
    net = net.to(args.device)
    images_train = images_train.to(args.device)
    labels_train = labels_train.to(args.device)
    lr = float(args.lr_net)
    Epoch = int(args.epoch_eval_train)
    lr_schedule = [Epoch//2+1]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss().to(args.device)

    dst_train = TensorDataset(images_train, labels_train)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    start = time.time()
    for ep in range(Epoch+1):
        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = True)
        if ep in lr_schedule:
            lr *= 0.1
            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

    time_train = time.time() - start
    loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False)
    print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test))

    return net, acc_train, acc_test

In [15]:
class ParamDiffAug():
    def __init__(self):
        self.aug_mode = 'S' #'multiple or single'
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = 0.5 # the size would be 0.5x0.5
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5


In [16]:
from torch.utils.data import Dataset
class TensorDataset(Dataset):
    def __init__(self, images, labels): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return self.images.shape[0]

# distillation avec la methode DSA 

In [29]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
def main():

    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, default='DSA', help='DC/DSA')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode')
    parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=0.1, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

    # Handle unrecognized arguments
    args, unknown = parser.parse_known_args()
    if unknown:
        print(f"Unrecognized arguments: {unknown}")

    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.method == 'DSA' else False

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 500).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration]
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    accs_all_exps = dict()
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []

    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n):
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1)

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')

        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5)
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):

            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    if args.dsa:
                        args.epoch_eval_train = 1000
                        args.dc_aug_param = None
                        print('DSA augmentation strategy: \n', args.dsa_strategy)
                        print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
                    else:
                        args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc)
                        print('DC augmentation parameters: \n', args.dc_aug_param)

                    if args.dsa or args.dc_aug_param['strategy'] != 'none':
                        args.epoch_eval_train = 1000
                    else:
                        args.epoch_eval_train = 300

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device)
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration:
                        accs_all_exps[model_eval] += accs

                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc)

            net = get_network(args.model, channel, num_classes, im_size).to(args.device)
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None

            for ol in range(args.outer_loop):

                BN_flag = False
                BNSizePC = 16
                for module in net.modules():
                    if 'BatchNorm' in module._get_name():
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train()
                    output_real = net(img_real)
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():
                            module.eval()

                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((img_syn.shape[0],), device=args.device, dtype=torch.long) * c
                    if args.dsa:
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, param=args.dsa_param)
                    else:
                        img_syn = augment(img_syn, args.dc_aug_param)

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break

                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

                if args.dsa:
                    image_syn.data = torch.clamp(image_syn, 0.0, 1.0)

            loss_avg /= (num_classes*args.outer_loop)

            if it % 10 == 0:
                print('%s iter = %04d, loss = %.4f'%(get_time(), it, loss_avg))

        data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
        torch.save({'data': data_save, 'accs_all_exps': accs_all_exps}, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))

    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments for %s: mean = %.4f std = %.4f\n-------------------------\n'%(len(accs), key, np.mean(accs), np.std(accs)))

if __name__ == '__main__':
    main()

Unrecognized arguments: ['-f', 'C:\\Users\\ahmed\\AppData\\Roaming\\jupyter\\runtime\\kernel-fbd20088-27d9-4d3a-a8ba-06626b58e328.json']
eval_it_pool:  [0, 500, 1000]
Files already downloaded and verified
Files already downloaded and verified

 
Hyper-parameters: 
 {'method': 'DSA', 'dataset': 'CIFAR10', 'model': 'ConvNet', 'ipc': 1, 'eval_mode': 'S', 'num_exp': 5, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 1000, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': 'data', 'save_path': 'result', 'dis_metric': 'ours', 'outer_loop': 1, 'inner_loop': 1, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x000001DC10AC6C10>, 'dsa': True}
Evaluation model pool:  ['ConvNet']
class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000

  label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1)


[2024-08-08 10:16:02] Evaluate_00: epoch = 1000 train time = 3 s train loss = 0.000236 train acc = 1.0000, test acc = 0.1265
[2024-08-08 10:16:07] Evaluate_01: epoch = 1000 train time = 3 s train loss = 0.000236 train acc = 1.0000, test acc = 0.1312
[2024-08-08 10:16:12] Evaluate_02: epoch = 1000 train time = 3 s train loss = 0.000229 train acc = 1.0000, test acc = 0.1177
[2024-08-08 10:16:18] Evaluate_03: epoch = 1000 train time = 3 s train loss = 0.000227 train acc = 1.0000, test acc = 0.1163
[2024-08-08 10:16:23] Evaluate_04: epoch = 1000 train time = 3 s train loss = 0.000230 train acc = 1.0000, test acc = 0.1338
[2024-08-08 10:16:28] Evaluate_05: epoch = 1000 train time = 2 s train loss = 0.000234 train acc = 1.0000, test acc = 0.1229
[2024-08-08 10:16:33] Evaluate_06: epoch = 1000 train time = 3 s train loss = 0.000231 train acc = 1.0000, test acc = 0.1133
[2024-08-08 10:16:39] Evaluate_07: epoch = 1000 train time = 3 s train loss = 0.000239 train acc = 1.0000, test acc = 0.1140


KeyboardInterrupt: 

# distillation avec la methode DM 

In [17]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image

def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=10, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode')
    parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=500, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

    # Ignorer les arguments non reconnus lorsque vous exécutez dans Jupyter ou des environnements similaires
    args, unknown = parser.parse_known_args()

    args.method = 'DM'
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = False if args.dsa_strategy in ['none', 'None'] else True

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration]
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    accs_all_exps = dict()
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []

    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n ' % exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' Organiser le dataset réel '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images' % (c, len(indices_class[c])))

        def get_images(c, n):
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f' % (ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

        ''' Initialiser les données synthétiques '''
        image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc) * i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1)

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c * args.ipc:(c + 1) * args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')

        ''' Entraînement '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5)
        optimizer_img.zero_grad()
        print('%s training begins' % get_time())

        for it in range(args.Iteration + 1):
            ''' Évaluer les données synthétiques '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d' % (args.model, model_eval, it))

                    print('DSA augmentation strategy: \n', args.dsa_strategy)
                    print('DSA augmentation parameters: \n', args.dsa_param.__dict__)

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device)
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------' % (len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration:
                        accs_all_exps[model_eval] += accs

                ''' Visualiser et enregistrer '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png' % (args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis < 0] = 0.0
                image_syn_vis[image_syn_vis > 1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc)

            ''' Entraîner les données synthétiques '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device)
            net.train()
            for param in list(net.parameters()):
                param.requires_grad = False

            embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed

            loss_avg = 0

            ''' Mettre à jour les données synthétiques '''
            if 'BN' not in args.model:
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    img_syn = image_syn[c * args.ipc:(c + 1) * args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = embed(img_real).detach()
                    output_syn = embed(img_syn)

                    loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0)) ** 2)

            else:
                images_real_all = []
                images_syn_all = []
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    img_syn = image_syn[c * args.ipc:(c + 1) * args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    images_real_all.append(img_real)
                    images_syn_all.append(img_syn)

                images_real_all = torch.cat(images_real_all, dim=0)
                images_syn_all = torch.cat(images_syn_all, dim=0)

                output_real = embed(images_real_all).detach()
                output_syn = embed(images_syn_all)
                loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0)) ** 2)

            optimizer_img.zero_grad()
            loss.backward()
            optimizer_img.step()
            loss_avg += loss.item()

            if it % 10 == 0:
                loss_avg /= (num_classes * 10)
                print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))
                loss_avg = 0

        print('\n==================== Final Results ====================\n')
        for key in model_eval_pool:
            accs = accs_all_exps[key]
            print('Run %d: %s: mean = %.4f std = %.4f' % (exp, key, np.mean(accs), np.std(accs)))

        data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])

    torch.save({'data': data_save, 'accs_all_exps': accs_all_exps}, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt' % (args.method, args.dataset, args.model, args.ipc)))

if __name__ == '__main__':
    main()


SyntaxError: unexpected EOF while parsing (2779585543.py, line 171)

In [None]:
# cl with image distillé DSA

In [40]:

def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, default='DSA', help='random/herding/DSA/DM')
    parser.add_argument('--dataset', type=str, default='CIFAR100', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=20, help='image(s) per class')
    parser.add_argument('--steps', type=int, default=10, help='5/10-step learning')
    parser.add_argument('--num_eval', type=int, default=3, help='evaluation number')
    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--data_path', type=str, default='./data', help='dataset path')

    args, unknown = parser.parse_known_args()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True # augment images for all methods
    args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' # for CIFAR10/100

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)


    ''' all training data '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]

    images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
    for i, lab in enumerate(labels_all):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to(args.device)
    labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

    # for c in range(num_classes):
    #     print('class c = %d: %d real images' % (c, len(indices_class[c])))

    def get_images(c, n):  # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]

    print()
    print('==================================================================================')
    print('method: ', args.method)
    results = np.zeros((args.steps, 5*args.num_eval))

    for seed_cl in range(5):
        num_classes_step = num_classes // args.steps
        np.random.seed(seed_cl)
        class_order = np.random.permutation(num_classes).tolist()
        print('=========================================')
        print('seed: ', seed_cl)
        print('class_order: ', class_order)
        print('augmentation strategy: \n', args.dsa_strategy)
        print('augmentation parameters: \n', args.dsa_param.__dict__)

        if args.method == 'random':
            images_train_all = []
            labels_train_all = []
            for step in range(args.steps):
                classes_current = class_order[step * num_classes_step: (step + 1) * num_classes_step]
                images_train_all += [torch.cat([get_images(c, args.ipc) for c in classes_current], dim=0)]
                labels_train_all += [torch.tensor([c for c in classes_current for i in range(args.ipc)], dtype=torch.long, device=args.device)]

        elif args.method == 'herding':
            fname = os.path.join(args.data_path, 'metasets', 'cl_data', 'cl_herding_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, seed_cl))
            data = torch.load(fname, map_location='cpu')['data']
            images_train_all = [data[step][0] for step in range(args.steps)]
            labels_train_all = [data[step][1] for step in range(args.steps)]
            print('use data: ', fname)

        elif args.method == 'DSA':
            fname = os.path.join(args.data_path, 'metasets', 'cl_data', 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, seed_cl))
            data = torch.load(fname, map_location='cpu')['data']
            images_train_all = [data[step][0] for step in range(args.steps)]
            labels_train_all = [data[step][1] for step in range(args.steps)]
            print('use data: ', fname)

        elif args.method == 'DM':
            fname = os.path.join(args.data_path, 'metasets', 'cl_data', 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, seed_cl))
            data = torch.load(fname, map_location='cpu')['data']
            images_train_all = [data[step][0] for step in range(args.steps)]
            labels_train_all = [data[step][1] for step in range(args.steps)]
            print('use data: ', fname)

        else:
            exit('unknown method: %s'%args.method)


        for step in range(args.steps):
            print('\n-----------------------------\nmethod %s seed %d step %d ' % (args.method, seed_cl, step))

            classes_seen = class_order[: (step+1)*num_classes_step]
            print('classes_seen: ', classes_seen)


            ''' train data '''
            images_train = torch.cat(images_train_all[:step+1], dim=0).to(args.device)
            labels_train = torch.cat(labels_train_all[:step+1], dim=0).to(args.device)
            print('train data size: ', images_train.shape)


            ''' test data '''
            images_test = []
            labels_test = []
            for i in range(len(dst_test)):
                lab = int(dst_test[i][1])
                if lab in classes_seen:
                    images_test.append(torch.unsqueeze(dst_test[i][0], dim=0))
                    labels_test.append(dst_test[i][1])

            images_test = torch.cat(images_test, dim=0).to(args.device)
            labels_test = torch.tensor(labels_test, dtype=torch.long, device=args.device)
            dst_test_current = TensorDataset(images_test, labels_test)
            testloader = torch.utils.data.DataLoader(dst_test_current, batch_size=256, shuffle=False, num_workers=0)

            print('test set size: ', images_test.shape)


            ''' train model on the newest memory '''
            accs = []
            for ep_eval in range(args.num_eval):
                net_eval = get_network(args.model, channel, num_classes, im_size)
                net_eval = net_eval.to(args.device)
                img_syn_eval = copy.deepcopy(images_train.detach())
                lab_syn_eval = copy.deepcopy(labels_train.detach())

                _, acc_train, acc_test = evaluate_synset(ep_eval, net_eval, img_syn_eval, lab_syn_eval, testloader, args)
                del net_eval, img_syn_eval, lab_syn_eval
                gc.collect()  # to reduce memory cost
                accs.append(acc_test)
                results[step, seed_cl*args.num_eval + ep_eval] = acc_test
            print('Evaluate %d random %s, mean = %.4f std = %.4f' % (len(accs), args.model, np.mean(accs), np.std(accs)))


    results_str = ''
    for step in range(args.steps):
        results_str += '& %.1f$\pm$%.1f  ' % (np.mean(results[step]) * 100, np.std(results[step]) * 100)
    print('\n\n')
    print('%d step learning %s perforamnce:'%(args.steps, args.method))
    print(results_str)
    print('Done')


if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified

method:  DSA
seed:  0
class_order:  [26, 86, 2, 55, 75, 93, 16, 73, 54, 95, 53, 92, 78, 13, 7, 30, 22, 24, 33, 8, 43, 62, 3, 71, 45, 48, 6, 99, 82, 76, 60, 80, 90, 68, 51, 27, 18, 56, 63, 74, 1, 61, 42, 41, 4, 15, 17, 40, 38, 5, 91, 59, 0, 34, 28, 50, 11, 35, 23, 52, 10, 31, 66, 57, 79, 85, 32, 84, 14, 89, 19, 29, 49, 97, 98, 69, 20, 94, 72, 77, 25, 37, 81, 46, 39, 65, 58, 12, 88, 70, 87, 36, 21, 83, 9, 96, 67, 64, 47, 44]
augmentation strategy: 
 color_crop_cutout_flip_scale_rotate
augmentation parameters: 
 {'aug_mode': 'S', 'prob_flip': 0.5, 'ratio_scale': 1.2, 'ratio_rotate': 15.0, 'ratio_crop_pad': 0.125, 'ratio_cutout': 0.5, 'brightness': 1.0, 'saturation': 2.0, 'contrast': 0.5}


FileNotFoundError: [Errno 2] No such file or directory: './data\\metasets\\cl_data\\cl_res_DSA_CIFAR100_ConvNet_20ipc_10steps_seed0.pt'

In [31]:
pip install avalanche-lib==0.5

Note: you may need to restart the kernel to use updated packages.




In [32]:
import torch
import torchvision
from avalanche.benchmarks.datasets import CIFAR10, CIFAR100
from avalanche.benchmarks.utils import classification_dataset
import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from avalanche.training.supervised import (
    Cumulative, Naive, ICaRL, LwF, EWC, GenerativeReplay, JointTraining, CWRStar
)

from avalanche.evaluation.metrics import (
    Accuracy, TaskAwareAccuracy, accuracy_metrics, loss_metrics, forgetting_metrics,
    cpu_usage_metrics, gpu_usage_metrics, MAC_metrics
)
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin, ReplayPlugin
import avalanche.checkpointing.checkpoint

import os
import json
parser = argparse.ArgumentParser(description='Parameter Processing')
method='DSA'
dataset='CIFAR100'
model='ConvNet'
ipc=1 # help='image(s) per class'
eval_mode='S' # help='eval_mode'
num_exp=5 # help='the number of experiments'
num_eval=20 # help='the number of evaluating randomly initialized models')
epoch_eval_train=300 # help='epochs to train a model with synthetic data')
Iteration=1000 # help='training iterations')
lr_img=0.1 # help='learning rate for updating synthetic images')
lr_net=0.01 # help='learning rate for updating network parameters')
batch_real=256 # help='batch size for real data')
batch_train=256 # help='batch size for training networks')
init='noise' # help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
dsa_strategy='color_crop_cutout_flip_scale_rotate' # help='differentiable Siamese augmentation strategy')
data_path='./data' # help='dataset path')
save_path='result'# help='path to save results')
dis_metric='ours' # help='distance metric')
steps=10 # help='5/10-step learning')
num_eval=3 # help='evaluation number')

# Handle unrecognized arguments

device =  'cpu'
dsa = True # augment images for all methods
dsa_strategy = 'color_crop_cutout_flip_scale_rotate' # for CIFAR10/100

if not os.path.exists(data_path):
    os.mkdir(data_path)

if not os.path.exists(save_path):
    os.mkdir(save_path)

eval_it_pool = np.arange(0, Iteration+1, 500).tolist() if eval_mode == 'S' or eval_mode == 'SS' else [args.Iteration]
print('eval_it_pool: ', eval_it_pool)

# Correctly loading CIFAR100 datasets
# channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(CIFAR100, args.data_path)
channel = 3
im_size = (32, 32)
num_classes = 100
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dst_train = datasets.CIFAR100('./data/cifar100', train=True, download=True, transform=transform) # no augmentation
dst_test = datasets.CIFAR100('./data/cifar100', train=False, download=True, transform=transform)
class_names = dst_train.classes

model_eval_pool = get_eval_pool('S','ConvNet','ConvNet')

# Iterate over CIFAR100 dataset
for i, example in enumerate(dst_train):
    pass
print("Num. examples processed: {}".format(i + 1))


eval_it_pool:  [0, 500, 1000]
Files already downloaded and verified
Files already downloaded and verified
Num. examples processed: 50000


In [33]:
from typing import (
    List,
    Any,
    Sequence,
    Union,
    Optional,
    TypeVar,
    Callable,
    Dict,
    Tuple,
    Mapping,
    overload,
)

In [34]:
from avalanche.benchmarks.utils.utils import*
from typing import TypeVar, SupportsInt, Sequence, Protocol
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import *

T_co = TypeVar("T_co", covariant=True)
TTargetType_co = TypeVar("TTargetType_co", covariant=True)
TTargetType = int

class IDataset(Protocol[T_co]):
    """
    Protocol definition of a Dataset.

    Note: no __add__ method is defined.
    """

    def __getitem__(self, index: int) -> T_co: ...

    def __len__(self) -> int: ...
        
class IDatasetWithTargets(IDataset[T_co], Protocol[T_co, TTargetType_co]):
    """
    Protocol definition of a Dataset that has a valid targets field.
    """

    @property
    def targets(self) -> Sequence[TTargetType_co]:
        """
        A sequence of elements describing the targets of each pattern.
        """
        ...
TClassificationDataset = TypeVar(
    "TClassificationDataset", bound="ClassificationDataset"
)        
class TaskAwareClassificationDataset(AvalancheDataset[T_co]):
    @property
    def task_pattern_indices(self) -> Dict[int, Sequence[int]]:
        """A dictionary mapping task ids to their sample indices."""
        return self.targets_task_labels.val_to_idx  # type: ignore

    @property
    def task_set(self: TClassificationDataset) -> TaskSet[TClassificationDataset]:
        """Returns the datasets's ``TaskSet``, which is a mapping <task-id,
        task-dataset>."""
        return TaskSet(self)

    def subset(self, indices):
        data = super().subset(indices)
        return data.with_transforms(self._flat_data._transform_groups.current_group)

    def concat(self, other):
        data = super().concat(other)
        return data.with_transforms(self._flat_data._transform_groups.current_group)

    def __hash__(self):
        return id(self)                
class TaskAwareSupervisedClassificationDataset(TaskAwareClassificationDataset[T_co]):
    # TODO: remove? ClassificationDataset should have targets
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert "targets" in self._data_attributes, (
            "The supervised version of the ClassificationDataset requires "
            + "the targets field"
        )
        assert "targets_task_labels" in self._data_attributes, (
            "The supervised version of the ClassificationDataset requires "
            + "the targets_task_labels field"
        )

    @property
    def targets(self) -> DataAttribute[TTargetType]:
        return self._data_attributes["targets"]

    @property
    def targets_task_labels(self) -> DataAttribute[int]:
        return self._data_attributes["targets_task_labels"]
        

class ISupportedClassificationDataset(IDatasetWithTargets[T_co, SupportsInt], Protocol):
    """
    Protocol definition of a Dataset that has a valid targets field (like the
    Datasets in the torchvision package) for classification.

    For classification purposes, the targets field must be a sequence of ints.
    describing the class label of each pattern.

    This class however describes a targets field as a sequence of elements
    that can be converted to `int`. The main reason for this choice is that
    the targets field of some torchvision datasets is a Tensor. This means that
    this protocol class supports both sequence of native ints and Tensor of ints
    (or longs).

    On the contrary, class :class:`IClassificationDataset` strictly
    defines a `targets` field as sequence of native `int`s.
    """

    @property
    def targets(self) -> Sequence[SupportsInt]:
        """
        A sequence of ints or a PyTorch Tensor or a NumPy ndarray describing the
        label of each pattern contained in the dataset.
        """
        ...
        
class TaskAwareSupervisedClassificationDataset(TaskAwareClassificationDataset[T_co]):
    # TODO: remove? ClassificationDataset should have targets
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert "targets" in self._data_attributes, (
            "The supervised version of the ClassificationDataset requires "
            + "the targets field"
        )
        assert "targets_task_labels" in self._data_attributes, (
            "The supervised version of the ClassificationDataset requires "
            + "the targets_task_labels field"
        )

    @property
    def targets(self) -> DataAttribute[TTargetType]:
        return self._data_attributes["targets"]

    @property
    def targets_task_labels(self) -> DataAttribute[int]:
        return self._data_attributes["targets_task_labels"]      


In [35]:
from avalanche.benchmarks.utils.transform_groups import*
def _as_taskaware_supervised_classification_dataset(
    dataset,
    *,
    transform: Optional[XTransform] = None,
    target_transform: Optional[YTransform] = None,
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
    initial_transform_group: Optional[str] = None,
    task_labels: Optional[Union[int, Sequence[int]]] = None,
    targets: Optional[Sequence[TTargetType]] = None,
    collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
    if (
        transform is not None
        or target_transform is not None
        or transform_groups is not None
        or initial_transform_group is not None
        or task_labels is not None
        or targets is not None
        or collate_fn is not None
        or not isinstance(dataset, TaskAwareSupervisedClassificationDataset)
    ):
        result_dataset = _make_taskaware_classification_dataset(
            dataset=dataset,
            transform=transform,
            target_transform=target_transform,
            transform_groups=transform_groups,
            initial_transform_group=initial_transform_group,
            task_labels=task_labels,
            targets=targets,
            collate_fn=collate_fn,
        )

        if not isinstance(result_dataset, TaskAwareSupervisedClassificationDataset):
            raise ValueError(
                "The given dataset does not have supervision fields "
                "(targets, task_labels)."
            )

        return result_dataset

    return dataset

In [36]:
class SimpleCNN(nn.Module):
    

    def __init__(self, num_classes=110):
        super(SimpleCNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=0.25),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=0.25),
            nn.Conv2d(64, 64, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.AdaptiveMaxPool2d(1),
            nn.Dropout(p=0.25),
        )
        self.classifier = nn.Sequential(nn.Linear(64, num_classes))


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [37]:
from typing import (
    List,
    Any,
    Sequence,
    Union,
    Optional,
    TypeVar,
    Callable,
    Dict,
    Tuple,
    Mapping,
    overload,
)
import random
from pathlib import Path
from typing import Sequence, Optional, Union, Any
from avalanche.benchmarks.utils.classification_dataset import *
from torchvision import transforms
from avalanche.benchmarks.utils.utils import (
    TaskSet,
    _count_unique,
    find_common_transforms_group,
    _init_task_labels,
    _init_transform_groups,
    _split_user_def_targets,
    _split_user_def_task_label,
    _traverse_supported_dataset,
)

from avalanche.benchmarks.classic.classic_benchmarks_utils import (
    check_vision_benchmark,
)

from avalanche.benchmarks.datasets.external_datasets.cifar import (
    get_cifar100_dataset,
    get_cifar10_dataset,
)

def _concat_taskaware_classification_datasets_sequentially(
    train_dataset_list: Sequence[ISupportedClassificationDataset],
    test_dataset_list: Sequence[ISupportedClassificationDataset],
) -> Tuple[
    TaskAwareSupervisedClassificationDataset,
    TaskAwareSupervisedClassificationDataset,
    List[list],
]:
    
    remapped_train_datasets: List[TaskAwareSupervisedClassificationDataset] = []
    remapped_test_datasets: List[TaskAwareSupervisedClassificationDataset] = []
    next_remapped_idx = 0

    train_dataset_list_sup = list(
        map(_as_taskaware_supervised_classification_dataset, train_dataset_list)
    )
    test_dataset_list_sup = list(
        map(_as_taskaware_supervised_classification_dataset, test_dataset_list)
    )
    del train_dataset_list
    del test_dataset_list

    # Obtain the number of classes of each dataset
    classes_per_dataset = [
        _count_unique(
            train_dataset_list_sup[dataset_idx].targets,
            test_dataset_list_sup[dataset_idx].targets,
        )
        for dataset_idx in range(len(train_dataset_list_sup))
    ]

    new_class_ids_per_dataset = []
    for dataset_idx in range(len(train_dataset_list_sup)):
        # Get the train and test sets of the dataset
        train_set = train_dataset_list_sup[dataset_idx]
        test_set = test_dataset_list_sup[dataset_idx]

        # Get the classes in the dataset
        dataset_classes = set(map(int, train_set.targets))

        # The class IDs for this dataset will be in range
        # [n_classes_in_previous_datasets,
        #       n_classes_in_previous_datasets + classes_in_this_dataset)
        new_classes = list(
            range(
                next_remapped_idx,
                next_remapped_idx + classes_per_dataset[dataset_idx],
            )
        )
        new_class_ids_per_dataset.append(new_classes)

        # AvalancheSubset is used to apply the class IDs transformation.
        # Remember, the class_mapping parameter must be a list in which:
        # new_class_id = class_mapping[original_class_id]
        # Hence, a list of size equal to the maximum class index is created
        # Only elements corresponding to the present classes are remapped
        class_mapping = [-1] * (max(dataset_classes) + 1)
        j = 0
        for i in dataset_classes:
            class_mapping[i] = new_classes[j]
            j += 1

        a = _taskaware_classification_subset(train_set, class_mapping=class_mapping)

        # Create remapped datasets and append them to the final list
        remapped_train_datasets.append(
            _taskaware_classification_subset(train_set, class_mapping=class_mapping)
        )
        remapped_test_datasets.append(
            _taskaware_classification_subset(test_set, class_mapping=class_mapping)
        )
        next_remapped_idx += classes_per_dataset[dataset_idx]

    return (
        _concat_taskaware_classification_datasets(remapped_train_datasets),
        _concat_taskaware_classification_datasets(remapped_test_datasets),
        new_class_ids_per_dataset,
    )

from avalanche.benchmarks import nc_benchmark, NCScenario

_default_cifar100_train_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
    ]
)

_default_cifar100_eval_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
    ]
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

steps=10 # help='5/10-step learning')
num_classes = 100
num_classes_step = num_classes // steps
class_order = np.random.permutation(num_classes).tolist()


''' organize the real dataset '''
images_all = []
labels_all = []
indices_class = [[] for c in range(num_classes)]

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)


# for c in range(num_classes):
#     print('class c = %d: %d real images' % (c, len(indices_class[c])))


def get_images(c, n):  # get random n images from class c
    idx_shuffle = np.random.permutation(indices_class[c])[:n]
    return images_all[idx_shuffle]

data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
classes_seen = class_order[: (10)*num_classes_step]
print('classes_seen: ', classes_seen)


if method == 'DSA':
    # fname = os.path.join(data_path, 'metasets', 'cl_data', 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(steps, 0))
    fname = os.path.join(data_path,'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(steps, 0))
    data = torch.load(fname, map_location='cpu')['data']
    images_train_all = [data[step][0] for step in range(steps)]
    labels_train_all = [data[step][1] for step in range(steps)]
    print('use data: ', fname)

elif method == 'DM':
    # fname = os.path.join(data_path, 'metasets', 'cl_data', 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, 0))
    fname = os.path.join(data_path,'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(steps, 0))
    data = torch.load(fname, map_location='cpu')['data']
    images_train_all = [data[step][0] for step in range(steps)]
    labels_train_all = [data[step][1] for step in range(steps)]
    print('use data: ', fname)

else:
    exit('unknown method: %s'%method)


''' train data '''
images_train = torch.cat(images_train_all[:10], dim=0).to(device)
labels_train = torch.cat(labels_train_all[:10], dim=0).to(device)
print('train data size: ', images_train.shape)


''' test data '''
images_test = []
labels_test = []
for i in range(len(dst_test)):
    lab = int(dst_test[i][1])
    if lab in classes_seen:
        images_test.append(torch.unsqueeze(dst_test[i][0], dim=0))
        labels_test.append(dst_test[i][1])

images_test = torch.cat(images_test, dim=0).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
dst_test_current = TensorDataset(images_test, labels_test)
testloader = torch.utils.data.DataLoader(dst_test_current, batch_size=256, shuffle=False, num_workers=0)

print('test set size: ', images_test.shape)


def cifar100_distilled_benchmark():
    return nc_benchmark(
        train_dataset=images_train,
        test_dataset=images_test,
        n_experiences=steps,
        task_labels=False
    )

    

if __name__ == "__main__":

    print("cifar100 distilled")
    benchmark_instance = cifar100_distilled_benchmark()
    check_vision_benchmark(benchmark_instance)

   




classes_seen:  [97, 92, 34, 33, 70, 14, 39, 50, 16, 96, 90, 18, 19, 73, 94, 28, 22, 65, 11, 79, 38, 27, 30, 47, 69, 21, 48, 13, 44, 63, 25, 42, 7, 80, 46, 49, 20, 81, 1, 4, 58, 72, 74, 55, 5, 43, 32, 95, 54, 51, 67, 31, 99, 26, 76, 52, 10, 86, 71, 62, 3, 6, 59, 57, 60, 77, 98, 82, 45, 87, 93, 0, 2, 75, 23, 78, 15, 64, 17, 84, 88, 91, 8, 53, 29, 40, 85, 56, 89, 83, 37, 35, 41, 9, 61, 66, 68, 12, 36, 24]
use data:  C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data\cl_res_DSA_CIFAR100_ConvNet_20ipc_10steps_seed0.pt
train data size:  torch.Size([2000, 3, 32, 32])
test set size:  torch.Size([10000, 3, 32, 32])
cifar100 distilled


ValueError: Unsupported dataset: must have a valid targets field or has to be a Tensor Dataset with at least 2 Tensors

In [None]:
import os
import torch
from torch.utils.data import TensorDataset
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from avalanche.benchmarks import nc_benchmark
import numpy as np

# Définition des arguments
class Args:
    def __init__(self):
        self.method = 'DSA'
        self.dataset = 'CIFAR100'
        self.model = 'ConvNet'
        self.ipc = 20
        self.steps = 10  # Assurez-vous que 'steps' est défini
        self.epoch_eval_train = 300
        self.lr_net = 0.01
        self.batch_train = 256
        self.data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
        self.save_path = './result'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

# Initialisation de la variable device
device = args.device

# Charger les données
def load_data():
    if args.method == 'DSA':
        fname = os.path.join(args.data_path, 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    elif args.method == 'DM':
        fname = os.path.join(args.data_path, 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    else:
        raise ValueError(f'Unknown method: {args.method}')

    # Vérifier l'existence du fichier
    if not os.path.exists(fname):
        raise FileNotFoundError(f"File not found: {fname}")
    
    # Charger le fichier
    print(f"Loading file from: {fname}")
    data = torch.load(fname, map_location='cpu')['data']
    
    images_train_all = [data[step][0] for step in range(args.steps)]
    labels_train_all = [data[step][1] for step in range(args.steps)]
    
    return images_train_all, labels_train_all

images_train_all, labels_train_all = load_data()

# Concaténer les données
images_train = torch.cat(images_train_all, dim=0).to(device)
labels_train = torch.cat(labels_train_all, dim=0).to(device)
print('Train data size: ', images_train.shape)

# Charger et préparer les données de test
dst_test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

# Préparer les données de test
images_test = []
labels_test = []
for i in range(len(dst_test)):
    img, lab = dst_test[i]
    if lab in set(labels_train.cpu().numpy()):  # Si l'étiquette de test fait partie des classes d'entraînement
        images_test.append(img)
        labels_test.append(lab)

images_test = torch.stack(images_test).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
print('Test data size: ', images_test.shape)

# Créer les TensorDatasets pour l'entraînement et les tests
train_dataset = TensorDataset(images_train, labels_train)
test_dataset = TensorDataset(images_test, labels_test)

# Fonction pour créer le benchmark
def cifar100_distilled_benchmark(steps):
    return nc_benchmark(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        n_experiences=steps,
        task_labels=False
    )

# Code principal
if __name__ == "__main__":
    print("CIFAR100 distilled")
    benchmark_instance = cifar100_distilled_benchmark(args.steps)
    # Suite de votre code pour entraîner et évaluer le modèle...


In [55]:
#At each experience, train model with data from all previous experiences

import argparse
import sys
import os
import torch
import torchvision
from avalanche.benchmarks.utils import classification_dataset
import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss

from avalanche.training.supervised import (
    Cumulative, Naive, ICaRL, LwF, EWC, GenerativeReplay, JointTraining, CWRStar
)

from avalanche.evaluation.metrics import (
    Accuracy, TaskAwareAccuracy, accuracy_metrics, loss_metrics, forgetting_metrics,
    cpu_usage_metrics, gpu_usage_metrics, MAC_metrics
)
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin, ReplayPlugin
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
from avalanche.training.determinism.rng_manager import RNGManager

def main_with_checkpointing(args):
    # STEP 1: SET THE RANDOM SEEDS to guarantee reproducibility
    RNGManager.set_random_seeds(1234)

    # Nothing new here...
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )
    print("Using device", device)

    # CL Benchmark Creation (as usual)
    benchmark = cifar100_distilled_benchmark(args.steps)  # Pass steps as an argument
    model = SimpleCNN(num_classes=100)
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Create the evaluation plugin (as usual)
    evaluation_plugin = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True), loggers=[InteractiveLogger()]
    )

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=[interactive_logger],
    )

    # Create the strategy using Cumulative
    strategy = Cumulative(
        model=model,                # Ensure your model is capable of handling the increased number of classes
        optimizer=optimizer,        # Optimizer, e.g., Adam or SGD
        criterion=criterion,        # Loss function, e.g., CrossEntropyLoss
        train_mb_size=64,           # Batch size; adjust based on your hardware and model requirements
        train_epochs=30,            # Increased number of epochs to ensure adequate training
        eval_mb_size=64,            # Evaluation batch size
        device=device,              # Device, e.g., 'cuda' or 'cpu'
        evaluator=eval_plugin # Evaluation plugin or metric, e.g., accuracy
    )

    # STEP 2: TRY TO LOAD THE LAST CHECKPOINT
    # if the checkpoint exists, load it into the newly created strategy
    # the method also loads the experience counter, so we know where to
    # resume training
    fname = "./checkpoint/Cumulative.pkl"  # name of the checkpoint file
    os.makedirs(os.path.dirname(fname), exist_ok=True)  # Ensure the checkpoint directory exists
    strategy, initial_exp = maybe_load_checkpoint(strategy, fname)

    # STEP 3: USE THE "initial_exp" to resume training
    for train_exp in benchmark.train_stream[initial_exp:]:
        strategy.train(train_exp, num_workers=4, persistent_workers=True)
        strategy.eval(benchmark.test_stream, num_workers=4)

        # STEP 4: SAVE the checkpoint after training on each experience.
        save_checkpoint(strategy, fname)


if __name__ == "__main__":
    if 'ipykernel_launcher' in sys.argv[0]:
        # Running in a Jupyter environment
        class Args:
            cuda = 0
            steps = 10  # Définir une valeur par défaut pour steps
        args = Args()
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--cuda",
            type=int,
            default=0,
            help="Select zero-indexed cuda device. -1 to use CPU.",
        )
        parser.add_argument(
            "--steps",
            type=int,
            default=10,
            help="Number of steps for the benchmark.",
        )
        # Parse known arguments and ignore the rest
        args, _ = parser.parse_known_args(sys.argv)
    main_with_checkpointing(args)



Using device cuda:0
Mapping cuda:0 to cuda:0
[InteractiveLogger] Resuming from checkpoint. Current time is 2024-08-04 14:34:41 +0100


AttributeError: 'dict' object has no attribute 'train_stream'

# cl with dataset distillation

## Cumulative

In [None]:
import argparse
import sys
import os
import torch
import torchvision
from avalanche.benchmarks.utils import classification_dataset


import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from avalanche.training.supervised import Cumulative
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
from avalanche.training.determinism.rng_manager import RNGManager
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data import TensorDataset
from avalanche.benchmarks import nc_benchmark

# Définition des arguments
class Args:
    def __init__(self):
        self.method = 'DSA'
        self.dataset = 'CIFAR100'
        self.model = 'ConvNet'
        self.ipc = 20
        self.steps = 10  # Assurez-vous que 'steps' est défini
        self.epoch_eval_train = 300
        self.lr_net = 0.01
        self.batch_train = 256
        self.data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
        self.save_path = './result'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

# Initialisation de la variable device
device = args.device
device = 'cpu'
# Charger les données
def load_data():
    if args.method == 'DSA':
        fname = os.path.join(args.data_path, 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    elif args.method == 'DM':
        fname = os.path.join(args.data_path, 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    else:
        raise ValueError(f'Unknown method: {args.method}')

    # Vérifier l'existence du fichier
    if not os.path.exists(fname):
        raise FileNotFoundError(f"File not found: {fname}")
    
    # Charger le fichier
    print(f"Loading file from: {fname}")
    data = torch.load(fname, map_location='cpu')['data']
    
    images_train_all = [data[step][0] for step in range(args.steps)]
    labels_train_all = [data[step][1] for step in range(args.steps)]
    
    return images_train_all, labels_train_all

images_train_all, labels_train_all = load_data()

# Concaténer les données
images_train = torch.cat(images_train_all, dim=0).to(device)
labels_train = torch.cat(labels_train_all, dim=0).to(device)
print('Train data size: ', images_train.shape)

# Charger et préparer les données de test
dst_test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

# Préparer les données de test
images_test = []
labels_test = []
for i in range(len(dst_test)):
    img, lab = dst_test[i]
    if lab in set(labels_train.cpu().numpy()):  # Si l'étiquette de test fait partie des classes d'entraînement
        images_test.append(img)
        labels_test.append(lab)

images_test = torch.stack(images_test).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
print('Test data size: ', images_test.shape)

# Créer les TensorDatasets pour l'entraînement et les tests
train_dataset = TensorDataset(images_train, labels_train)
test_dataset = TensorDataset(images_test, labels_test)

# Définir un modèle simple de CNN
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class SimpleCNN64(nn.Module):
    #p=0.05

    def __init__(self, num_classes=100):
        super(SimpleCNN64, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=0.05),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=0.05),
            nn.Conv2d(64, 64, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.AdaptiveMaxPool2d(1),
            nn.Dropout(p=0.05),
        )
        self.classifier = nn.Sequential(nn.Linear(64, num_classes))


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Fonction pour créer le benchmark
def cifar100_distilled_benchmark(steps):
    return nc_benchmark(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        n_experiences=steps,
        task_labels=False
    )

def main_with_checkpointing(args):
    # STEP 1: SET THE RANDOM SEEDS to guarantee reproducibility
    RNGManager.set_random_seeds(1234)

    # Nothing new here...
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )
    print("Using device", device)

    # CL Benchmark Creation (as usual)
    benchmark = cifar100_distilled_benchmark(args.steps)  # Pass steps as an argument
    model = SimpleCNN64(num_classes=100)
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Create the evaluation plugin (as usual)
    evaluation_plugin = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True), loggers=[InteractiveLogger()]
    )

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=[interactive_logger],
    )

    # Create the strategy using Cumulative
    strategy = Cumulative(
        model=model,                # Ensure your model is capable of handling the increased number of classes
        optimizer=optimizer,        # Optimizer, e.g., Adam or SGD
        criterion=criterion,        # Loss function, e.g., CrossEntropyLoss
        train_mb_size=400,           # Batch size; adjust based on your hardware and model requirements
        train_epochs=30,            # Increased number of epochs to ensure adequate training
        eval_mb_size=2000,            # Evaluation batch size
        device=device,              # Device, e.g., 'cuda' or 'cpu'
        evaluator=eval_plugin # Evaluation plugin or metric, e.g., accuracy
    )

    # STEP 2: TRY TO LOAD THE LAST CHECKPOINT
    # if the checkpoint exists, load it into the newly created strategy
    # the method also loads the experience counter, so we know where to
    # resume training
    fname = "./checkpoint/Cumulative.pkl"  # name of the checkpoint file
    os.makedirs(os.path.dirname(fname), exist_ok=True)  # Ensure the checkpoint directory exists
    #strategy, initial_exp = maybe_load_checkpoint(strategy, fname)


    initial_exp=0
    # STEP 3: USE THE "initial_exp" to resume training
    for train_exp in benchmark.train_stream[initial_exp:]:
        strategy.train(train_exp, num_workers=4, persistent_workers=True)
        strategy.eval(benchmark.test_stream, num_workers=4)

        # STEP 4: SAVE the checkpoint after training on each experience.
        save_checkpoint(strategy, fname)


if __name__ == "__main__":
    if 'ipykernel_launcher' in sys.argv[0]:
        # Running in a Jupyter environment
        class Args:
            cuda = 0
            steps = 10  # Définir une valeur par défaut pour steps
        args = Args()
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--cuda",
            type=int,
            default=0,
            help="Select zero-indexed cuda device. -1 to use CPU.",
        )
        parser.add_argument(
            "--steps",
            type=int,
            default=10,
            help="Number of steps for the benchmark.",
        )
        # Parse known arguments and ignore the rest
        args, _ = parser.parse_known_args(sys.argv)
    main_with_checkpointing(args)


## naive

In [3]:
import argparse
import sys
import os
import torch
import torchvision
from avalanche.benchmarks.utils import classification_dataset

from avalanche.training.supervised import (
    Cumulative, Naive, ICaRL, LwF, EWC, GenerativeReplay, JointTraining, CWRStar
)
import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from avalanche.training.supervised import Cumulative
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
from avalanche.training.determinism.rng_manager import RNGManager
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data import TensorDataset
from avalanche.benchmarks import nc_benchmark

# Définition des arguments
class Args:
    def __init__(self):
        self.method = 'DSA'
        self.dataset = 'CIFAR100'
        self.model = 'ConvNet'
        self.ipc = 20
        self.steps = 10  # Assurez-vous que 'steps' est défini
        self.epoch_eval_train = 300
        self.lr_net = 0.01
        self.batch_train = 256
        self.data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
        self.save_path = './result'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

# Initialisation de la variable device
device = args.device
device = 'cpu'
# Charger les données
def load_data():
    if args.method == 'DSA':
        fname = os.path.join(args.data_path, 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    elif args.method == 'DM':
        fname = os.path.join(args.data_path, 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    else:
        raise ValueError(f'Unknown method: {args.method}')

    # Vérifier l'existence du fichier
    if not os.path.exists(fname):
        raise FileNotFoundError(f"File not found: {fname}")
    
    # Charger le fichier
    print(f"Loading file from: {fname}")
    data = torch.load(fname, map_location='cpu')['data']
    
    images_train_all = [data[step][0] for step in range(args.steps)]
    labels_train_all = [data[step][1] for step in range(args.steps)]
    
    return images_train_all, labels_train_all

images_train_all, labels_train_all = load_data()

# Concaténer les données
images_train = torch.cat(images_train_all, dim=0).to(device)
labels_train = torch.cat(labels_train_all, dim=0).to(device)
print('Train data size: ', images_train.shape)

# Charger et préparer les données de test
dst_test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

# Préparer les données de test
images_test = []
labels_test = []
for i in range(len(dst_test)):
    img, lab = dst_test[i]
    if lab in set(labels_train.cpu().numpy()):  # Si l'étiquette de test fait partie des classes d'entraînement
        images_test.append(img)
        labels_test.append(lab)

images_test = torch.stack(images_test).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
print('Test data size: ', images_test.shape)

# Créer les TensorDatasets pour l'entraînement et les tests
train_dataset = TensorDataset(images_train, labels_train)
test_dataset = TensorDataset(images_test, labels_test)

# Définir un modèle simple de CNN
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Fonction pour créer le benchmark
def cifar100_distilled_benchmark(steps):
    return nc_benchmark(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        n_experiences=steps,
        task_labels=False
    )

def main_with_checkpointing(args):
    # STEP 1: SET THE RANDOM SEEDS to guarantee reproducibility
    RNGManager.set_random_seeds(1234)

    # Nothing new here...
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )
    print("Using device", device)

    # CL Benchmark Creation (as usual)
    benchmark = cifar100_distilled_benchmark(args.steps)  # Pass steps as an argument
    model = SimpleCNN(num_classes=100)
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Create the evaluation plugin (as usual)
    evaluation_plugin = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True), loggers=[InteractiveLogger()]
    )
    

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=[interactive_logger],
    )
    

    # Create the strategy (as usual)
    strategy = Naive(
        model=model,                # Ensure your model is capable of handling the increased number of classes
        optimizer=optimizer,        # Optimizer, e.g., Adam or SGD
        criterion=criterion,        # Loss function, e.g., CrossEntropyLoss
        train_mb_size=64,           # Batch size; adjust based on your hardware and model requirements
        train_epochs=30,            # Increased number of epochs to ensure adequate training
        eval_mb_size=64,            # Evaluation batch size
        device=device,              # Device, e.g., 'cuda' or 'cpu'
        evaluator=eval_plugin # Evaluation plugin or metric, e.g., accuracy
    )

    
    # STEP 2: TRY TO LOAD THE LAST CHECKPOINT
    # if the checkpoint exists, load it into the newly created strategy
    # the method also loads the experience counter, so we know where to
    # resume training
    fname = "./checkpoint/Naive1.pkl"  # name of the checkpoint file
    os.makedirs(os.path.dirname(fname), exist_ok=True)  # Ensure the checkpoint directory exists
    #strategy, initial_exp = maybe_load_checkpoint(strategy, fname)

    initial_exp=0
    # STEP 3: USE THE "initial_exp" to resume training
    for train_exp in benchmark.train_stream[initial_exp:]:
        strategy.train(train_exp, num_workers=4, persistent_workers=True)
        strategy.eval(benchmark.test_stream, num_workers=4)

        # STEP 4: SAVE the checkpoint after training on each experience.
        save_checkpoint(strategy, fname)


if __name__ == "__main__":
    if 'ipykernel_launcher' in sys.argv[0]:
        # Running in a Jupyter environment
        class Args:
            cuda = 0
            steps = 10  # Définir une valeur par défaut pour steps
        args = Args()
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--cuda",
            type=int,
            default=0,
            help="Select zero-indexed cuda device. -1 to use CPU.",
        )
        parser.add_argument(
            "--steps",
            type=int,
            default=10,
            help="Number of steps for the benchmark.",
        )
        # Parse known arguments and ignore the rest
        args, _ = parser.parse_known_args(sys.argv)
    main_with_checkpointing(args)


Loading file from: C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data\cl_res_DSA_CIFAR100_ConvNet_20ipc_10steps_seed0.pt
Train data size:  torch.Size([2000, 3, 32, 32])
Files already downloaded and verified
Test data size:  torch.Size([10000, 3, 32, 32])
Using device cuda:0
-- >> Start of training phase << --
100%|██████████| 4/4 [00:07<00:00,  1.93s/it]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 4.5677
	Loss_MB/train_phase/train_stream/Task000 = 4.4292
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.0850
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.1250
100%|██████████| 4/4 [00:00<00:00, 23.53it/s]
Epoch 1 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 3.9040
	Loss_MB/train_phase/train_stream/Task000 = 3.2662
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.1550
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.0000
100%|██████████| 4/4 [00:00<00:00, 45.10it/s]
Epoch 2 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 2.579

## Replay

In [19]:
import argparse
import sys
import os
import torch
import torchvision
from avalanche.benchmarks.utils import classification_dataset


import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from avalanche.training.supervised import Cumulative
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
from avalanche.training.determinism.rng_manager import RNGManager
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data import TensorDataset
from avalanche.benchmarks import nc_benchmark


from avalanche.training.supervised.strategy_wrappers import Replay



# Définition des arguments
class Args:
    def __init__(self):
        self.method = 'DSA'
        self.dataset = 'CIFAR100'
        self.model = 'ConvNet'
        self.ipc = 20
        self.steps = 10  # Assurez-vous que 'steps' est défini
        self.epoch_eval_train = 300
        self.lr_net = 0.01
        self.batch_train = 256
        self.data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
        self.save_path = './result'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

# Initialisation de la variable device
device = args.device
device = 'cpu'
# Charger les données
def load_data():
    if args.method == 'DSA':
        fname = os.path.join(args.data_path, 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    elif args.method == 'DM':
        fname = os.path.join(args.data_path, 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    else:
        raise ValueError(f'Unknown method: {args.method}')

    # Vérifier l'existence du fichier
    if not os.path.exists(fname):
        raise FileNotFoundError(f"File not found: {fname}")
    
    # Charger le fichier
    print(f"Loading file from: {fname}")
    data = torch.load(fname, map_location='cpu')['data']
    
    images_train_all = [data[step][0] for step in range(args.steps)]
    labels_train_all = [data[step][1] for step in range(args.steps)]
    
    return images_train_all, labels_train_all

images_train_all, labels_train_all = load_data()

# Concaténer les données
images_train = torch.cat(images_train_all, dim=0).to(device)
labels_train = torch.cat(labels_train_all, dim=0).to(device)
print('Train data size: ', images_train.shape)

# Charger et préparer les données de test
dst_test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

# Préparer les données de test
images_test = []
labels_test = []
for i in range(len(dst_test)):
    img, lab = dst_test[i]
    if lab in set(labels_train.cpu().numpy()):  # Si l'étiquette de test fait partie des classes d'entraînement
        images_test.append(img)
        labels_test.append(lab)

images_test = torch.stack(images_test).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
print('Test data size: ', images_test.shape)

# Créer les TensorDatasets pour l'entraînement et les tests
train_dataset = TensorDataset(images_train, labels_train)
test_dataset = TensorDataset(images_test, labels_test)

# Définir un modèle simple de CNN
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class SimpleCNN1(nn.Module):
    

    def __init__(self, num_classes=100):
        super(SimpleCNN1, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=0.25),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=0.25),
            nn.Conv2d(64, 64, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.AdaptiveMaxPool2d(1),
            nn.Dropout(p=0.25),
        )
        self.classifier = nn.Sequential(nn.Linear(64, num_classes))


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Fonction pour créer le benchmark
def cifar100_distilled_benchmark(steps):
    return nc_benchmark(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        n_experiences=steps,
        task_labels=False
    )

def main_with_checkpointing(args):
    # STEP 1: SET THE RANDOM SEEDS to guarantee reproducibility
    RNGManager.set_random_seeds(1234)

    # Nothing new here...
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )
    print("Using device", device)

    # CL Benchmark Creation (as usual)
    benchmark = cifar100_distilled_benchmark(args.steps)  # Pass steps as an argument
    model = SimpleCNN1(num_classes=100)
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Create the evaluation plugin (as usual)
    evaluation_plugin = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True), loggers=[InteractiveLogger()]
    )

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=[interactive_logger],
    )

    # Create the strategy using Replay
    strategy = Replay(
        model=model,                # Ensure your model is capable of handling the increased number of classes
        optimizer=optimizer,        # Optimizer, e.g., Adam or SGD
        criterion=criterion,        # Loss function, e.g., CrossEntropyLoss
        mem_size=2000,               # Size of the replay buffer
        train_mb_size=64,           # Batch size; adjust based on your hardware and model requirements
        train_epochs=30,            # Increased number of epochs to ensure adequate training
        eval_mb_size=64,            # Evaluation batch size
        device=device,              # Device, e.g., 'cuda' or 'cpu'
        evaluator=eval_plugin # Evaluation plugin or metric, e.g., accuracy
    )


    # STEP 2: TRY TO LOAD THE LAST CHECKPOINT
    # if the checkpoint exists, load it into the newly created strategy
    # the method also loads the experience counter, so we know where to
    # resume training
    fname = "./checkpoint/Replay.pkl"  # name of the checkpoint file
    os.makedirs(os.path.dirname(fname), exist_ok=True)  # Ensure the checkpoint directory exists
    #strategy, initial_exp = maybe_load_checkpoint(strategy, fname)


    initial_exp=0
    # STEP 3: USE THE "initial_exp" to resume training
    for train_exp in benchmark.train_stream[initial_exp:]:
        strategy.train(train_exp, num_workers=4, persistent_workers=True)
        strategy.eval(benchmark.test_stream, num_workers=4)

        # STEP 4: SAVE the checkpoint after training on each experience.
        save_checkpoint(strategy, fname)


if __name__ == "__main__":
    if 'ipykernel_launcher' in sys.argv[0]:
        # Running in a Jupyter environment
        class Args:
            cuda = 0
            steps = 10  # Définir une valeur par défaut pour steps
        args = Args()
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--cuda",
            type=int,
            default=0,
            help="Select zero-indexed cuda device. -1 to use CPU.",
        )
        parser.add_argument(
            "--steps",
            type=int,
            default=10,
            help="Number of steps for the benchmark.",
        )
        # Parse known arguments and ignore the rest
        args, _ = parser.parse_known_args(sys.argv)
    main_with_checkpointing(args)


Loading file from: C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data\cl_res_DSA_CIFAR100_ConvNet_20ipc_10steps_seed0.pt
Train data size:  torch.Size([2000, 3, 32, 32])
Files already downloaded and verified
Test data size:  torch.Size([10000, 3, 32, 32])
Using device cuda:0
-- >> Start of training phase << --
100%|██████████| 4/4 [00:05<00:00,  1.38s/it]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 4.5983
	Loss_MB/train_phase/train_stream/Task000 = 4.5959
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.0050
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.0000
100%|██████████| 4/4 [00:00<00:00, 21.55it/s]
Epoch 1 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 4.5769
	Loss_MB/train_phase/train_stream/Task000 = 4.5570
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.0050
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.0000
100%|██████████| 4/4 [00:00<00:00, 46.36it/s]
Epoch 2 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 4.539

##  GenerativeReplay

In [44]:
# Define a simple VAE generator model (example, adjust to your needs)
class MlpVAE(nn.Module):
    def __init__(self, input_shape, nhid, device):
        super(MlpVAE, self).__init__()
        self.device = device
        self.fc1 = nn.Linear(np.prod(input_shape), nhid)
        self.fc21 = nn.Linear(nhid, nhid)
        self.fc22 = nn.Linear(nhid, nhid)
        self.fc3 = nn.Linear(nhid, np.prod(input_shape))
        self.fc4 = nn.Linear(nhid, np.prod(input_shape))
    
    def encode(self, x):
        h1 = F.relu(self.fc1(x.view(-1, np.prod(x.size()[1:]))))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar




In [45]:
import argparse
import sys
import os
import torch
import torchvision
from avalanche.benchmarks.utils import classification_dataset


import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from avalanche.training.supervised import Cumulative
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
from avalanche.training.determinism.rng_manager import RNGManager
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data import TensorDataset
from avalanche.benchmarks import nc_benchmark


from avalanche.training.supervised.strategy_wrappers import Replay

from torch.optim import Adam, SGD
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
import sys
from avalanche.training.determinism.rng_manager import RNGManager
from avalanche.training.plugins import EvaluationPlugin


# Définition des arguments
class Args:
    def __init__(self):
        self.method = 'DSA'
        self.dataset = 'CIFAR100'
        self.model = 'ConvNet'
        self.ipc = 20
        self.steps = 10  # Assurez-vous que 'steps' est défini
        self.epoch_eval_train = 300
        self.lr_net = 0.01
        self.batch_train = 256
        self.data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
        self.save_path = './result'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

# Initialisation de la variable device
device = args.device
device = 'cpu'
# Charger les données
def load_data():
    if args.method == 'DSA':
        fname = os.path.join(args.data_path, 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    elif args.method == 'DM':
        fname = os.path.join(args.data_path, 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    else:
        raise ValueError(f'Unknown method: {args.method}')

    # Vérifier l'existence du fichier
    if not os.path.exists(fname):
        raise FileNotFoundError(f"File not found: {fname}")
    
    # Charger le fichier
    print(f"Loading file from: {fname}")
    data = torch.load(fname, map_location='cpu')['data']
    
    images_train_all = [data[step][0] for step in range(args.steps)]
    labels_train_all = [data[step][1] for step in range(args.steps)]
    
    return images_train_all, labels_train_all

images_train_all, labels_train_all = load_data()

# Concaténer les données
images_train = torch.cat(images_train_all, dim=0).to(device)
labels_train = torch.cat(labels_train_all, dim=0).to(device)
print('Train data size: ', images_train.shape)

# Charger et préparer les données de test
dst_test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

# Préparer les données de test
images_test = []
labels_test = []
for i in range(len(dst_test)):
    img, lab = dst_test[i]
    if lab in set(labels_train.cpu().numpy()):  # Si l'étiquette de test fait partie des classes d'entraînement
        images_test.append(img)
        labels_test.append(lab)

images_test = torch.stack(images_test).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
print('Test data size: ', images_test.shape)

# Créer les TensorDatasets pour l'entraînement et les tests
train_dataset = TensorDataset(images_train, labels_train)
test_dataset = TensorDataset(images_test, labels_test)

# Définir un modèle simple de CNN
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Fonction pour créer le benchmark
def cifar100_distilled_benchmark(steps):
    return nc_benchmark(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        n_experiences=steps,
        task_labels=False
    )

def main_with_checkpointing(args):
    # STEP 1: SET THE RANDOM SEEDS to guarantee reproducibility
    RNGManager.set_random_seeds(1234)

    # Nothing new here...
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )
    print("Using device", device)

    # CL Benchmark Creation (as usual)
    benchmark = cifar100_distilled_benchmark(args.steps)  # Pass steps as an argument
    model = SimpleCNN(num_classes=100)
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Define the VAE generator model and its training strategy
    generator = MlpVAE((3, 32, 32), nhid=256, device=device)
    optimizer_generator = Adam(generator.parameters(), lr=0.01, weight_decay=0.0001)
    
    # Define the VAE training strategy
    from avalanche.training.templates import SupervisedTemplate
    from avalanche.training.plugins import GenerativeReplayPlugin, TrainGeneratorAfterExpPlugin
    
    class VAETraining(SupervisedTemplate):
        # Implementation of VAE training strategy
        pass
  

    vae_training_strategy = VAETraining(
        model=generator,
        optimizer=optimizer_generator,
        criterion=CrossEntropyLoss(),
        train_mb_size=64,
        train_epochs=30,
        eval_mb_size=64,
        device=device,
        plugins=[GenerativeReplayPlugin(replay_size=200)],
    )
    

    
    # Create the Generative Replay strategy
    strategy = GenerativeReplay(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_mb_size=64,
        train_epochs=30,
        eval_mb_size=64,
        device=device,
        generator_strategy=vae_training_strategy,  # The generator strategy
        replay_size=200,  # Size of the replay buffer
        increasing_replay_size=False,  # Whether to increase the replay buffer size over time
        evaluator=EvaluationPlugin(
            accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
            loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
            loggers=[InteractiveLogger()],
        ),
    )


    # STEP 2: TRY TO LOAD THE LAST CHECKPOINT
    # if the checkpoint exists, load it into the newly created strategy
    # the method also loads the experience counter, so we know where to
    # resume training
    fname = "./checkpoint/GenerativeReplay.pkl"  # name of the checkpoint file
    os.makedirs(os.path.dirname(fname), exist_ok=True)  # Ensure the checkpoint directory exists
    strategy, initial_exp = maybe_load_checkpoint(strategy, fname)

    # STEP 3: USE THE "initial_exp" to resume training
    for train_exp in benchmark.train_stream[initial_exp:]:
        strategy.train(train_exp, num_workers=4, persistent_workers=True)
        strategy.eval(benchmark.test_stream, num_workers=4)

        # STEP 4: SAVE the checkpoint after training on each experience.
        save_checkpoint(strategy, fname)


if __name__ == "__main__":
    if 'ipykernel_launcher' in sys.argv[0]:
        # Running in a Jupyter environment
        class Args:
            cuda = 0
            steps = 10  # Définir une valeur par défaut pour steps
        args = Args()
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--cuda",
            type=int,
            default=0,
            help="Select zero-indexed cuda device. -1 to use CPU.",
        )
        parser.add_argument(
            "--steps",
            type=int,
            default=10,
            help="Number of steps for the benchmark.",
        )
        # Parse known arguments and ignore the rest
        args, _ = parser.parse_known_args(sys.argv)
    main_with_checkpointing(args)


Loading file from: C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data\cl_res_DSA_CIFAR100_ConvNet_20ipc_10steps_seed0.pt
Train data size:  torch.Size([2000, 3, 32, 32])
Files already downloaded and verified
Test data size:  torch.Size([10000, 3, 32, 32])
Using device cuda:0
-- >> Start of training phase << --
100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 4.5652
	Loss_MB/train_phase/train_stream/Task000 = 4.3164
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.0700
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.1250
100%|██████████| 4/4 [00:00<00:00, 21.45it/s]
Epoch 1 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 3.8783
	Loss_MB/train_phase/train_stream/Task000 = 3.0401
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.1000
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.1250
100%|██████████| 4/4 [00:00<00:00, 28.53it/s]
Epoch 2 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 2.567

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x3072 and 256x3072)

In [46]:
# EWC :
"""Elastic Weight Consolidation (EWC) strategy.

    See EWC plugin for details.
    This strategy does not use task identities.
    """

'Elastic Weight Consolidation (EWC) strategy.\n\n    See EWC plugin for details.\n    This strategy does not use task identities.\n    '

In [None]:
import argparse
import sys
import os
import torch
import torchvision
from avalanche.benchmarks.utils import classification_dataset


import avalanche.benchmarks.scenarios.dataset_scenario

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from avalanche.training.supervised import Cumulative
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint
from avalanche.training.determinism.rng_manager import RNGManager
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data import TensorDataset
from avalanche.benchmarks import nc_benchmark


from avalanche.training.supervised.strategy_wrappers import Replay



# Définition des arguments
class Args:
    def __init__(self):
        self.method = 'DSA'
        self.dataset = 'CIFAR100'
        self.model = 'ConvNet'
        self.ipc = 20
        self.steps = 10  # Assurez-vous que 'steps' est défini
        self.epoch_eval_train = 300
        self.lr_net = 0.01
        self.batch_train = 256
        self.data_path = r'C:\Users\ahmed\OneDrive\Documents\GitHub\DatasetCondensation\cl_data'
        self.save_path = './result'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

# Initialisation de la variable device
device = args.device
device = 'cpu'
# Charger les données
def load_data():
    if args.method == 'DSA':
        fname = os.path.join(args.data_path, 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    elif args.method == 'DM':
        fname = os.path.join(args.data_path, 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt' % (args.steps, 0))
    else:
        raise ValueError(f'Unknown method: {args.method}')

    # Vérifier l'existence du fichier
    if not os.path.exists(fname):
        raise FileNotFoundError(f"File not found: {fname}")
    
    # Charger le fichier
    print(f"Loading file from: {fname}")
    data = torch.load(fname, map_location='cpu')['data']
    
    images_train_all = [data[step][0] for step in range(args.steps)]
    labels_train_all = [data[step][1] for step in range(args.steps)]
    
    return images_train_all, labels_train_all

images_train_all, labels_train_all = load_data()

# Concaténer les données
images_train = torch.cat(images_train_all, dim=0).to(device)
labels_train = torch.cat(labels_train_all, dim=0).to(device)
print('Train data size: ', images_train.shape)

# Charger et préparer les données de test
dst_test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

# Préparer les données de test
images_test = []
labels_test = []
for i in range(len(dst_test)):
    img, lab = dst_test[i]
    if lab in set(labels_train.cpu().numpy()):  # Si l'étiquette de test fait partie des classes d'entraînement
        images_test.append(img)
        labels_test.append(lab)

images_test = torch.stack(images_test).to(device)
labels_test = torch.tensor(labels_test, dtype=torch.long, device=device)
print('Test data size: ', images_test.shape)

# Créer les TensorDatasets pour l'entraînement et les tests
train_dataset = TensorDataset(images_train, labels_train)
test_dataset = TensorDataset(images_test, labels_test)

# Définir un modèle simple de CNN
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Fonction pour créer le benchmark
def cifar100_distilled_benchmark(steps):
    return nc_benchmark(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        n_experiences=steps,
        task_labels=False
    )

def main_with_checkpointing(args):
    # STEP 1: SET THE RANDOM SEEDS to guarantee reproducibility
    RNGManager.set_random_seeds(1234)

    # Nothing new here...
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )
    print("Using device", device)

    # CL Benchmark Creation (as usual)
    benchmark = cifar100_distilled_benchmark(args.steps)  # Pass steps as an argument
    model = SimpleCNN(num_classes=100)
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Create the evaluation plugin (as usual)
    evaluation_plugin = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True), loggers=[InteractiveLogger()]
    )

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=[interactive_logger],
    )

    # Create the strategy using Replay
    strategy = Replay(
        model=model,                # Ensure your model is capable of handling the increased number of classes
        optimizer=optimizer,        # Optimizer, e.g., Adam or SGD
        criterion=criterion,        # Loss function, e.g., CrossEntropyLoss
        mem_size=200,               # Size of the replay buffer
        train_mb_size=64,           # Batch size; adjust based on your hardware and model requirements
        train_epochs=30,            # Increased number of epochs to ensure adequate training
        eval_mb_size=64,            # Evaluation batch size
        device=device,              # Device, e.g., 'cuda' or 'cpu'
        evaluator=eval_plugin # Evaluation plugin or metric, e.g., accuracy
    )


    # STEP 2: TRY TO LOAD THE LAST CHECKPOINT
    # if the checkpoint exists, load it into the newly created strategy
    # the method also loads the experience counter, so we know where to
    # resume training
    fname = "./checkpoint/Replay.pkl"  # name of the checkpoint file
    os.makedirs(os.path.dirname(fname), exist_ok=True)  # Ensure the checkpoint directory exists
    strategy, initial_exp = maybe_load_checkpoint(strategy, fname)

    initial_exp=0
    # STEP 3: USE THE "initial_exp" to resume training
    for train_exp in benchmark.train_stream[initial_exp:]:
        strategy.train(train_exp, num_workers=4, persistent_workers=True)
        strategy.eval(benchmark.test_stream, num_workers=4)

        # STEP 4: SAVE the checkpoint after training on each experience.
        save_checkpoint(strategy, fname)


if __name__ == "__main__":
    if 'ipykernel_launcher' in sys.argv[0]:
        # Running in a Jupyter environment
        class Args:
            cuda = 0
            steps = 10  # Définir une valeur par défaut pour steps
        args = Args()
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--cuda",
            type=int,
            default=0,
            help="Select zero-indexed cuda device. -1 to use CPU.",
        )
        parser.add_argument(
            "--steps",
            type=int,
            default=10,
            help="Number of steps for the benchmark.",
        )
        # Parse known arguments and ignore the rest
        args, _ = parser.parse_known_args(sys.argv)
    main_with_checkpointing(args)
