In [1]:
import os
import time
import copy
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from fvcore.nn import FlopCountAnalysis
import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug


In [3]:
def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    model = model.to(device)
    criterion = criterion.to(device)
    model.train()

    for epoch in range(epochs):
        print('Start Epoch #{}'.format(epoch+1))
        with tqdm.tqdm(total=len(train_loader)) as pbar:
            loss_avg, acc_avg, num_exp = 0, 0, 0
            for i_batch, datum in enumerate(train_loader):
                img = datum[0].float().to(device)
                if epoch == 0 and i_batch == 0:
                    flops = FlopCountAnalysis(model, img)
                    print('The FLOPs is {}'.format(flops.total()))
        #         if aug:
        #             if args.dsa:
        #                 img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
        #             else:
        #                 img = augment(img, args.dc_aug_param, device=args.device)
                lab = datum[1].long().to(device)
                n_b = lab.shape[0]
                output = model(img)
                loss = criterion(output, lab)
                acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
                loss_avg += loss.item()*n_b
                acc_avg += acc
                num_exp += n_b

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(1)
                del img, lab
                
        loss_avg /= num_exp
        acc_avg /= num_exp
        scheduler.step()
        print('Train accuracy is {}%'.format(acc_avg*100))
        print('Average train loss is {}'.format(loss_avg))
        
    loss_avg, acc_avg, num_exp = 0, 0, 0   
    model.eval()
    for i_batch, datum in enumerate(test_loader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]
        output = model(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b
        
        del img, lab

    loss_avg /= num_exp
    acc_avg /= num_exp
    print('Test accuracy is {}%'.format(acc_avg*100))
    print('Average test loss is {}'.format(loss_avg))          

    return loss_avg, acc_avg

class CustomDataset(Dataset):
    def __init__(self, images, labels, transform): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()
        self.transform = transform

    def __getitem__(self, index):
        return self.transform(self.images[index]), self.labels[index]

    def __len__(self):
        return self.images.shape[0]
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_path = './CIFAR10result/realres_DC_CIFAR10_ConvNet_10ipc.pt'
data = torch.load(data_path)
images = data['data'][0][0]

# for img in images:
#     img = img*255
#     plt.imshow(img.permute(1,2,0), cmap='gray', vmin=0, vmax=255)
#     plt.show()
CIFAR10_dataset = 'CIFAR10'
CIFAR10_data_path = './CIFAR10data'
CIFAR10_channel, CIFAR10_im_size, CIFAR10_num_classes, CIFAR10_class_names, CIFAR10_mean, CIFAR10_std, CIFAR10_dst_train, CIFAR10_dst_test, CIFAR10_testloader = get_dataset(CIFAR10_dataset, CIFAR10_data_path)
condensed_labs_train = torch.ones(10*CIFAR10_num_classes)
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([transforms.Normalize(mean=mean, std=std)])

for c in range(CIFAR10_num_classes):  
    condensed_labs_train[c*10:(c+1)*10]*=c
    
condensed_train_dst = CustomDataset(images, condensed_labs_train, transform)
CIFAR10_trainloader = torch.utils.data.DataLoader(condensed_train_dst, batch_size=8, shuffle=True, num_workers=0)
model = get_network('AlexNet', CIFAR10_channel, CIFAR10_num_classes, CIFAR10_im_size).to(device) # get a random model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
epochs = 20
start = time.time()
train(model, CIFAR10_trainloader , CIFAR10_testloader , optimizer, scheduler, criterion, epochs)
print('Training on the original CIFAR10 dataset takes {} seconds'.format(time.time()-start))




Files already downloaded and verified
Files already downloaded and verified
Start Epoch #1


  0%|                                                                                           | 0/13 [00:00<?, ?it/s]Unsupported operator aten::max_pool2d encountered 3 time(s)
 31%|█████████████████████████▌                                                         | 4/13 [00:00<00:00, 33.04it/s]

The FLOPs is 1960034304


100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 39.85it/s]


Train accuracy is 7.000000000000001%
Average train loss is 2.322800941467285
Start Epoch #2


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 191.59it/s]


Train accuracy is 15.0%
Average train loss is 2.2777725219726563
Start Epoch #3


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 203.52it/s]


Train accuracy is 16.0%
Average train loss is 2.2258460330963135
Start Epoch #4


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 206.47it/s]


Train accuracy is 22.0%
Average train loss is 2.143567943572998
Start Epoch #5


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 210.64it/s]


Train accuracy is 19.0%
Average train loss is 2.103898735046387
Start Epoch #6


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 215.61it/s]


Train accuracy is 26.0%
Average train loss is 2.007688226699829
Start Epoch #7


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 210.23it/s]


Train accuracy is 30.0%
Average train loss is 1.9272135877609253
Start Epoch #8


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 224.75it/s]


Train accuracy is 30.0%
Average train loss is 1.845277795791626
Start Epoch #9


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 182.99it/s]


Train accuracy is 32.0%
Average train loss is 1.8748187828063965
Start Epoch #10


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 224.81it/s]


Train accuracy is 33.0%
Average train loss is 1.6779120635986329
Start Epoch #11


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 220.59it/s]


Train accuracy is 48.0%
Average train loss is 1.539889931678772
Start Epoch #12


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 222.31it/s]


Train accuracy is 45.0%
Average train loss is 1.5718065452575685
Start Epoch #13


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 219.38it/s]


Train accuracy is 54.0%
Average train loss is 1.3636942100524903
Start Epoch #14


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 219.52it/s]


Train accuracy is 54.0%
Average train loss is 1.298374834060669
Start Epoch #15


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 233.81it/s]


Train accuracy is 59.0%
Average train loss is 1.2198048973083495
Start Epoch #16


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 227.73it/s]

Train accuracy is 66.0%
Average train loss is 1.147015345096588
Start Epoch #17



100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 235.76it/s]


Train accuracy is 69.0%
Average train loss is 1.0841358757019044
Start Epoch #18


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 234.99it/s]


Train accuracy is 75.0%
Average train loss is 1.028213918209076
Start Epoch #19


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 235.62it/s]


Train accuracy is 76.0%
Average train loss is 1.0008485651016235
Start Epoch #20


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 234.38it/s]


Train accuracy is 76.0%
Average train loss is 0.9808714413642883
Test accuracy is 29.520000000000003%
Average test loss is 2.0852886209487913
Training on the original CIFAR10 dataset takes 3.881880760192871 seconds
