<a href="https://colab.research.google.com/github/Joyce-ZhouY/ECE1512-ProjectB/blob/main/ProjectB_Part2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import libraries
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import time
import copy
from torchvision.utils import save_image
import os

# import utility classes
import networks
import utils

In [2]:
# initialize args
args = type('', (), {})()
args.batch_train = 256
args.batch_real = 256
args.epoch_eval_train = 30 # the number of epoch to train a model with synthetic dataset 
args.epoch_train = 10 # the number of epoch to train a network
args.lr_net = 0.01 # learning rate of the network
args.lr_img = 0.1 # learning rate of synthetic dataset
args.dsa_strategy = None
args.num_eval = 5 # number of randomly initialized networks
args.num_exp = 1 # number of experiments
args.ipc = 10 # image per class
args.outer_loop, args.inner_loop = utils.get_loops(args.ipc)
args.batch_real = 256 # batch size for original dataset
args.dsa = None
args.method = 'DC'
args.Iteration = 1000 # training iterations
args.model = 'ConvNet'
args.dataset = 'MNIST'
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = utils.ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False
args.dc_aug_param = None
args.dis_metric = 'ours' # distance metric
args.data_path = 'data'
args.save_path = 'result'
args.eval_mode = 'S' # evaluate the synthetic data with the same model

