In [1]:
import os
import time
import copy
import argparse
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis
import tqdm
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    model = model.to(device)
    criterion = criterion.to(device)
    model.train()
    sythetic_img_all = torch.ones((100,1,28,28)).to(device)
    flops_syn = FlopCountAnalysis(model, sythetic_img_all)
    print('The FLOPs for 100 synthetic images is {}'.format(flops_syn.total()))
    print('The FLOPs for 60000 real training images is {}'.format(flops_syn.total()*600))
    print('The FLOPs for 10000 real validation images is {}'.format(flops_syn.total()*100))
    del sythetic_img_all
    for epoch in range(epochs):
        print('Start Epoch #{}'.format(epoch+1))
        with tqdm.tqdm(total=len(train_loader)) as pbar:
            loss_avg, acc_avg, num_exp = 0, 0, 0
            for i_batch, datum in enumerate(train_loader):
                img = datum[0].float().to(device)
        #         if aug:
        #             if args.dsa:
        #                 img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
        #             else:
        #                 img = augment(img, args.dc_aug_param, device=args.device)
                lab = datum[1].long().to(device)
                n_b = lab.shape[0]
                output = model(img)
                loss = criterion(output, lab)
                acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
                loss_avg += loss.item()*n_b
                acc_avg += acc
                num_exp += n_b

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(1)
                del img, lab
                
        loss_avg /= num_exp
        acc_avg /= num_exp
        scheduler.step()
        print('Train accuracy is {}%'.format(acc_avg*100))
        print('Average train loss is {}'.format(loss_avg))
        
    loss_avg, acc_avg, num_exp = 0, 0, 0   
    model.eval()
    for i_batch, datum in enumerate(test_loader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]
        output = model(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b
        
        del img, lab

    loss_avg /= num_exp
    acc_avg /= num_exp
    print('Test accuracy is {}%'.format(acc_avg*100))
    print('Average test loss is {}'.format(loss_avg))          

    return loss_avg, acc_avg

MNIST_dataset = 'MNIST'
MNIST_data_path = './MNISTdata'
MNIST_channel, MNIST_im_size, MNIST_num_classes, MNIST_class_names, MNIST_mean, MNIST_std, MNIST_dst_train, MNIST_dst_test, MNIST_testloader = get_dataset(MNIST_dataset, MNIST_data_path)
MNIST_trainloader = torch.utils.data.DataLoader(MNIST_dst_train, batch_size=8, shuffle=True, num_workers=0)
model = get_network('ConvNet', MNIST_channel, MNIST_num_classes, MNIST_im_size).to(device) # get a random model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
epochs = 20
start = time.time()
train(model, MNIST_trainloader , MNIST_testloader , optimizer, scheduler, criterion, epochs)
print('Training on the original MNIST dataset takes {} seconds'.format(time.time()-start))


Unsupported operator aten::avg_pool2d encountered 3 time(s)


The FLOPs for 100 synthetic images is 4924620800
The FLOPs for 60000 real training images is 2954772480000
The FLOPs for 10000 real validation images is 492462080000
Start Epoch #1


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:41<00:00, 181.48it/s]


Train accuracy is 97.21333333333332%
Average train loss is 0.09636120648886232
Start Epoch #2


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:41<00:00, 182.15it/s]


Train accuracy is 98.86833333333334%
Average train loss is 0.037596627227744706
Start Epoch #3


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:39<00:00, 188.29it/s]


Train accuracy is 99.13833333333332%
Average train loss is 0.028208444553001026
Start Epoch #4


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 195.77it/s]


Train accuracy is 99.345%
Average train loss is 0.0221524243795627
Start Epoch #5


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:39<00:00, 191.30it/s]


Train accuracy is 99.42333333333333%
Average train loss is 0.018589427906607184
Start Epoch #6


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:39<00:00, 189.53it/s]


Train accuracy is 99.56833333333334%
Average train loss is 0.014966347912801697
Start Epoch #7


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:40<00:00, 183.64it/s]


Train accuracy is 99.66166666666668%
Average train loss is 0.01233487842734891
Start Epoch #8


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:40<00:00, 184.19it/s]


Train accuracy is 99.73333333333333%
Average train loss is 0.010501599295541503
Start Epoch #9


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:39<00:00, 188.52it/s]


Train accuracy is 99.80166666666666%
Average train loss is 0.00862585836301514
Start Epoch #10


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 193.20it/s]


Train accuracy is 99.85166666666667%
Average train loss is 0.007250313250330813
Start Epoch #11


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 193.12it/s]


Train accuracy is 99.88666666666667%
Average train loss is 0.005959797833087244
Start Epoch #12


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 194.15it/s]


Train accuracy is 99.91833333333334%
Average train loss is 0.004990835475930317
Start Epoch #13


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 196.32it/s]


Train accuracy is 99.94666666666666%
Average train loss is 0.00422198680354201
Start Epoch #14


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 194.94it/s]


Train accuracy is 99.94333333333333%
Average train loss is 0.0038661902848921574
Start Epoch #15


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:38<00:00, 194.84it/s]


Train accuracy is 99.96666666666667%
Average train loss is 0.0034085706100623306
Start Epoch #16


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:54<00:00, 137.36it/s]


Train accuracy is 99.97333333333333%
Average train loss is 0.003172207000024476
Start Epoch #17


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:51<00:00, 145.73it/s]


Train accuracy is 99.97666666666667%
Average train loss is 0.0029685941946937418
Start Epoch #18


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:51<00:00, 145.99it/s]


