In [1]:
import os
import time
import copy
import tqdm
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
warnings.filterwarnings("ignore") 

def main(args):
    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 200).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

    data_save = []

    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)



        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):
            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            for param in list(net.parameters()):
                param.requires_grad = False

            embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel

            loss_avg = 0

            ''' update synthetic data '''
            if 'BN' not in args.model: # for ConvNet
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = embed(img_real).detach()
                    output_syn = embed(img_syn)
#                     import pdb;pdb.set_trace()
                    loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

            else: # for ConvNetBN
                images_real_all = []
                images_syn_all = []
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    images_real_all.append(img_real)
                    images_syn_all.append(img_syn)

                images_real_all = torch.cat(images_real_all, dim=0)
                images_syn_all = torch.cat(images_syn_all, dim=0)

                output_real = embed(images_real_all).detach()
                output_syn = embed(images_syn_all)

                loss += torch.sum((torch.mean(output_real.reshape(num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(num_classes, args.ipc, -1), dim=1))**2)



            optimizer_img.zero_grad()
            loss.backward()
            optimizer_img.step()
            loss_avg += loss.item()


            loss_avg /= (num_classes)

            if it%10 == 0:
                print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save}, os.path.join(args.save_path, args.init+'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))
                        
            ''' Evaluate synthetic data '''
            if it == eval_it_pool[-1]:
                net_eval = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
                image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                _, acc_train, acc_test = evaluate_synset(net_eval, image_syn_eval, label_syn_eval, testloader, args)
                print('\n==================== Final Results ====================\n')
                print('After {} iterations, the model test accuracy on synthetic data is {}%'.format(it, acc_test*100))
            
            if it in eval_it_pool:    
                ''' visualize and save '''
                save_name = os.path.join(args.save_path, args.init+'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.

class arguments():
    def __init__(self,): 
        self.method = 'DM'
        self.dataset = 'CIFAR10'
        self.model = 'ConvNet'
        #'image(s) per class'
        self.ipc = 10
         # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
        self.eval_mode = 'S'
        #the number of experiments
        self.num_exp = 1
        #the number of evaluating randomly initialized models
        self.num_eval = 1
        #c
        self.epoch_eval_train = 100
        #training iterations
        self.Iteration = 2000
        self.lr_img = 1.0
        self.lr_net = 0.01
        self.batch_real = 256
        self.batch_train = 256
        #'noise/real: initialize synthetic images from random noise or randomly sampled real images.'
        self.init = 'real'
        self.dsa_strategy = 'none'
        self.data_path = 'CIFAR10data'
        self.save_path = 'CIFAR10result'
        self.dis_metric = 'ours'
        self.outer_loop = 10
        self.inner_loop = 50
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.dsa_param = ParamDiffAug()
        self.dsa = True
        
args = arguments()
main(args)




eval_it_pool:  [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
Files already downloaded and verified
Files already downloaded and verified

 
Hyper-parameters: 
 {'method': 'DM', 'dataset': 'CIFAR10', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 1, 'epoch_eval_train': 100, 'Iteration': 2000, 'lr_img': 1.0, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': 'CIFAR10data', 'save_path': 'CIFAR10result', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x00000206A5E64130>, 'dsa': True}
class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images c

[2022-12-01 21:00:48] iter = 01390, loss = 5.1676
[2022-12-01 21:00:53] iter = 01400, loss = 4.9596
[2022-12-01 21:00:57] iter = 01410, loss = 4.8319
[2022-12-01 21:01:00] iter = 01420, loss = 5.1900
[2022-12-01 21:01:04] iter = 01430, loss = 4.8074
[2022-12-01 21:01:09] iter = 01440, loss = 4.9288
[2022-12-01 21:01:13] iter = 01450, loss = 4.7676
[2022-12-01 21:01:16] iter = 01460, loss = 4.8914
[2022-12-01 21:01:20] iter = 01470, loss = 4.8280
[2022-12-01 21:01:25] iter = 01480, loss = 4.5113
[2022-12-01 21:01:28] iter = 01490, loss = 4.7616
[2022-12-01 21:01:32] iter = 01500, loss = 4.6956
[2022-12-01 21:01:36] iter = 01510, loss = 4.9331
[2022-12-01 21:01:39] iter = 01520, loss = 4.7885
[2022-12-01 21:01:44] iter = 01530, loss = 4.4702
[2022-12-01 21:01:48] iter = 01540, loss = 4.7576
[2022-12-01 21:01:52] iter = 01550, loss = 4.7279
[2022-12-01 21:01:55] iter = 01560, loss = 4.8036
[2022-12-01 21:02:00] iter = 01570, loss = 4.7272
[2022-12-01 21:02:04] iter = 01580, loss = 4.6700


In [2]:
args.init = 'noise'
main(args)

eval_it_pool:  [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
Files already downloaded and verified
Files already downloaded and verified

 
Hyper-parameters: 
 {'method': 'DM', 'dataset': 'CIFAR10', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 1, 'epoch_eval_train': 100, 'Iteration': 2000, 'lr_img': 1.0, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'noise', 'dsa_strategy': 'none', 'data_path': 'CIFAR10data', 'save_path': 'CIFAR10result', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x000001DCDBA27700>, 'dsa': True}
class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images channel 0, mean = -0.0000, std 

[2022-12-01 22:48:32] iter = 01390, loss = 5.5567
[2022-12-01 22:48:33] iter = 01400, loss = 5.4769
[2022-12-01 22:48:35] iter = 01410, loss = 5.2686
[2022-12-01 22:48:36] iter = 01420, loss = 5.4101
[2022-12-01 22:48:38] iter = 01430, loss = 5.2225
[2022-12-01 22:48:39] iter = 01440, loss = 5.5644
[2022-12-01 22:48:41] iter = 01450, loss = 5.7927
[2022-12-01 22:48:42] iter = 01460, loss = 5.2755
[2022-12-01 22:48:44] iter = 01470, loss = 5.3785
[2022-12-01 22:48:45] iter = 01480, loss = 5.1901
[2022-12-01 22:48:47] iter = 01490, loss = 5.1421
[2022-12-01 22:48:49] iter = 01500, loss = 5.4532
[2022-12-01 22:48:50] iter = 01510, loss = 5.0711
[2022-12-01 22:48:52] iter = 01520, loss = 5.5171
[2022-12-01 22:48:53] iter = 01530, loss = 5.2506
[2022-12-01 22:48:55] iter = 01540, loss = 5.1772
[2022-12-01 22:48:56] iter = 01550, loss = 5.4313
[2022-12-01 22:48:58] iter = 01560, loss = 5.3931
[2022-12-01 22:48:59] iter = 01570, loss = 5.4736
[2022-12-01 22:49:01] iter = 01580, loss = 5.3460
