### Reference
https://github.com/VICO-UoE/DatasetCondensation

# Continual Learning

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/ECE1513/Project_B_Supp/

Mounted at /content/drive
/content/drive/MyDrive/ECE1513/Project_B_Supp


In [None]:
import os
import numpy as np
import torch
import argparse
from utils import get_dataset, get_network, get_eval_pool, evaluate_synset, ParamDiffAug, TensorDataset
import copy
import gc

In [None]:
def continual_learning():
  channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
  ''' all training data '''
  images_all = []
  labels_all = []
  indices_class = [[] for c in range(num_classes)]

  images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
  labels_all = [dst_train[i][1] for i in range(len(dst_train))]
  for i, lab in enumerate(labels_all):
      indices_class[lab].append(i)
  images_all = torch.cat(images_all, dim=0).to(args.device)
  labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

  def get_images(c, n):  # get random n images from class c
      idx_shuffle = np.random.permutation(indices_class[c])[:n]
      return images_all[idx_shuffle]

  print()
  print('==================================================================================')
  print('method: ', args.method)
  results = np.zeros((args.steps, 5*args.num_eval))

  for seed_cl in range(1):
    num_classes_step = num_classes // args.steps
    np.random.seed(seed_cl)
    class_order = np.random.permutation(num_classes).tolist()
    print('=========================================')
    print('seed: ', seed_cl)
    print('class_order: ', class_order)
    print('augmentation strategy: \n', args.dsa_strategy)
    print('augmentation parameters: \n', args.dsa_param.__dict__)

    if args.method == 'random':
        images_train_all = []
        labels_train_all = []
        for step in range(args.steps):
            classes_current = class_order[step * num_classes_step: (step + 1) * num_classes_step]
            images_train_all += [torch.cat([get_images(c, args.ipc) for c in classes_current], dim=0)]
            labels_train_all += [torch.tensor([c for c in classes_current for i in range(args.ipc)], dtype=torch.long, device=args.device)]

    if args.dataset == 'MNIST':
      image_syn = torch.load('image_syn_mnist.pt',map_location=torch.device('cpu')).to(args.device)
      label_syn = torch.load('label_syn_mnist.pt',map_location=torch.device('cpu')).to(args.device)

    elif args.dataset == 'MHIST':
      image_syn = torch.load('image_syn.pt',map_location=torch.device('cpu')).to(args.device)
      label_syn = torch.load('label_syn.pt',map_location=torch.device('cpu')).to(args.device)


    for step in range(args.steps):
        print('\n-----------------------------\nmethod %s seed %d step %d ' % (args.method, seed_cl, step))

        classes_seen = class_order[: (step+1)*num_classes_step]
        print('classes_seen: ', classes_seen)


        ''' train real data '''
        images_train = torch.cat(images_train_all[:step+1], dim=0).to(args.device)
        labels_train = torch.cat(labels_train_all[:step+1], dim=0).to(args.device)
        print('train data size: ', images_train.shape)

        '''Train synthetic data'''
        mask = torch.isin(label_syn, torch.tensor(classes_seen, device=args.device))
        images_train_syn = image_syn[mask]
        labels_train_syn = label_syn[mask]

        ''' test data '''
        images_test = []
        labels_test = []
        for i in range(len(dst_test)):
            lab = int(dst_test[i][1])
            if lab in classes_seen:
                images_test.append(torch.unsqueeze(dst_test[i][0], dim=0))
                labels_test.append(dst_test[i][1])

        images_test = torch.cat(images_test, dim=0).to(args.device)
        labels_test = torch.tensor(labels_test, dtype=torch.long, device=args.device)
        dst_test_current = TensorDataset(images_test, labels_test)
        testloader = torch.utils.data.DataLoader(dst_test_current, batch_size=256, shuffle=False, num_workers=0)

        print('test set size: ', images_test.shape)


        '''Train model on synthetic dataset'''
        accs = []
        for ep_eval in range(args.num_eval):
            net_eval = get_network(args.model, channel, num_classes, im_size)
            net_eval = net_eval.to(args.device)
            img_syn_eval = copy.deepcopy(images_train_syn.detach())
            lab_syn_eval = copy.deepcopy(labels_train_syn.detach())

            _, acc_train, acc_test = evaluate_synset(ep_eval, net_eval, img_syn_eval, lab_syn_eval, testloader, args)
            del net_eval, img_syn_eval, lab_syn_eval
            gc.collect()  # to reduce memory cost
            accs.append(acc_test)
            results[step, seed_cl*args.num_eval + ep_eval] = acc_test
        print('With synthetic dataset: Evaluate %d random %s, mean = %.4f std = %.4f' % (len(accs), args.model, np.mean(accs), np.std(accs)))


        ''' train model on the newest memory '''
        accs = []
        for ep_eval in range(args.num_eval):
            net_eval = get_network(args.model, channel, num_classes, im_size)
            net_eval = net_eval.to(args.device)
            img_syn_eval = copy.deepcopy(images_train.detach())
            lab_syn_eval = copy.deepcopy(labels_train.detach())

            _, acc_train, acc_test = evaluate_synset(ep_eval, net_eval, img_syn_eval, lab_syn_eval, testloader, args)
            del net_eval, img_syn_eval, lab_syn_eval
            gc.collect()  # to reduce memory cost
            accs.append(acc_test)
            results[step, seed_cl*args.num_eval + ep_eval] = acc_test
        print('With real dataset: Evaluate %d random %s, mean = %.4f std = %.4f' % (len(accs), args.model, np.mean(accs), np.std(accs)))


  results_str = ''
  for step in range(args.steps):
      results_str += '& %.1f$\pm$%.1f  ' % (np.mean(results[step]) * 100, np.std(results[step]) * 100)
  print('\n\n')
  print('%d step learning %s perforamnce:'%(args.steps, args.method))
  print(results_str)
  print('Done')


