# Project B: Dataset Distillation: A Data-Efficient Learning Framework

## 1.Import

In [7]:
import gc
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from ptflops import get_model_complexity_info
%run utils.ipynb import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, train_only

  from scipy.ndimage.interpolation import rotate as scipyrotate


## 2.Train from scratch

In [None]:
def train(
    dataset:str, 
    model:str, 
    epoch:int, 
    lr:float,
    data_path:str 
):

    parser = argparse.ArgumentParser()

    args, unknown = parser.parse_known_args()
    
    args.dataset = dataset
    args.model = model
    args.epoch_eval_train = epoch
    args.lr_net = lr
    args.data_path = data_path
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.batch_train = 256
    
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    
    print('Hyper-parameters: \n', args.__dict__)

    print('Training begins')
    print('-------------------------\nEvaluation\nmodel_train = %s'%(args.model))

    net = get_network(args.model, channel, num_classes, im_size).to(args.device)
    flops, params = get_model_complexity_info(net, (channel, im_size[0], im_size[1]), as_strings=True, print_per_layer_stat=True)
    
    _, acc_train, acc_test = train_only(net, dst_train, testloader, args)
    print('Evaluate %s, acc = %.4f FLOPs = %s\n-------------------------'%(net, acc_test, flops))


train('MNIST', 'ConvNetD3', 20, 0.01, 'data')

Hyper-parameters: 
 {'dataset': 'MNIST', 'model': 'ConvNetD3', 'epoch_eval_train': 20, 'lr_net': 0.01, 'data_path': 'data', 'device': 'cpu', 'batch_train': 256}
Training begins
-------------------------
Evaluation
model_train = ConvNetD3
ConvNet(
  317.71 k, 100.000% Params, 49.25 MMac, 99.306% MACs, 
  (features): Sequential(
    297.22 k, 93.551% Params, 49.23 MMac, 99.265% MACs, 
    (0): Conv2d(1.28 k, 0.403% Params, 1.31 MMac, 2.643% MACs, 1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
    (1): GroupNorm(256, 0.081% Params, 262.14 KMac, 0.529% MACs, 128, 128, eps=1e-05, affine=True)
    (2): ReLU(0, 0.000% Params, 131.07 KMac, 0.264% MACs, inplace=True)
    (3): AvgPool2d(0, 0.000% Params, 131.07 KMac, 0.264% MACs, kernel_size=2, stride=2, padding=0)
    (4): Conv2d(147.58 k, 46.453% Params, 37.78 MMac, 76.187% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): GroupNorm(256, 0.081% Params, 65.54 KMac, 0.132% MACs, 128, 128, eps=1e-05, affine=T

KeyboardInterrupt: 

## 3.Define distillation method
based on https://github.com/VICO-UoE/DatasetCondensation/main.py

In [None]:
def distillation(
    dataset:str, 
    model:str, 
    ipc:int, 
    epoch_eval_train:int, 
    Iteration:int, 
    lr_img:float, 
    lr_net:float, 
    batch:int,  
    init:str, 
    data_path:str, 
    result:str, 
    dis_metric:str
):

    parser = argparse.ArgumentParser(description='Parameter Processing')
    
    # parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    # parser.add_argument('--model', type=str, default='ConvNet', help='model')
    # parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
    # parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    # parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    # parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    # parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data')
    # parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
    # parser.add_argument('--lr_img', type=float, default=0.1, help='learning rate for updating synthetic images')
    # parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    # parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    # parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    # parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    # parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    # parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    # parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

    args = parser.parse_args()
    
    args.dataset = dataset
    args.model = model
    args.ipc = ipc
    args.eval_mode = 'S'
    args.num_exp = 1
    args.num_eval = 1
    args.epoch_eval_train = epoch_eval_train
    args.Iteration = Iteration
    args.lr_img = lr_img
    args.lr_net = lr_net
    args.batch_real = batch
    args.batch_size = batch
    args.init = init
    args.data_path = data_path
    args.result = result
    args.dis_metric = dis_metric
    
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = [args.Iteration]
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []


    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc)
                    print('DC augmentation parameters: \n', args.dc_aug_param)

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f \n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%dipc_exp%d_iter%d.png'%(args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.


            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None  # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.


            for ol in range(args.outer_loop):

                ''' freeze the running mu and sigma for BatchNorm layers '''
                # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
                # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
                # This would make the training with BatchNorm layers easier.

                BN_flag = False
                BNSizePC = 16  # for batch normalization
                for module in net.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train() # for updating the mu, sigma of BatchNorm
                    output_real = net(img_real) # get running mu, sigma
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer


                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break


                ''' update network '''
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = False)


            loss_avg /= (num_classes*args.outer_loop)

            if it%10 == 0:
                print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc)))


    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))




## 4.Dataset Distillation using Gradient Matching

