In [1]:
import os
import time
import copy
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from fvcore.nn import FlopCountAnalysis
import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug


In [3]:
def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    model = model.to(device)
    criterion = criterion.to(device)
    model.train()

    for epoch in range(epochs):
        print('Start Epoch #{}'.format(epoch+1))
        with tqdm.tqdm(total=len(train_loader)) as pbar:
            loss_avg, acc_avg, num_exp = 0, 0, 0
            for i_batch, datum in enumerate(train_loader):
                img = datum[0].float().to(device)
                if epoch == 0 and i_batch == 0:
                    flops = FlopCountAnalysis(model, img)
                    print('The FLOPs is {}'.format(flops.total()))
        #         if aug:
        #             if args.dsa:
        #                 img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
        #             else:
        #                 img = augment(img, args.dc_aug_param, device=args.device)
                lab = datum[1].long().to(device)
                n_b = lab.shape[0]
                output = model(img)
                loss = criterion(output, lab)
                acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
                loss_avg += loss.item()*n_b
                acc_avg += acc
                num_exp += n_b

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(1)
                del img, lab
                
        loss_avg /= num_exp
        acc_avg /= num_exp
        scheduler.step()
        print('Train accuracy is {}%'.format(acc_avg*100))
        print('Average train loss is {}'.format(loss_avg))
        
    loss_avg, acc_avg, num_exp = 0, 0, 0   
    model.eval()
    for i_batch, datum in enumerate(test_loader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]
        output = model(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b
        
        del img, lab

    loss_avg /= num_exp
    acc_avg /= num_exp
    print('Test accuracy is {}%'.format(acc_avg*100))
    print('Average test loss is {}'.format(loss_avg))          

    return loss_avg, acc_avg

class CustomDataset(Dataset):
    def __init__(self, images, labels, transform): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()
        self.transform = transform

    def __getitem__(self, index):
        return self.transform(self.images[index]), self.labels[index]

    def __len__(self):
        return self.images.shape[0]
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_path = './MNISTresult/realres_DC_MNIST_ConvNet_10ipc.pt'
data = torch.load(data_path)
images = data['data'][0][0]

# for img in images:
#     img = img*255
#     plt.imshow(img.permute(1,2,0), cmap='gray', vmin=0, vmax=255)
#     plt.show()
MNIST_dataset = 'MNIST'
MNIST_data_path = './MNISTdata'
MNIST_channel, MNIST_im_size, MNIST_num_classes, MNIST_class_names, MNIST_mean, MNIST_std, MNIST_dst_train, MNIST_dst_test, MNIST_testloader = get_dataset(MNIST_dataset, MNIST_data_path)
condensed_labs_train = torch.ones(10*MNIST_num_classes)
mean = [0.1307]
std = [0.3081]
transform = transforms.Compose([transforms.Normalize(mean=mean, std=std)])

for c in range(MNIST_num_classes):  
    condensed_labs_train[c*10:(c+1)*10]*=c
    
condensed_train_dst = CustomDataset(images, condensed_labs_train, transform)
MNIST_trainloader = torch.utils.data.DataLoader(condensed_train_dst, batch_size=8, shuffle=True, num_workers=0)
model = get_network('AlexNet', MNIST_channel, MNIST_num_classes, MNIST_im_size).to(device) # get a random model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
epochs = 20
start = time.time()
train(model, MNIST_trainloader , MNIST_testloader , optimizer, scheduler, criterion, epochs)
print('Training on the original MNIST dataset takes {} seconds'.format(time.time()-start))




Start Epoch #1


  0%|                                                                                           | 0/13 [00:00<?, ?it/s]Unsupported operator aten::max_pool2d encountered 3 time(s)
 31%|█████████████████████████▌                                                         | 4/13 [00:00<00:00, 33.56it/s]

The FLOPs is 1907605504


100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 46.42it/s]


Train accuracy is 8.0%
Average train loss is 2.315928840637207
Start Epoch #2


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 209.94it/s]


Train accuracy is 12.0%
Average train loss is 2.2920486545562744
Start Epoch #3


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 197.24it/s]


Train accuracy is 12.0%
Average train loss is 2.2721188640594483
Start Epoch #4


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 191.02it/s]


