In [3]:
import os
import argparse
import numpy as np
import torch
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from tqdm import tqdm
from utils import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug
import wandb
import copy
import random
from reparam_module import ReparamModule

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def main(args):

    if args.zca and args.texture:
        raise AssertionError("Cannot use zca and texture together")

    if args.texture and args.pix_init == "real":
        print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.")

    if args.max_experts is not None and args.max_files is not None:
        args.total_experts = args.max_experts * args.max_files

    print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled))

    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    eval_it_pool = [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
    channel, im_size, num_classes, class_names, mean_orig, std_orig, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)
    mean = mean_orig.copy()
    std = std_orig.copy()
    im_res = im_size[0]

    args.im_size = im_size

    data_save = []

    if args.dsa:
        # args.epoch_eval_train = 1000
        args.dc_aug_param = None
    args.dc_aug_param = None
    args.dsa_param = ParamDiffAug()

    dsa_params = args.dsa_param
    if args.zca:
        zca_trans = args.zca_trans
    else:
        zca_trans = None

    wandb.init(sync_tensorboard=False,
               project="DatasetDistillation",
               job_type="CleanRepo",
               config=args,
               )

    args = type('', (), {})()

    for key in wandb.config._items:
        setattr(args, key, wandb.config._items[key])

    args.dsa_param = dsa_params
    args.zca_trans = zca_trans

    if args.batch_syn is None:
        args.batch_syn = num_classes * args.ipc

    args.distributed = torch.cuda.device_count() > 1


    print('Hyper-parameters: \n', args.__dict__)

    ''' organize the real dataset '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]
    print("BUILDING DATASET")
    for i in tqdm(range(len(dst_train))):
        sample = dst_train[i]
        images_all.append(torch.unsqueeze(sample[0], dim=0))
        labels_all.append(class_map[torch.tensor(sample[1]).item()])

    for i, lab in tqdm(enumerate(labels_all)):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to("cpu")
    labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")

    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


    def get_images(c, n):  # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]


    ''' initialize the synthetic data '''
    label_syn = torch.tensor([np.ones(args.ipc,dtype=np.int_)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

    if args.texture:
        image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0]*args.canvas_size, im_size[1]*args.canvas_size), dtype=torch.float)
    else:
        image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float)

    syn_lr = torch.tensor(args.lr_teacher).to(args.device)

    if args.pix_init == 'real':
        print('initialize synthetic data from random real images')
        if args.texture:
            for c in range(num_classes):
                for i in range(args.canvas_size):
                    for j in range(args.canvas_size):
                        image_syn.data[c * args.ipc:(c + 1) * args.ipc, :, i * im_size[0]:(i + 1) * im_size[0],
                        j * im_size[1]:(j + 1) * im_size[1]] = torch.cat(
                            [get_images(c, 1).detach().data for s in range(args.ipc)])
        for c in range(num_classes):
            image_syn.data[c * args.ipc:(c + 1) * args.ipc] = get_images(c, args.ipc).detach().data
    else:
        print('initialize synthetic data from random noise')


    ''' training '''
    image_syn = image_syn.detach().to(args.device).requires_grad_(True)
    syn_lr = syn_lr.detach().to(args.device).requires_grad_(True)
    optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5)
    optimizer_lr = torch.optim.SGD([syn_lr], lr=args.lr_lr, momentum=0.5)
    optimizer_img.zero_grad()

    criterion = nn.CrossEntropyLoss().to(args.device)
    print('%s training begins'%get_time())

    expert_dir = os.path.join(args.buffer_path, args.dataset)
    if args.dataset == "ImageNet":
        expert_dir = os.path.join(expert_dir, args.subset, str(args.res))
    if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
        expert_dir += "_NO_ZCA"
    expert_dir = os.path.join(expert_dir, args.model)
    print("Expert Dir: {}".format(expert_dir))

    if args.load_all:
        buffer = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):
            buffer = buffer + torch.load(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))
            n += 1
        if n == 0:
            raise AssertionError("No buffers detected at {}".format(expert_dir))

    else:
        expert_files = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):
            expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))
            n += 1
        if n == 0:
            raise AssertionError("No buffers detected at {}".format(expert_dir))
        file_idx = 0
        expert_idx = 0
        random.shuffle(expert_files)
        if args.max_files is not None:
            expert_files = expert_files[:args.max_files]
        print("loading file {}".format(expert_files[file_idx]))
        buffer = torch.load(expert_files[file_idx])
        if args.max_experts is not None:
            buffer = buffer[:args.max_experts]
        random.shuffle(buffer)

    for it in range(0, args.Iteration+1):
        save_this_it = False

        # writer.add_scalar('Progress', it, it)
        wandb.log({"Progress": it}, step=it)

        student_net = get_network(args.model, channel, num_classes, im_size, dist=False).to(args.device)  # get a random model

        student_net = ReparamModule(student_net)

        if args.distributed:
            student_net = torch.nn.DataParallel(student_net)

        student_net.train()

        num_params = sum([np.prod(p.size()) for p in (student_net.parameters())])

        if args.load_all:
            expert_trajectory = buffer[np.random.randint(0, len(buffer))]
        else:
            expert_trajectory = buffer[expert_idx]
            expert_idx += 1
            if expert_idx == len(buffer):
                expert_idx = 0
                file_idx += 1
                if file_idx == len(expert_files):
                    file_idx = 0
                    random.shuffle(expert_files)
                print("loading file {}".format(expert_files[file_idx]))
                if args.max_files != 1:
                    del buffer
                    buffer = torch.load(expert_files[file_idx])
                if args.max_experts is not None:
                    buffer = buffer[:args.max_experts]
                random.shuffle(buffer)

        start_epoch = np.random.randint(0, args.max_start_epoch)
        starting_params = expert_trajectory[start_epoch]

        target_params = expert_trajectory[start_epoch+args.expert_epochs]
        target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params], 0)

        student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)]

        starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0)

        syn_images = image_syn

        y_hat = label_syn.to(args.device)

        param_loss_list = []
        param_dist_list = []
        indices_chunks = []

        for step in range(args.syn_steps):

            if not indices_chunks:
                indices = torch.randperm(len(syn_images))
                indices_chunks = list(torch.split(indices, args.batch_syn))

            these_indices = indices_chunks.pop()

            x = syn_images[these_indices]
            this_y = y_hat[these_indices]

            if args.texture:
                x = torch.cat([torch.stack([torch.roll(im, (torch.randint(im_size[0]*args.canvas_size, (1,)), torch.randint(im_size[1]*args.canvas_size, (1,))), (1,2))[:,:im_size[0],:im_size[1]] for im in x]) for _ in range(args.canvas_samples)])
                this_y = torch.cat([this_y for _ in range(args.canvas_samples)])

            if args.dsa and (not args.no_aug):
                x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param)

            if args.distributed:
                forward_params = student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
            else:
                forward_params = student_params[-1]
            x = student_net(x, flat_param=forward_params)
            ce_loss = criterion(x, this_y)

            grad = torch.autograd.grad(ce_loss, student_params[-1], create_graph=True)[0]

            student_params.append(student_params[-1] - syn_lr * grad)


        param_loss = torch.tensor(0.0).to(args.device)
        param_dist = torch.tensor(0.0).to(args.device)

        param_loss += torch.nn.functional.mse_loss(student_params[-1], target_params, reduction="sum")
        param_dist += torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum")

        param_loss_list.append(param_loss)
        param_dist_list.append(param_dist)


        param_loss /= num_params
        param_dist /= num_params

        param_loss /= param_dist

        grand_loss = param_loss

        optimizer_img.zero_grad()
        optimizer_lr.zero_grad()

        grand_loss.backward()

        optimizer_img.step()
        optimizer_lr.step()

        wandb.log({"Grand_Loss": grand_loss.detach().cpu(),
                   "Start_Epoch": start_epoch})

        for _ in student_params:
            del _

        if it%10 == 0:
            print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item()))
            
        ''' Evaluate synthetic data '''
        if it == eval_it_pool[-1]:
#         if it in eval_it_pool:     
            print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, args.model, it))
            print('\n==================== Final Results ====================\n')
            net_eval = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model

            eval_labs = label_syn
            with torch.no_grad():
                image_save = image_syn
            image_syn_eval, label_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification

            args.lr_net = syn_lr.item()
            _, acc_train, acc_test = evaluate_synset(net_eval, image_syn_eval, label_syn_eval, testloader, args, texture=args.texture)
            print('After {} iterations, the model test accuracy on synthetic data is {}%'.format(it, acc_test*100))

        if it in eval_it_pool or save_this_it:
            with torch.no_grad():
                image_save = image_syn.cuda()

                save_dir = os.path.join(".", "CIFAR10result")

                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                    
                if it in eval_it_pool:
                    save_name = os.path.join(save_dir, args.pix_init+'images_{}.png'.format(it))
                    image_syn_vis = copy.deepcopy(image_save.detach().cpu())
                    for ch in range(channel):
                        image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std_orig[ch] + mean_orig[ch]
                    image_syn_vis[image_syn_vis<0] = 0.0
                    image_syn_vis[image_syn_vis>1] = 1.0
                    save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.

                if it == eval_it_pool[-1]:
                    torch.save(image_save.cpu(), os.path.join(save_dir, args.pix_init+"images_{}.pt".format(it)))
                    torch.save(label_syn.cpu(), os.path.join(save_dir, args.pix_init+"labels_{}.pt".format(it)))

                if save_this_it:
                    torch.save(image_save.cpu(), os.path.join(save_dir, args.pix_init+"images_best.pt".format(it)))
                    torch.save(label_syn.cpu(), os.path.join(save_dir, args.pix_init+"labels_best.pt".format(it)))

                wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it)

                if args.ipc < 50 or args.force_save:
                    upsampled = image_save
                    if args.dataset != "ImageNet":
                        upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                        upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                    grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                    wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)
                    wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)

                    for clip_val in [2.5]:
                        std = torch.std(image_save)
                        mean = torch.mean(image_save)
                        upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std)
                        if args.dataset != "ImageNet":
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                        grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                        wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)

                    if args.zca:
                        image_save = image_save.to(args.device)
                        image_save = args.zca_trans.inverse_transform(image_save)
                        image_save.cpu()

                        torch.save(image_save.cpu(), os.path.join(save_dir, args.pix_init+"images_zca_{}.pt".format(it)))

                        upsampled = image_save
                        if args.dataset != "ImageNet":
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                        grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                        wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)
                        wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)

                        for clip_val in [2.5]:
                            std = torch.std(image_save)
                            mean = torch.mean(image_save)
                            upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std)
                            if args.dataset != "ImageNet":
                                upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                                upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                            grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                            wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image(
                                torch.nan_to_num(grid.detach().cpu()))}, step=it)

        wandb.log({"Synthetic_LR": syn_lr.detach().cpu()}, step=it)
        
    wandb.finish()


class arguments():
    def __init__(self,): 
        self.dataset = 'CIFAR10'
        self.subset = 'imagenette'
        self.model = 'ConvNet'
        self.lr_img = 1000
        self.lr_lr = 1e-05
        self.lr_teacher = 0.01
        self.lr_init = 0.01 
        self.batch_real = 256
        self.batch_train = 256
        self.batch_syn = None
        self.pix_init = 'real'
        self.dsa = 'False'
        self.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
        self.data_path = 'CIFAR10data'
        self.buffer_path = './CIFAR10buffers'
        self.expert_epochs = 2
        self.syn_steps = 22
        self.max_start_epoch = 15
        self.load_all = False
        self.no_aug = False
        self.zca = False
        self.texture = False
        self.canvas_size = 2
        self.canvas_samples = 1
        self.max_files = None
        self.max_experts = None
        self.force_save = False
        self.ipc = 10
        self.eval_mode = 'S'
        self.num_eval = 5
        self.eval_it = 100
        self.epoch_eval_train = 1000
        self.Iteration = 2000
        

args = arguments()

main(args)




CUDNN STATUS: True
Files already downloaded and verified
Files already downloaded and verified


0,1
Progress,▁

0,1
Progress,0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'lr_img': 1000, 'lr_lr': 1e-05, 'lr_teacher': 0.01, 'lr_init': 0.01, 'batch_real': 256, 'batch_train': 256, 'batch_syn': 100, 'pix_init': 'real', 'dsa': False, 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': 'CIFAR10data', 'buffer_path': './CIFAR10buffers', 'expert_epochs': 2, 'syn_steps': 22, 'max_start_epoch': 15, 'load_all': False, 'no_aug': False, 'zca': False, 'texture': False, 'canvas_size': 2, 'canvas_samples': 1, 'max_files': None, 'max_experts': None, 'force_save': False, 'ipc': 10, 'eval_mode': 'S', 'num_eval': 5, 'eval_it': 100, 'epoch_eval_train': 1000, 'Iteration': 2000, 'device': 'cuda', 'im_size': [32, 32], 'dc_aug_param': None, 'dsa_param': <utils.ParamDiffAug object at 0x0000029223BF8A30>, '_wandb': {}, 'zca_trans': None, 'distributed': False}
BUILDING DATASET


100%|██████████████████████████████████████████████████████████████████████████| 50000/50000 [00:09<00:00, 5199.76it/s]
50000it [00:00, 3066460.01it/s]


class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images channel 0, mean = -0.0000, std = 1.2211
real images channel 1, mean = -0.0002, std = 1.2211
real images channel 2, mean = 0.0002, std = 1.3014
initialize synthetic data from random real images
[2022-12-02 10:38:46] training begins
Expert Dir: ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:38:48] iter = 0000, loss = 1.2316
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:39:02] iter = 0010, loss = 0.9666
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFA

[2022-12-02 10:47:55] iter = 0390, loss = 0.8686
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:48:09] iter = 0400, loss = 0.9334
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:48:23] iter = 0410, loss = 0.9491
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 10:48:37] iter = 0420, loss = 0.9643
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:48:51] iter = 0430, loss = 0.9127
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 10:49:05] iter = 0440, 

loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 10:57:42] iter = 0820, loss = 0.8467
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 10:57:56] iter = 0830, loss = 0.9506
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:58:09] iter = 0840, loss = 0.7558
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 10:58:23] iter = 0850, loss = 0.7255
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 10:58:36] iter = 0860, loss = 0.7272
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR

loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:07:26] iter = 1250, loss = 0.6528
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:07:40] iter = 1260, loss = 0.6398
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:07:54] iter = 1270, loss = 0.9260
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:08:08] iter = 1280, loss = 0.8898
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:08:22] iter = 1290, loss = 0.7373
loading file ./CIFAR10buffers\CIFAR

[2022-12-02 11:17:09] iter = 1670, loss = 0.7106
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:17:22] iter = 1680, loss = 0.9094
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:17:36] iter = 1690, loss = 0.7830
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:17:50] iter = 1700, loss = 0.8066
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:18:03] iter = 1710, loss = 0.8161
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:18:17] iter = 1720, 

100%|██████████████████████████████████████████████████████████████████████████████| 1001/1001 [00:20<00:00, 49.80it/s]


[2022-12-02 11:24:58] Evaluate: epoch = 1000 train time = 20 s train loss = 0.000592 train acc = 1.0000, test acc = 0.4982
After 2000 iterations, the model test accuracy on synthetic data is 49.82%


0,1
Grand_Loss,▆█▄▄▅▆█▄▃███▃▇▁█▇▁███▅█▆▇█▇▅▂█▄▃▃▅█▆▆▆▅▇
Progress,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Start_Epoch,▂▇▂▃▁▃▆▁▁▇▇▇▂▆▁▇▇▂█▇▇▃█▅▅▇▅▄▃█▃▁▁▄█▅▅▅▄█

0,1
Grand_Loss,0.86508
Progress,2000.0
Start_Epoch,7.0


In [4]:
args.pix_init='noise'
main(args)

CUDNN STATUS: True
Files already downloaded and verified
Files already downloaded and verified


Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'lr_img': 1000, 'lr_lr': 1e-05, 'lr_teacher': 0.01, 'lr_init': 0.01, 'batch_real': 256, 'batch_train': 256, 'batch_syn': 100, 'pix_init': 'noise', 'dsa': False, 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': 'CIFAR10data', 'buffer_path': './CIFAR10buffers', 'expert_epochs': 2, 'syn_steps': 22, 'max_start_epoch': 15, 'load_all': False, 'no_aug': False, 'zca': False, 'texture': False, 'canvas_size': 2, 'canvas_samples': 1, 'max_files': None, 'max_experts': None, 'force_save': False, 'ipc': 10, 'eval_mode': 'S', 'num_eval': 5, 'eval_it': 100, 'epoch_eval_train': 1000, 'Iteration': 2000, 'device': 'cuda', 'im_size': [32, 32], 'dc_aug_param': None, 'dsa_param': <utils.ParamDiffAug object at 0x000002956388B9A0>, '_wandb': {}, 'zca_trans': None, 'distributed': False}
BUILDING DATASET


100%|██████████████████████████████████████████████████████████████████████████| 50000/50000 [00:09<00:00, 5178.76it/s]
50000it [00:00, 2935091.18it/s]


class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images channel 0, mean = -0.0000, std = 1.2211
real images channel 1, mean = -0.0002, std = 1.2211
real images channel 2, mean = 0.0002, std = 1.3014
initialize synthetic data from random noise
[2022-12-02 11:25:16] training begins
Expert Dir: ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:25:18] iter = 0000, loss = 1.3547
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:25:32] iter = 0010, loss = 1.0441
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buf

loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:34:21] iter = 0400, loss = 0.7831
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:34:35] iter = 0410, loss = 0.8179
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:34:48] iter = 0420, loss = 0.8097
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:35:02] iter = 0430, loss = 0.9622
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:35:15] iter = 0440, loss = 0.8121
loading file ./CIFAR10buffers\CIFAR

[2022-12-02 11:43:50] iter = 0820, loss = 0.7550
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:44:04] iter = 0830, loss = 0.7429
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:44:17] iter = 0840, loss = 0.9590
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:44:31] iter = 0850, loss = 0.7447
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:44:44] iter = 0860, loss = 0.8731
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:44:58] iter = 0870, 

loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:53:32] iter = 1250, loss = 0.9710
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:53:46] iter = 1260, loss = 0.8963
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:53:59] iter = 1270, loss = 0.9068
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 11:54:13] iter = 1280, loss = 0.9226
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 11:54:26] iter = 1290, loss = 0.9618
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR

loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 12:03:13] iter = 1680, loss = 0.6479
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 12:03:27] iter = 1690, loss = 0.7754
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
[2022-12-02 12:03:40] iter = 1700, loss = 0.9334
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 12:03:54] iter = 1710, loss = 0.8215
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
loading file ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_1.pt
[2022-12-02 12:04:07] iter = 1720, loss = 0.8787
loading file ./CIFAR10buffers\CIFAR

100%|██████████████████████████████████████████████████████████████████████████████| 1001/1001 [00:20<00:00, 50.02it/s]


[2022-12-02 12:10:45] Evaluate: epoch = 1000 train time = 20 s train loss = 0.000592 train acc = 1.0000, test acc = 0.3960
After 2000 iterations, the model test accuracy on synthetic data is 39.6%


0,1
Grand_Loss,▇▆█▆▄█▆█▇▇█▅▅▁▇▆▆▄▇▅▃█▄▇▇▇▆▂▇▄▃▆▁▃▅▇▁▇▆▆
Progress,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Start_Epoch,▃▂▆▁▁█▄▇▆▅▇▁▅▂▅▅▅▁▇▅▃▇▄███▅▁█▁▃▅▃▃▅█▃▇▅▅

0,1
Grand_Loss,0.97127
Progress,2000.0
Start_Epoch,13.0