In [None]:
distillation('MNIST', 'ConvNetD3', 10, 20, 1, 0.1, 0.01, 256, 'real', 'data', 'result', 'ours')
distillation('MHIST', 'ConvNetD7', 50, 20, 1, 0.1, 0.01, 128, 'real', 'data', 'result', 'ours')

## 5.Dataset Distillation using Gradient Matching (With Gaussian Noise)

In [None]:
distillation('MNIST', 'ConvNetD3', 10, 20, 1, 0.1, 0.01, 256, 'noise', 'data', 'result', 'ours')
distillation('MHIST', 'ConvNetD7', 50, 20, 1, 0.1, 0.01, 128, 'noise', 'data', 'result', 'ours')



## 6.Train from scratch using distilled dataset

In [None]:
train('MNIST_distilled', 'ConvNetD3', 20, 0.01, 'data')
train('MHIST_distilled', 'ConvNetD7', 20, 0.01, 'data')

## 7.Cross-architecture Generalization

In [None]:
def cross(
    dataset:str, 
    model:str, 
    ipc:int, 
    eval_mode:str,
    num_eval:int,
    epoch_eval_train:int, 
    Iteration:int, 
    lr_img:float, 
    lr_net:float, 
    batch:int,  
    init:str, 
    data_path:str, 
    result:str, 
    dis_metric:str
):

    parser = argparse.ArgumentParser(description='Parameter Processing')

    args = parser.parse_args()
    
    args.dataset = dataset
    args.model = model
    args.ipc = ipc
    args.eval_mode = eval_mode
    args.num_exp = 1
    args.num_eval = num_eval
    args.epoch_eval_train = epoch_eval_train
    args.Iteration = Iteration
    args.lr_img = lr_img
    args.lr_net = lr_net
    args.batch_real = batch
    args.batch_size = batch
    args.init = init
    args.data_path = data_path
    args.result = result
    args.dis_metric = dis_metric
    
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.method == 'DSA' else False

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 500).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []


    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    if args.dsa:
                        args.epoch_eval_train = 1000
                        args.dc_aug_param = None
                        print('DSA augmentation strategy: \n', args.dsa_strategy)
                        print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
                    else:
                        args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
                        print('DC augmentation parameters: \n', args.dc_aug_param)

                    if args.dsa or args.dc_aug_param['strategy'] != 'none':
                        args.epoch_eval_train = 1000  # Training with data augmentation needs more epochs.
                    else:
                        args.epoch_eval_train = 300

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.


            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None  # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.


            for ol in range(args.outer_loop):

                ''' freeze the running mu and sigma for BatchNorm layers '''
                # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
                # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
                # This would make the training with BatchNorm layers easier.

                BN_flag = False
                BNSizePC = 16  # for batch normalization
                for module in net.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train() # for updating the mu, sigma of BatchNorm
                    output_real = net(img_real) # get running mu, sigma
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer


                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break


                ''' update network '''
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)


            loss_avg /= (num_classes*args.outer_loop)

            if it%10 == 0:
                print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))


    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))

cross('MNIST', 'ConvNetD3', 10,'M', 2, 20, 1, 0.1, 0.01, 256, 'noise', 'data', 'result', 'ours')
cross('MHIST', 'ConvNetD7', 50,'M', 2, 20, 1, 0.1, 0.01, 128, 'noise', 'data', 'result', 'ours')


## 8.Application in continual learning
based on https://github.com/VICO-UoE/DatasetCondensation/CL_DM.py