Train accuracy is 20.0%
Average train loss is 2.253855724334717
Start Epoch #5


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 201.60it/s]


Train accuracy is 31.0%
Average train loss is 2.231180410385132
Start Epoch #6


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 207.47it/s]


Train accuracy is 25.0%
Average train loss is 2.2040181159973145
Start Epoch #7


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 209.34it/s]


Train accuracy is 43.0%
Average train loss is 2.1644834518432616
Start Epoch #8


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 201.44it/s]


Train accuracy is 36.0%
Average train loss is 2.118197660446167
Start Epoch #9


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 178.72it/s]


Train accuracy is 33.0%
Average train loss is 2.070787172317505
Start Epoch #10


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 175.59it/s]


Train accuracy is 48.0%
Average train loss is 1.9771376848220825
Start Epoch #11


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 165.71it/s]


Train accuracy is 46.0%
Average train loss is 1.8678407669067383
Start Epoch #12


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 180.70it/s]


Train accuracy is 61.0%
Average train loss is 1.726105399131775
Start Epoch #13


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 191.24it/s]


Train accuracy is 62.0%
Average train loss is 1.5885427331924438
Start Epoch #14


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 188.44it/s]


Train accuracy is 63.0%
Average train loss is 1.4081426239013672
Start Epoch #15


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 157.06it/s]


Train accuracy is 59.0%
Average train loss is 1.34196298122406
Start Epoch #16


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 180.91it/s]


Train accuracy is 71.0%
Average train loss is 1.1887889528274536
Start Epoch #17


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 189.24it/s]


Train accuracy is 82.0%
Average train loss is 1.0989443731307984
Start Epoch #18


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 170.64it/s]


Train accuracy is 80.0%
Average train loss is 1.0355815982818604
Start Epoch #19


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 192.84it/s]


Train accuracy is 84.0%
Average train loss is 0.9929377126693726
Start Epoch #20


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 187.66it/s]

Train accuracy is 89.0%
Average train loss is 0.9691627073287964





Test accuracy is 83.32000000000001%
Average test loss is 1.6897642930984498
Training on the original MNIST dataset takes 3.576970100402832 seconds


In [2]:
def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    model = model.to(device)
    criterion = criterion.to(device)
    model.train()

    for epoch in range(epochs):
        print('Start Epoch #{}'.format(epoch+1))
        with tqdm.tqdm(total=len(train_loader)) as pbar:
            loss_avg, acc_avg, num_exp = 0, 0, 0
            for i_batch, datum in enumerate(train_loader):
                img = datum[0].float().to(device)
                if epoch == 0 and i_batch == 0:
                    flops = FlopCountAnalysis(model, img)
                    print('The FLOPs is {}'.format(flops.total()))
        #         if aug:
        #             if args.dsa:
        #                 img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
        #             else:
        #                 img = augment(img, args.dc_aug_param, device=args.device)
                lab = datum[1].long().to(device)
                n_b = lab.shape[0]
                output = model(img)
                loss = criterion(output, lab)
                acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
                loss_avg += loss.item()*n_b
                acc_avg += acc
                num_exp += n_b

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(1)
                del img, lab
                
        loss_avg /= num_exp
        acc_avg /= num_exp
        scheduler.step()
        print('Train accuracy is {}%'.format(acc_avg*100))
        print('Average train loss is {}'.format(loss_avg))
        
    loss_avg, acc_avg, num_exp = 0, 0, 0   
    model.eval()
    for i_batch, datum in enumerate(test_loader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]
        output = model(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b
        
        del img, lab

    loss_avg /= num_exp
    acc_avg /= num_exp
    print('Test accuracy is {}%'.format(acc_avg*100))
    print('Average test loss is {}'.format(loss_avg))          

    return loss_avg, acc_avg

class CustomDataset(Dataset):
    def __init__(self, images, labels, transform): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()
        self.transform = transform

    def __getitem__(self, index):
        return self.transform(self.images[index]), self.labels[index]

    def __len__(self):
        return self.images.shape[0]
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_path = './MNISTresult/noiseres_DC_MNIST_ConvNet_10ipc.pt'
data = torch.load(data_path)
images = data['data'][0][0]

# for img in images:
#     img = img*255
#     plt.imshow(img.permute(1,2,0), cmap='gray', vmin=0, vmax=255)
#     plt.show()
MNIST_dataset = 'MNIST'
MNIST_data_path = './MNISTdata'
MNIST_channel, MNIST_im_size, MNIST_num_classes, MNIST_class_names, MNIST_mean, MNIST_std, MNIST_dst_train, MNIST_dst_test, MNIST_testloader = get_dataset(MNIST_dataset, MNIST_data_path)
condensed_labs_train = torch.ones(10*MNIST_num_classes)
mean = [0.1307]
std = [0.3081]
transform = transforms.Compose([transforms.Normalize(mean=mean, std=std)])

for c in range(MNIST_num_classes):  
    condensed_labs_train[c*10:(c+1)*10]*=c
    
condensed_train_dst = CustomDataset(images, condensed_labs_train, transform)
MNIST_trainloader = torch.utils.data.DataLoader(condensed_train_dst, batch_size=8, shuffle=True, num_workers=0)
model = get_network('AlexNet', MNIST_channel, MNIST_num_classes, MNIST_im_size).to(device) # get a random model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
epochs = 20
start = time.time()
train(model, MNIST_trainloader , MNIST_testloader , optimizer, scheduler, criterion, epochs)
print('Training on the original MNIST dataset takes {} seconds'.format(time.time()-start))




Start Epoch #1


  0%|                                                                                           | 0/13 [00:00<?, ?it/s]Unsupported operator aten::max_pool2d encountered 3 time(s)
  8%|██████▍                                                                            | 1/13 [00:02<00:34,  2.88s/it]

The FLOPs is 1907605504


100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.41it/s]


