In [1]:
import os
import time
import copy
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from fvcore.nn import FlopCountAnalysis
import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug


In [3]:
def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    model = model.to(device)
    criterion = criterion.to(device)
    model.train()

    for epoch in range(epochs):
        print('Start Epoch #{}'.format(epoch+1))
        with tqdm.tqdm(total=len(train_loader)) as pbar:
            loss_avg, acc_avg, num_exp = 0, 0, 0
            for i_batch, datum in enumerate(train_loader):
                img = datum[0].float().to(device)
                if epoch == 0 and i_batch == 0:
                    flops = FlopCountAnalysis(model, img)
                    print('The FLOPs is {}'.format(flops.total()))
        #         if aug:
        #             if args.dsa:
        #                 img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
        #             else:
        #                 img = augment(img, args.dc_aug_param, device=args.device)
                lab = datum[1].long().to(device)
                n_b = lab.shape[0]
                output = model(img)
                loss = criterion(output, lab)
                acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
                loss_avg += loss.item()*n_b
                acc_avg += acc
                num_exp += n_b

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(1)
                del img, lab
                
        loss_avg /= num_exp
        acc_avg /= num_exp
        scheduler.step()
        print('Train accuracy is {}%'.format(acc_avg*100))
        print('Average train loss is {}'.format(loss_avg))
        
    loss_avg, acc_avg, num_exp = 0, 0, 0   
    model.eval()
    for i_batch, datum in enumerate(test_loader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]
        output = model(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b
        
        del img, lab

    loss_avg /= num_exp
    acc_avg /= num_exp
    print('Test accuracy is {}%'.format(acc_avg*100))
    print('Average test loss is {}'.format(loss_avg))          

    return loss_avg, acc_avg

class CustomDataset(Dataset):
    def __init__(self, images, labels, transform): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()
        self.transform = transform

    def __getitem__(self, index):
        return self.transform(self.images[index]), self.labels[index]

    def __len__(self):
        return self.images.shape[0]
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_path = './MNISTresult/realres_DC_MNIST_ConvNet_10ipc.pt'
data = torch.load(data_path)
images = data['data'][0][0]

# for img in images:
#     img = img*255
#     plt.imshow(img.permute(1,2,0), cmap='gray', vmin=0, vmax=255)
#     plt.show()
MNIST_dataset = 'MNIST'
MNIST_data_path = './MNISTdata'
MNIST_channel, MNIST_im_size, MNIST_num_classes, MNIST_class_names, MNIST_mean, MNIST_std, MNIST_dst_train, MNIST_dst_test, MNIST_testloader = get_dataset(MNIST_dataset, MNIST_data_path)
condensed_labs_train = torch.ones(10*MNIST_num_classes)
mean = [0.1307]
std = [0.3081]
transform = transforms.Compose([transforms.Normalize(mean=mean, std=std)])

for c in range(MNIST_num_classes):  
    condensed_labs_train[c*10:(c+1)*10]*=c
    
condensed_train_dst = CustomDataset(images, condensed_labs_train, transform)
MNIST_trainloader = torch.utils.data.DataLoader(condensed_train_dst, batch_size=8, shuffle=True, num_workers=0)
model = get_network('AlexNet', MNIST_channel, MNIST_num_classes, MNIST_im_size).to(device) # get a random model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
epochs = 20
start = time.time()
train(model, MNIST_trainloader , MNIST_testloader , optimizer, scheduler, criterion, epochs)
print('Training on the original MNIST dataset takes {} seconds'.format(time.time()-start))




Start Epoch #1


  0%|                                                                                           | 0/13 [00:00<?, ?it/s]Unsupported operator aten::max_pool2d encountered 3 time(s)
 23%|███████████████████▏                                                               | 3/13 [00:00<00:00, 27.26it/s]

The FLOPs is 1907605504


100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 43.13it/s]


Train accuracy is 5.0%
Average train loss is 2.319987335205078
Start Epoch #2


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 199.77it/s]


Train accuracy is 20.0%
Average train loss is 2.296140041351318
Start Epoch #3


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 173.58it/s]


Train accuracy is 23.0%
Average train loss is 2.2753256511688233
Start Epoch #4


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 189.15it/s]


Train accuracy is 32.0%
Average train loss is 2.251053409576416
Start Epoch #5


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 203.99it/s]


Train accuracy is 20.0%
Average train loss is 2.2358448696136475
Start Epoch #6


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 180.28it/s]


Train accuracy is 31.0%
Average train loss is 2.2027390098571775
Start Epoch #7


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 188.75it/s]


Train accuracy is 40.0%
Average train loss is 2.163063716888428
Start Epoch #8


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 186.66it/s]


Train accuracy is 42.0%
Average train loss is 2.1130717849731444
Start Epoch #9


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 197.80it/s]


Train accuracy is 36.0%
Average train loss is 2.0642401790618896
Start Epoch #10


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 185.44it/s]


Train accuracy is 40.0%
Average train loss is 1.9499921894073486
Start Epoch #11


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 194.38it/s]


Train accuracy is 46.0%
Average train loss is 1.8242185974121095
Start Epoch #12


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 196.98it/s]


Train accuracy is 51.0%
Average train loss is 1.682830400466919
Start Epoch #13


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 181.76it/s]


Train accuracy is 62.0%
Average train loss is 1.5456060647964478
Start Epoch #14


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 209.18it/s]


Train accuracy is 68.0%
Average train loss is 1.3750854015350342
Start Epoch #15


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 213.85it/s]


Train accuracy is 69.0%
Average train loss is 1.2651242113113403
Start Epoch #16


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 203.25it/s]


Train accuracy is 73.0%
Average train loss is 1.1190080404281617
Start Epoch #17


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 212.94it/s]


Train accuracy is 77.0%
Average train loss is 1.0172993516921998
Start Epoch #18


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 209.40it/s]


Train accuracy is 79.0%
Average train loss is 0.9866066837310791
Start Epoch #19


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 219.33it/s]


Train accuracy is 84.0%
Average train loss is 0.925868968963623
Start Epoch #20


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 228.93it/s]


Train accuracy is 86.0%
Average train loss is 0.9039381456375122
Test accuracy is 82.89%
Average test loss is 1.647504403114319
Training on the original MNIST dataset takes 3.6319658756256104 seconds