Train accuracy is 99.98%
Average train loss is 0.002843619347076992
Start Epoch #19


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:52<00:00, 142.76it/s]


Train accuracy is 99.98333333333333%
Average train loss is 0.0027625877702018064
Start Epoch #20


100%|█████████████████████████████████████████████████████████████████████████████| 7500/7500 [00:46<00:00, 160.34it/s]


Train accuracy is 99.98166666666667%
Average train loss is 0.0027157509171857478
Test accuracy is 99.48%
Average test loss is 0.015956512201280564
Training on the original MNIST dataset takes 851.3110814094543 seconds


In [3]:
class arguments():
    def __init__(self,): 
        self.method = 'DC'
        self.dataset = 'MNIST'
        self.model = 'ConvNet'
        #'image(s) per class'
        self.ipc = 10
         # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
        self.eval_mode = 'S'
        #the number of experiments
        self.num_exp = 1
        #the number of evaluating randomly initialized models
        self.num_eval = 1
        #c
        self.epoch_eval_train = 20
        #training iterations
        self.Iteration = 100
        self.lr_img = 0.1
        self.lr_net = 0.01
        self.batch_real = 256
        self.batch_train = 256
        #'noise/real: initialize synthetic images from random noise or randomly sampled real images.'
        self.init = 'real'
        self.dsa_strategy = 'None'
        self.data_path = 'MNISTdata'
        self.save_path = 'MNISTresult'
        self.dis_metric = 'ours'
        self.outer_loop = 10
        self.inner_loop = 50
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.dsa_param = ParamDiffAug()
        self.dsa = False
        

def main(args):
    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = [0, 9, 19, 29, 39, 49, 59, 69, 79, 89, args.Iteration-1] # The list of iterations when we evaluate models and record results.
    print('eval_it_pool: ', np.array(eval_it_pool)+1)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    data_save = []
    all_losses = []
    all_accs = []

    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration):
            
            if it in eval_it_pool:
                ''' visualize and save '''
                save_name = os.path.join(args.save_path, args.init+'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.
            
            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None  # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.


            for ol in range(args.outer_loop):

                ''' freeze the running mu and sigma for BatchNorm layers '''
                # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
                # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
                # This would make the training with BatchNorm layers easier.

                BN_flag = False
                BNSizePC = 16  # for batch normalization
                for module in net.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train() # for updating the mu, sigma of BatchNorm
                    output_real = net(img_real) # get running mu, sigma
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer


                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break


                ''' update network '''
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)


            loss_avg /= (num_classes*args.outer_loop)
            all_losses.append(loss_avg)

            if (it+1)%10 == 0:
                print('%s iter = %04d, loss = %.4f' % (get_time(), it+1, loss_avg))

            if it == (args.Iteration-1): # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save}, os.path.join(args.save_path, args.init+'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))

            ''' Evaluate synthetic data '''
            if it == eval_it_pool[-1]:
                print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, args.model, it+1))
                args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
                args.epoch_eval_train = 20

                accs = []
                net_eval = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
                image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                start = time.time()
                _, acc_train, acc_test = evaluate_synset(net_eval, image_syn_eval, label_syn_eval, testloader, args)
                print('\n==================== Final Results ====================\n')
                print('After {} iterations, the model test accuracy on synthetic data is {}%'.format(it+1, acc_test*100))
                print('Training on synthetic MNIST dataset takes {} seconds'.format(time.time()-start))
                all_accs.append(acc_test)
                
    return all_losses, all_accs

args = arguments()
all_losses, all_accs = main(args)


eval_it_pool:  [  1  10  20  30  40  50  60  70  80  90 100]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 1, 'epoch_eval_train': 20, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'None', 'data_path': 'MNISTdata', 'save_path': 'MNISTresult', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x0000011F721FB9A0>, 'dsa': False}
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
class c = 9: 5949 real images
real images channel 0, mean = -0.0001, std = 1.0000
initialize synthetic data from random real images
[2022-11-28 22:49:45] training begins


  label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]


[2022-11-28 22:52:44] iter = 0010, loss = 36.7684
[2022-11-28 22:55:45] iter = 0020, loss = 36.6858
[2022-11-28 22:58:46] iter = 0030, loss = 36.3136
[2022-11-28 23:01:49] iter = 0040, loss = 37.6236
[2022-11-28 23:04:52] iter = 0050, loss = 35.4983
[2022-11-28 23:07:54] iter = 0060, loss = 35.9700
[2022-11-28 23:10:55] iter = 0070, loss = 36.7914
[2022-11-28 23:13:57] iter = 0080, loss = 34.2500
[2022-11-28 23:16:34] iter = 0090, loss = 36.1507
[2022-11-28 23:19:22] iter = 0100, loss = 35.5200
-------------------------
Evaluation
model_train = ConvNet, model_eval = ConvNet, iteration = 100


After 100 iterations, the model test accuracy on synthetic data is 94.37%
Training on synthetic MNIST dataset takes 5.802131652832031 seconds


In [4]:
args = arguments()
args.init = 'noise'
all_losses, all_accs = main(args)

eval_it_pool:  [  1  10  20  30  40  50  60  70  80  90 100]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 1, 'epoch_eval_train': 20, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': 'MNISTdata', 'save_path': 'MNISTresult', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x0000011F72239990>, 'dsa': False}
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
class c = 9: 5949 real images
real images channel 0, mean = -0.0001, std = 1.0000
initialize synthetic data from random noise
[2022-11-28 23:19:52] training begins
[2022-