Train accuracy is 8.0%
Average train loss is 2.328088493347168
Start Epoch #2


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 184.95it/s]


Train accuracy is 20.0%
Average train loss is 2.290717945098877
Start Epoch #3


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 174.34it/s]


Train accuracy is 21.0%
Average train loss is 2.262152605056763
Start Epoch #4


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 197.39it/s]


Train accuracy is 27.0%
Average train loss is 2.2375422382354735
Start Epoch #5


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 188.79it/s]


Train accuracy is 28.999999999999996%
Average train loss is 2.1944645977020265
Start Epoch #6


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 196.93it/s]


Train accuracy is 42.0%
Average train loss is 2.1312132358551024
Start Epoch #7


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 212.60it/s]


Train accuracy is 31.0%
Average train loss is 2.0392752838134767
Start Epoch #8


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 190.33it/s]


Train accuracy is 36.0%
Average train loss is 1.9330228233337403
Start Epoch #9


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 156.47it/s]


Train accuracy is 40.0%
Average train loss is 1.776557388305664
Start Epoch #10


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 174.12it/s]


Train accuracy is 49.0%
Average train loss is 1.5865830945968629
Start Epoch #11


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 195.03it/s]


Train accuracy is 61.0%
Average train loss is 1.290006217956543
Start Epoch #12


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 199.91it/s]


Train accuracy is 65.0%
Average train loss is 1.09957603931427
Start Epoch #13


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 187.18it/s]


Train accuracy is 68.0%
Average train loss is 0.9655977630615235
Start Epoch #14


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 192.20it/s]


Train accuracy is 80.0%
Average train loss is 0.7666957998275756
Start Epoch #15


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 188.69it/s]


Train accuracy is 83.0%
Average train loss is 0.694795196056366
Start Epoch #16


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 154.56it/s]


Train accuracy is 87.0%
Average train loss is 0.5846870398521423
Start Epoch #17


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 174.64it/s]


Train accuracy is 89.0%
Average train loss is 0.5356438875198364
Start Epoch #18


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 199.46it/s]


Train accuracy is 94.0%
Average train loss is 0.4765075373649597
Start Epoch #19


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 167.34it/s]


Train accuracy is 95.0%
Average train loss is 0.4528187727928162
Start Epoch #20


100%|█████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 215.12it/s]

Train accuracy is 97.0%
Average train loss is 0.4353028154373169





Test accuracy is 84.38%
Average test loss is 1.3267022100448609
Training on the original MNIST dataset takes 6.3093321323394775 seconds