In [None]:
args = type('', (), {})()
args.method = 'random'
args.dataset = 'MNIST'
args.model = 'ConvNet'
args.ipc=10
args.steps = 3
args.num_eval = 3 # evaluation number
args.epoch_eval_train = 200 # epochs to train a model with synthetic data
args.lr_net = 0.01
args.batch_train = 256
args.data_path = './data'

args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = False # augment images for all methods
args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' # for CIFAR10/100
args.dc_aug_param = None

continual_learning()


method:  random
seed:  0
class_order:  [2, 8, 4, 9, 1, 6, 7, 3, 0, 5]
augmentation strategy: 
 color_crop_cutout_flip_scale_rotate
augmentation parameters: 
 {'aug_mode': 'S', 'prob_flip': 0.5, 'ratio_scale': 1.2, 'ratio_rotate': 15.0, 'ratio_crop_pad': 0.125, 'ratio_cutout': 0.5, 'brightness': 1.0, 'saturation': 2.0, 'contrast': 0.5}

-----------------------------
method random seed 0 step 0 
classes_seen:  [2, 8, 4]
train data size:  torch.Size([30, 1, 28, 28])
test set size:  torch.Size([2988, 1, 28, 28])
[2023-11-21 22:41:17] Evaluate_00: epoch = 0200 train time = 57 s train loss = 0.000449 train acc = 1.0000, test acc = 0.9391
[2023-11-21 22:42:25] Evaluate_01: epoch = 0200 train time = 54 s train loss = 0.000401 train acc = 1.0000, test acc = 0.9357
[2023-11-21 22:43:34] Evaluate_02: epoch = 0200 train time = 56 s train loss = 0.000536 train acc = 1.0000, test acc = 0.9388
With synthetic dataset: Evaluate 3 random ConvNet, mean = 0.9379 std = 0.0015
[2023-11-21 22:44:46] Evaluat

In [None]:
args = type('', (), {})()
args.method = 'random'
args.dataset = 'MHIST'
args.model = 'ConvNet'
args.ipc=10
args.steps = 2
args.num_eval = 3 # evaluation number
args.epoch_eval_train = 200 # epochs to train a model with synthetic data
args.lr_net = 0.01
args.batch_train = 128
args.data_path = './data'

args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = False # augment images for all methods
args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' # for CIFAR10/100
args.dc_aug_param = None
continual_learning()


method:  random
seed:  0
class_order:  [1, 0]
augmentation strategy: 
 color_crop_cutout_flip_scale_rotate
augmentation parameters: 
 {'aug_mode': 'S', 'prob_flip': 0.5, 'ratio_scale': 1.2, 'ratio_rotate': 15.0, 'ratio_crop_pad': 0.125, 'ratio_cutout': 0.5, 'brightness': 1.0, 'saturation': 2.0, 'contrast': 0.5}

-----------------------------
method random seed 0 step 0 
classes_seen:  [1]
train data size:  torch.Size([10, 3, 32, 32])
test set size:  torch.Size([360, 3, 32, 32])
[2023-11-21 22:30:41] Evaluate_00: epoch = 0200 train time = 18 s train loss = 0.000000 train acc = 1.0000, test acc = 1.0000
[2023-11-21 22:31:11] Evaluate_01: epoch = 0200 train time = 28 s train loss = 0.000000 train acc = 1.0000, test acc = 1.0000
[2023-11-21 22:31:32] Evaluate_02: epoch = 0200 train time = 19 s train loss = 0.000000 train acc = 1.0000, test acc = 1.0000
With synthetic dataset: Evaluate 3 random ConvNet, mean = 1.0000 std = 0.0000
[2023-11-21 22:31:53] Evaluate_00: epoch = 0200 train time =