In [None]:
def CL(
    method:str='random',
    dataset:str='CIFAR100',
    model:str='ConvNet',
    ipc:int=20,
    steps:int=5,
    num_eval:int=3,
    epoch_eval_train:int=1000,
    lr_net:float=0.01,
    batch:int=256,
    data_path:str='data'
):
    parser = argparse.ArgumentParser(description='Parameter Processing')
    # parser.add_argument('--method', type=str, default='random', help='random/herding/DSA/DM')
    # parser.add_argument('--dataset', type=str, default='CIFAR100', help='dataset')
    # parser.add_argument('--model', type=str, default='ConvNet', help='model')
    # parser.add_argument('--ipc', type=int, default=20, help='image(s) per class')
    # parser.add_argument('--steps', type=int, default=5, help='5/10-step learning')
    # parser.add_argument('--num_eval', type=int, default=3, help='evaluation number')
    # parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
    # parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    # parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    # parser.add_argument('--data_path', type=str, default='./../data', help='dataset path')

    args = parser.parse_args()
    
    args.method = method
    args.dataset = dataset
    args.model = model
    args.ipc = ipc
    args.steps = steps
    args.num_eval = num_eval
    args.epoch_eval_train = epoch_eval_train
    args.lr_net = lr_net
    args.batch_train = batch
    args.data_path = data_path
    
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True # augment images for all methods
    args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' # for CIFAR10/100

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)


    ''' all training data '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]

    images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
    for i, lab in enumerate(labels_all):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to(args.device)
    labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

    # for c in range(num_classes):
    #     print('class c = %d: %d real images' % (c, len(indices_class[c])))

    def get_images(c, n):  # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]

    print()
    print('==================================================================================')
    print('method: ', args.method)
    results = np.zeros((args.steps, 5*args.num_eval))

    for seed_cl in range(5):
        num_classes_step = num_classes // args.steps
        np.random.seed(seed_cl)
        class_order = np.random.permutation(num_classes).tolist()
        print('=========================================')
        print('seed: ', seed_cl)
        print('class_order: ', class_order)
        print('augmentation strategy: \n', args.dsa_strategy)
        print('augmentation parameters: \n', args.dsa_param.__dict__)

        if args.method == 'random':
            images_train_all = []
            labels_train_all = []
            for step in range(args.steps):
                classes_current = class_order[step * num_classes_step: (step + 1) * num_classes_step]
                images_train_all += [torch.cat([get_images(c, args.ipc) for c in classes_current], dim=0)]
                labels_train_all += [torch.tensor([c for c in classes_current for i in range(args.ipc)], dtype=torch.long, device=args.device)]

        elif args.method == 'herding':
            fname = os.path.join(args.data_path, 'metasets', 'cl_data', 'cl_herding_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, seed_cl))
            data = torch.load(fname, map_location='cpu')['data']
            images_train_all = [data[step][0] for step in range(args.steps)]
            labels_train_all = [data[step][1] for step in range(args.steps)]
            print('use data: ', fname)

        elif args.method == 'DSA':
            fname = os.path.join(args.data_path, 'metasets', 'cl_data', 'cl_res_DSA_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, seed_cl))
            data = torch.load(fname, map_location='cpu')['data']
            images_train_all = [data[step][0] for step in range(args.steps)]
            labels_train_all = [data[step][1] for step in range(args.steps)]
            print('use data: ', fname)

        elif args.method == 'DM':
            fname = os.path.join(args.data_path, 'metasets', 'cl_data', 'cl_DM_CIFAR100_ConvNet_20ipc_%dsteps_seed%d.pt'%(args.steps, seed_cl))
            data = torch.load(fname, map_location='cpu')['data']
            images_train_all = [data[step][0] for step in range(args.steps)]
            labels_train_all = [data[step][1] for step in range(args.steps)]
            print('use data: ', fname)

        else:
            exit('unknown method: %s'%args.method)


        for step in range(args.steps):
            print('\n-----------------------------\nmethod %s seed %d step %d ' % (args.method, seed_cl, step))

            classes_seen = class_order[: (step+1)*num_classes_step]
            print('classes_seen: ', classes_seen)


            ''' train data '''
            images_train = torch.cat(images_train_all[:step+1], dim=0).to(args.device)
            labels_train = torch.cat(labels_train_all[:step+1], dim=0).to(args.device)
            print('train data size: ', images_train.shape)


            ''' test data '''
            images_test = []
            labels_test = []
            for i in range(len(dst_test)):
                lab = int(dst_test[i][1])
                if lab in classes_seen:
                    images_test.append(torch.unsqueeze(dst_test[i][0], dim=0))
                    labels_test.append(dst_test[i][1])

            images_test = torch.cat(images_test, dim=0).to(args.device)
            labels_test = torch.tensor(labels_test, dtype=torch.long, device=args.device)
            dst_test_current = TensorDataset(images_test, labels_test)
            testloader = torch.utils.data.DataLoader(dst_test_current, batch_size=256, shuffle=False, num_workers=0)

            print('test set size: ', images_test.shape)


            ''' train model on the newest memory '''
            accs = []
            for ep_eval in range(args.num_eval):
                net_eval = get_network(args.model, channel, num_classes, im_size)
                net_eval = net_eval.to(args.device)
                img_syn_eval = copy.deepcopy(images_train.detach())
                lab_syn_eval = copy.deepcopy(labels_train.detach())

                _, acc_train, acc_test = evaluate_synset(ep_eval, net_eval, img_syn_eval, lab_syn_eval, testloader, args)
                del net_eval, img_syn_eval, lab_syn_eval
                gc.collect()  # to reduce memory cost
                accs.append(acc_test)
                results[step, seed_cl*args.num_eval + ep_eval] = acc_test
            print('Evaluate %d random %s, mean = %.4f std = %.4f' % (len(accs), args.model, np.mean(accs), np.std(accs)))


    results_str = ''
    for step in range(args.steps):
        results_str += '& %.1f$\pm$%.1f  ' % (np.mean(results[step]) * 100, np.std(results[step]) * 100)
    print('\n\n')
    print('%d step learning %s perforamnce:'%(args.steps, args.method))
    print(results_str)
    print('Done')


CL(dataset='MNIST', model='ConvNet', steps=5, method='random')
