In [1]:
import os
import argparse
import torch
import torch.nn as nn
from tqdm import tqdm
from utils import get_dataset, get_network, get_daparam, TensorDataset, epoch, ParamDiffAug
import copy

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def main(args):

    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()

    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)

    # print('\n================== Exp %d ==================\n '%exp)
    print('Hyper-parameters: \n', args.__dict__)

    save_dir = os.path.join(args.buffer_path, args.dataset)
    if args.dataset == "ImageNet":
        save_dir = os.path.join(save_dir, args.subset, str(args.res))
    if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
        save_dir += "_NO_ZCA"
    save_dir = os.path.join(save_dir, args.model)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)


    ''' organize the real dataset '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]
    print("BUILDING DATASET")
    for i in tqdm(range(len(dst_train))):
        sample = dst_train[i]
        images_all.append(torch.unsqueeze(sample[0], dim=0))
        labels_all.append(class_map[torch.tensor(sample[1]).item()])

    for i, lab in tqdm(enumerate(labels_all)):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to("cpu")
    labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")

    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

    criterion = nn.CrossEntropyLoss().to(args.device)

    trajectories = []

    dst_train = TensorDataset(copy.deepcopy(images_all.detach()), copy.deepcopy(labels_all.detach()))
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    ''' set augmentation for whole-dataset training '''
    args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None)
    args.dc_aug_param['strategy'] = 'crop_scale_rotate'  # for whole-dataset training
    print('DC augmentation parameters: \n', args.dc_aug_param)

    for it in range(0, args.num_experts):

        ''' Train synthetic data '''
        teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
        teacher_net.train()
        lr = args.lr_teacher
        teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)  # optimizer_img for synthetic data
        teacher_optim.zero_grad()

        timestamps = []

        timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

        lr_schedule = [args.train_epochs // 2 + 1]

        for e in range(args.train_epochs):

            train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim,
                                        criterion=criterion, args=args, aug=True)

            test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None,
                                        criterion=criterion, args=args, aug=False)

            print("Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc))

            timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

            if e in lr_schedule and args.decay:
                lr *= 0.1
                teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
                teacher_optim.zero_grad()

        trajectories.append(timestamps)

        if len(trajectories) == args.save_interval:
            n = 0
            while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))):
                n += 1
            print("Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))))
            torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))
            trajectories = []


class arguments():
    def __init__(self,): 
        self.dataset = 'CIFAR10'
        self.subset = 'imagenette'
        self.model = 'ConvNet'
        self.num_experts = 10
        self.lr_teacher = 0.01
        self.batch_real = 256
        self.batch_train = 256
        self.dsa = 'False'
        self.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
        self.data_path = 'CIFAR10data'
        self.buffer_path = './CIFAR10buffers'
        self.train_epochs = 30
        self.zca = False
        self.decay = False
        self.mom = 0
        self.l2 = 0
        self.save_interval = 5

args = arguments()
main(args)




Files already downloaded and verified
Files already downloaded and verified
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'num_experts': 10, 'lr_teacher': 0.01, 'batch_real': 256, 'batch_train': 256, 'dsa': False, 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': 'CIFAR10data', 'buffer_path': './CIFAR10buffers', 'train_epochs': 30, 'zca': False, 'decay': False, 'mom': 0, 'l2': 0, 'save_interval': 5, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x0000029AB440F3A0>}
BUILDING DATASET


100%|██████████████████████████████████████████████████████████████████████████| 50000/50000 [00:08<00:00, 5569.16it/s]
50000it [00:00, 16134420.68it/s]


class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images channel 0, mean = -0.0000, std = 1.2211
real images channel 1, mean = -0.0002, std = 1.2211
real images channel 2, mean = 0.0002, std = 1.3014
DC augmentation parameters: 
 {'crop': 4, 'scale': 0.2, 'rotate': 45, 'noise': 0.001, 'strategy': 'crop_scale_rotate'}
Itr: 0	Epoch: 0	Train Acc: 0.34928	Test Acc: 0.4378
Itr: 0	Epoch: 1	Train Acc: 0.45586	Test Acc: 0.4968
Itr: 0	Epoch: 2	Train Acc: 0.50108	Test Acc: 0.5337
Itr: 0	Epoch: 3	Train Acc: 0.53628	Test Acc: 0.5639
Itr: 0	Epoch: 4	Train Acc: 0.55692	Test Acc: 0.5706
Itr: 0	Epoch: 5	Train Acc: 0.57772	Test Acc: 0.5999
Itr: 0	Epoch: 6	Train Acc: 0.59556	Test Acc: 0.6139
Itr: 0	Epoch: 7	Train Acc: 0.60658	Test Acc: 0.6247
Itr: 0	Epo

Itr: 4	Epoch: 26	Train Acc: 0.72956	Test Acc: 0.7194
Itr: 4	Epoch: 27	Train Acc: 0.73288	Test Acc: 0.7195
Itr: 4	Epoch: 28	Train Acc: 0.73508	Test Acc: 0.7353
Itr: 4	Epoch: 29	Train Acc: 0.73796	Test Acc: 0.7359
Saving ./CIFAR10buffers\CIFAR10_NO_ZCA\ConvNet\replay_buffer_0.pt
Itr: 5	Epoch: 0	Train Acc: 0.3458	Test Acc: 0.4429
Itr: 5	Epoch: 1	Train Acc: 0.45272	Test Acc: 0.4972
Itr: 5	Epoch: 2	Train Acc: 0.50056	Test Acc: 0.5381
Itr: 5	Epoch: 3	Train Acc: 0.5351	Test Acc: 0.5582
Itr: 5	Epoch: 4	Train Acc: 0.55664	Test Acc: 0.5778
Itr: 5	Epoch: 5	Train Acc: 0.57694	Test Acc: 0.5982
Itr: 5	Epoch: 6	Train Acc: 0.58854	Test Acc: 0.6207
Itr: 5	Epoch: 7	Train Acc: 0.60716	Test Acc: 0.6229
Itr: 5	Epoch: 8	Train Acc: 0.6201	Test Acc: 0.6387
Itr: 5	Epoch: 9	Train Acc: 0.63158	Test Acc: 0.6472
Itr: 5	Epoch: 10	Train Acc: 0.63952	Test Acc: 0.6648
Itr: 5	Epoch: 11	Train Acc: 0.64748	Test Acc: 0.6798
Itr: 5	Epoch: 12	Train Acc: 0.65634	Test Acc: 0.6633
Itr: 5	Epoch: 13	Train Acc: 0.66298	Test Acc: 