In [3]:
#Inspired by https://arxiv.org/abs/2110.0418
def distribution_matching():
  if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)
  
  evaluate_pool = []
  if args.eval_mode == 'S':
    evaluate_pool = np.arange(0, args.Iteration+1, 500).tolist()
  else:
    evaluate_pool = [args.Iteration]

  # load real dataset
  channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = utils.get_dataset(args.dataset, args.data_path)
  models = utils.get_eval_pool(args.eval_mode, args.model, args.model)

  # record accuracy of each model
  records = dict()

  for model in models:
    records[model] = []

  # start experiments loops
  for experiment in range(args.num_exp):
    # orgnize real data
    images_real = []
    labels_real = []
    for i in range(len(dst_train)):
      images_real.append(torch.unsqueeze(dst_train[i][0], dim=0))

    for i in range(len(dst_train)):
      labels_real.append(dst_train[i][1])

    class_index = [[] for c in range(num_classes)]
    for n, label in enumerate(labels_real):
      class_index[label].append(n)
    # move inputs to device
    images_real = torch.cat(images_real, dim=0).to(args.device)
    labels_real = torch.tensor(labels_real, dtype=torch.long, device=args.device)

    # implemente a method of ramdomly selecting n images from each class
    def pick_images(c, n):
      index = np.random.permutation(class_index[c])[:n]
      return images_real[index]

    # initialize synthetic dataset: randomly or gaussian noise
    images_syn = torch.randn(
        size = (args.ipc * num_classes, channel, im_size[0], im_size[1]),
        dtype = torch.float,
        requires_grad=True, 
        device=args.device
    )
    labels_syn = torch.tensor(
        [np.ones(args.ipc) * n for n in range(num_classes)], 
        dtype=torch.long, 
        requires_grad=False, 
        device=args.device
    ).view(-1)
  
    # if args.init = real, initilize synthetic data by ramdomly choosing from real data
    if args.init == 'real':
      for n in range(num_classes):
        images_syn.data[n * args.ipc : (n + 1) * args.ipc] = pick_images(n, args.ipc).detach().data


    # training process begins
    optimizer_syn = torch.optim.SGD(
      [images_syn, ], 
      lr=args.lr_img, 
      momentum=0.5) 
    optimizer_syn.zero_grad()
    loss_fn = nn.CrossEntropyLoss().to(args.device)

    for iter in range(args.Iteration + 1):
      # evaluate synthetic set 
      if iter in evaluate_pool:
        for model in models:
          args.epoch_eval_train = 300
        
          accuracy = []
          # loop over number of random model initialization
          for eva in range(args.num_eval):
            # load a network
            net = utils.get_network(model, channel, num_classes, im_size).to(args.device)
            _, acc_train, acc_test = utils.evaluate_synset(eva, net, images_syn, labels_syn, testloader, args)
            accuracy.append(acc_test)
          print('Evaluate synthetic data on model: %s, mean accuracy = %.4f'%(model, np.mean(accuracy)))

          # save the last accuracy into records
          if iter == args.Iteration:
            records[model] += accuracy

        #save synthetic data
        path =  os.path.join(args.save_path, 'distribution_syn_%s_%s_%s_%dipc_exp%d_iter%d.png'%('random' if args.init == 'real' else 'noise', args.dataset, args.model, args.ipc, experiment, iter))
        save_image(images_syn, path, nrow=args.ipc)

      # train synthetic data
      net = utils.get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
      net.train()
      for param in list(net.parameters()):
        param.requires_grad = False
      
      # parallel
      embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed

      # update synthetic data
      loss = torch.tensor(0.0).to(args.device)
      for num in range(num_classes):
        image_batch_real = pick_images(num, args.batch_real)
        image_batch_syn = images_syn[num * args.ipc : (num + 1) * args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
        out_real = embed(image_batch_real).detach()
        out_syn = embed(image_batch_syn)

        loss += torch.sum((torch.mean(out_real, dim=0) - torch.mean(out_syn, dim=0))**2)

      optimizer_syn.zero_grad()
      loss.backward()
      optimizer_syn.step()
         

  for model in models:
    accuracy = records[model]
    print("Experiments = %d, model = %s, accuracy= %.2f"%(args.num_exp, args.model, np.mean(accuracy)*100))
  return images_syn, labels_syn

In [4]:
args.model = 'ConvNet'
args.dataset = 'MNIST'
args.eval_mode = 'S'
args.init = 'real'
syn = distribution_matching()

  labels_syn = torch.tensor(
  labels_syn = torch.tensor(


[2022-12-12 01:01:37] Evaluate_00: epoch = 0300 train time = 10 s train loss = 0.003749 train acc = 1.0000, test acc = 0.9126
[2022-12-12 01:01:43] Evaluate_01: epoch = 0300 train time = 3 s train loss = 0.004039 train acc = 1.0000, test acc = 0.9018
[2022-12-12 01:01:49] Evaluate_02: epoch = 0300 train time = 3 s train loss = 0.003731 train acc = 1.0000, test acc = 0.9048
[2022-12-12 01:01:55] Evaluate_03: epoch = 0300 train time = 3 s train loss = 0.003970 train acc = 1.0000, test acc = 0.9139
[2022-12-12 01:02:00] Evaluate_04: epoch = 0300 train time = 3 s train loss = 0.003943 train acc = 1.0000, test acc = 0.9096
Evaluate synthetic data on model: ConvNet, mean accuracy = 0.9085
[2022-12-12 01:03:12] Evaluate_00: epoch = 0300 train time = 3 s train loss = 0.004265 train acc = 1.0000, test acc = 0.9345
[2022-12-12 01:03:18] Evaluate_01: epoch = 0300 train time = 3 s train loss = 0.004476 train acc = 1.0000, test acc = 0.9325
[2022-12-12 01:03:24] Evaluate_02: epoch = 0300 train time

In [5]:
pip install pthflops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pthflops
  Downloading pthflops-0.4.2-py3-none-any.whl (11 kB)
Installing collected packages: pthflops
Successfully installed pthflops-0.4.2


In [25]:
from pthflops import count_ops
args.dataset = 'MNIST'
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = utils.get_dataset(args.dataset, args.data_path)
args.model = 'ConvNet'
net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)

def count_flops(dst_test):
  flops = 0
  images_test = []
  labels_test = []
  for i in range(len(dst_test)):
    images_test.append(torch.unsqueeze(dst_test[i][0], dim=0))
  for i in range(len(dst_test)):
    labels_test.append(dst_test[i][1])
  class_index = [[] for c in range(num_classes)]
  for n, label in enumerate(labels_test):
    class_index[label].append(n)
  images_test = torch.cat(images_test, dim=0).to(args.device)
  def pick_images(c, n):
      index = np.random.permutation(class_index[c])[:n]
      return images_test[index]
  for n in range(num_classes):
    img_batch = pick_images(n, 2560)
    flops += count_ops(net, img_batch)[0]
  return flops

flops = count_flops(dst_test)
print("The number of FLOPS = %.2f"%(flops))


Operation     OPS           
------------  ------------  
features_0    1284505600    
features_2    256901120     
features_3    128450560     
features_4    37025873920   
features_6    64225280      
features_7    32112640      
features_8    9256468480    
features_10   16056320      
features_11   8028160       
classifier    20070410      
-----------   -----------   
Input size: (980, 1, 28, 28)
48,092,692,490 FLOPs or approx. 48.09 GFLOPs
Operation     OPS           
------------  ------------  
features_0    1487667200    
features_2    297533440     
features_3    148766720     
features_4    42882007040   
features_6    74383360      
features_7    37191680      
features_8    10720501760   
features_10   18595840      
features_11   9297920       
classifier    23244810      
-----------   -----------   
Input size: (1135, 1, 28, 28)
55,699,189,770 FLOPs or approx. 55.70 GFLOPs
Operation     OPS           
------------  ------------  
features_0    1352663040    
features_2

In [28]:
# train a network with syn data and test with real data

# testing data
args.dataset = 'MNIST'
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = utils.get_dataset(args.dataset, args.data_path)
# net
args.model = 'ConvNet'
net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)
it_eval = args.num_eval

# images_train = syn[0].to(args.device)
# labels_train = syn[1].to(args.device)
images_train, labels_train = copy.deepcopy(syn[0].detach()), copy.deepcopy(syn[1].detach())
dst_train = utils.TensorDataset(images_train, labels_train)

_, acc_train, acc_test = utils.evaluate_dataset(it_eval,net,dst_train,testloader,args)
print("test with real dataset, accuracy = %.4f"%(acc_test))

flops = count_ops(net, images_train)[0]
print("The number of FLOPS = %.2f"%(flops))

[2022-12-12 01:27:07] Evaluate_05: epoch = 0010 train time = 0 s train loss = 1.214679 train acc = 0.8700, test acc = 0.8477
test with real dataset, accuracy = 0.8477
Operation     OPS          
------------  -----------  
features_0    131072000    
features_2    26214400     
features_3    13107200     
features_4    3778150400   
features_6    6553600      
features_7    3276800      
features_8    944537600    
features_10   1638400      
features_11   819200       
classifier    2048010      
-----------   ----------   
Input size: (100, 1, 28, 28)
4,907,417,610 FLOPs or approx. 4.91 GFLOPs
The number of FLOPS = 4907417610.00


In [32]:
args.model = 'ConvNet'
args.dataset = 'MNIST'
args.eval_mode = 'S'
args.init = 'noise'
syn_noise = distribution_matching()

  labels_syn = torch.tensor(


[2022-12-12 01:30:58] Evaluate_00: epoch = 0300 train time = 3 s train loss = 0.004619 train acc = 1.0000, test acc = 0.1065
[2022-12-12 01:31:03] Evaluate_01: epoch = 0300 train time = 3 s train loss = 0.004660 train acc = 1.0000, test acc = 0.1075
[2022-12-12 01:31:09] Evaluate_02: epoch = 0300 train time = 3 s train loss = 0.004640 train acc = 1.0000, test acc = 0.0765
[2022-12-12 01:31:15] Evaluate_03: epoch = 0300 train time = 3 s train loss = 0.004846 train acc = 1.0000, test acc = 0.0646
[2022-12-12 01:31:21] Evaluate_04: epoch = 0300 train time = 3 s train loss = 0.004556 train acc = 1.0000, test acc = 0.0940
Evaluate synthetic data on model: ConvNet, mean accuracy = 0.0898
[2022-12-12 01:32:33] Evaluate_00: epoch = 0300 train time = 3 s train loss = 0.003946 train acc = 1.0000, test acc = 0.8811
[2022-12-12 01:32:39] Evaluate_01: epoch = 0300 train time = 3 s train loss = 0.004148 train acc = 1.0000, test acc = 0.8841
[2022-12-12 01:32:45] Evaluate_02: epoch = 0300 train time 