In [None]:
import torch
import os
import copy
import torch.nn as nn
from tqdm import tqdm
from torchvision import datasets, transforms
from utils import get_dataset, get_network, get_daparam,TensorDataset, epoch, ParamDiffAug

In [None]:
class arguments:
    pass

args = arguments()
args.zca = False
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.buffer_path = './buffers'
args.dataset = 'kinetics400'
args.data_path = 'kinetics-dataset/k400'
args.batch_real = 10
args.subset = 'imagenette'
args.model = 'ConvNet'
args.batch_train = 10
args.num_experts = 100
args.lr_teacher = 0.01
args.mom = 0
args.l2 = 0
args.train_epochs = 50
args.dsa = True
args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'

In [None]:
im_size = (128, 128)
mean=[0.43216, 0.394666, 0.37645]
std=[0.22803, 0.22145, 0.216989]
num_classes=1
class_map = {0:1}
transform = transforms.Compose([transforms.ToPILImage(),transforms.Normalize(mean=mean, std=std),transforms.Resize(im_size),transforms.CenterCrop(im_size)])
dst_train = datasets.Kinetics(args.data_path,frames_per_clip=10)
dst_test = datasets.Kinetics(args.data_path,frames_per_clip=10)

In [None]:
print('Hyper-parameters: \n', args.__dict__)

save_dir = os.path.join(args.buffer_path, args.dataset)
if args.dataset == "ImageNet":
    save_dir = os.path.join(save_dir, args.subset, str(args.res))
if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
    save_dir += "_NO_ZCA"
save_dir = os.path.join(save_dir, args.model)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
''' organize the real dataset '''
images_all = []
labels_all = []
indices_class = [[] for c in range(num_classes)]
print("BUILDING DATASET")
for i in tqdm(range(len(dst_train))):
    sample = [j for j in dst_train[i]]
    images_all.append(torch.unsqueeze(sample[0], dim=0))
    labels_all.append(0)

for i, lab in tqdm(enumerate(labels_all)):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to("cpu")
labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")

for c in range(num_classes):
    print('class c = %d: %d real images'%(c, len(indices_class[c])))

for ch in range(channel):
    print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

criterion = nn.CrossEntropyLoss().to(args.device)

trajectories = []

dst_train = TensorDataset(copy.deepcopy(images_all.detach()), copy.deepcopy(labels_all.detach()))
trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

''' set augmentation for whole-dataset training '''
args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None)
args.dc_aug_param['strategy'] = 'crop_scale_rotate'  # for whole-dataset training
print('DC augmentation parameters: \n', args.dc_aug_param)


In [None]:
for it in range(0, args.num_experts):

    ''' Train synthetic data '''
    teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
    teacher_net.train()
    lr = args.lr_teacher
    teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)  # optimizer_img for synthetic data
    teacher_optim.zero_grad()

    timestamps = []

    timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

    lr_schedule = [args.train_epochs // 2 + 1]

    for e in range(args.train_epochs):

        train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim,
                                    criterion=criterion, args=args, aug=True)

        test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None,
                                    criterion=criterion, args=args, aug=False)

        print("Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc))

        timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

        if e in lr_schedule and args.decay:
            lr *= 0.1
            teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
            teacher_optim.zero_grad()

    trajectories.append(timestamps)

    if len(trajectories) == args.save_interval:
        n = 0
        while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))):
            n += 1
        print("Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))))
        torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))
        trajectories = []