
# **Initial Imports and Drive Mounting**

In [None]:
# Main code adapted from here: https://github.com/VICO-UoE/DatasetCondensation

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from scipy.ndimage.interpolation import rotate as scipyrotate

from PIL import Image
import pandas as pd
from sklearn.preprocessing import LabelEncoder

import cv2

from torch.profiler import profile, record_function, ProfilerActivity


%cd "/content/drive/MyDrive/Colab Notebooks/"

!pip install import-ipynb
import import_ipynb

from networks import MLP, ConvNet, LeNet, AlexNet, AlexNetBN, VGG11, VGG11BN, ResNet18, ResNet18BN_AP, ResNet18BN

from torch.optim.lr_scheduler import CosineAnnealingLR

# from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug

  from scipy.ndimage.interpolation import rotate as scipyrotate


/content/drive/MyDrive/Colab Notebooks
Collecting import-ipynb
  Downloading import_ipynb-0.1.4-py3-none-any.whl (4.1 kB)
Collecting jedi>=0.16 (from IPython->import-ipynb)
  Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi, import-ipynb
Successfully installed import-ipynb-0.1.4 jedi-0.19.1
importing Jupyter notebook from networks.ipynb


# **Helper Functions**

In [3]:
class CustomDataset(Dataset):
    def __init__(self, images, targets, transform=None):
        self.data = torch.from_numpy(images).float()
        self.targets = torch.from_numpy(targets)
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y

    def __len__(self):
        return len(self.data)


def get_dataset(dataset, data_path):
    if dataset == 'MNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        mean = [0.1307]
        std = [0.3081]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
        class_names = [str(c) for c in range(num_classes)]

    elif dataset == 'MNIST Custom':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        mean = [0.1307]
        std = [0.3081]

        transform = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(), transforms.Normalize(mean=mean, std=std)])

        dataset_root = "/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Custom/"
        dst_train = datasets.ImageFolder(root=dataset_root, transform=transform)

        dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

        class_names = [str(c) for c in range(num_classes)]


    elif dataset == 'MHIST Custom':
        channel = 3
        im_size = (128, 128)
        num_classes = 2
        mean = [188, 165, 197]
        std = [45, 58, 38]

        class_names = ["HP", "SSA"]

        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        transformTrain = transforms.Compose([transforms.Resize(im_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])

        dataset_root = "/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Custom/"
        dst_train = datasets.ImageFolder(root=dataset_root, transform=transformTrain)

        test_images = np.load("/content/drive/MyDrive/ECE1512/test_images.npy").astype(np.uint8)
        test_labels = np.load("/content/drive/MyDrive/ECE1512/test_labels.npy").astype(np.uint8)

        test_images = np.asarray(test_images)
        test_labels = np.asarray(test_labels)

        test_images = np.transpose(test_images, (0, 3, 1, 2))

        test_images_resized = []

        for image in test_images:
          newImg = np.transpose(image, (1, 2, 0))
          newImg = cv2.resize(newImg, im_size)

          newImg = (newImg-np.min(newImg))/(np.max(newImg)-np.min(newImg))

          newImg = np.transpose(newImg, (2, 0, 1))
          test_images_resized.append(newImg)

        test_images_resized = np.asarray(test_images_resized)

        print(test_images_resized.shape)

        dst_test = CustomDataset(test_images_resized, test_labels, transform=transform)

        class_names = ["HP", "SSA"]

    elif dataset == 'MHIST':

        use_existing_arrays = True

        if not(use_existing_arrays):
          csv_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/images/labels.csv"
          data_df = pd.read_csv(csv_path)

          image_dir = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/images/images/"

          label_encoder = LabelEncoder()
          data_df['Majority Vote Label'] = label_encoder.fit_transform(data_df['Majority Vote Label'])

          train_data = data_df[data_df['Partition'] == 'train']
          test_data = data_df[data_df['Partition'] == 'test']

          train_images = []
          train_labels = []

          # Load images and labels for training data
          for _, row in train_data.iterrows():
              filename = row['Image Name']
              label = row['Majority Vote Label']
              image_path = os.path.join(image_dir, filename)

              image = Image.open(image_path)
              image_array = np.asarray(image)

              train_images.append(image_array)
              train_labels.append(label)

          # Load images and labels for testing data
          test_images = []
          test_labels = []
          for _, row in test_data.iterrows():
              filename = row['Image Name']
              label = row['Majority Vote Label']
              image_path = os.path.join(image_dir, filename)

              image = Image.open(image_path)
              image_array = np.asarray(image)

              test_images.append(image_array)
              test_labels.append(label)
        else:
          train_images = np.load("/content/drive/MyDrive/ECE1512/train_images.npy").astype(np.uint8)
          test_images = np.load("/content/drive/MyDrive/ECE1512/test_images.npy").astype(np.uint8)
          train_labels = np.load("/content/drive/MyDrive/ECE1512/train_labels.npy").astype(np.uint8)
          test_labels = np.load("/content/drive/MyDrive/ECE1512/test_labels.npy").astype(np.uint8)

        train_images = np.asarray(train_images)
        train_labels = np.asarray(train_labels)
        test_images = np.asarray(test_images)
        test_labels = np.asarray(test_labels)

        train_images = np.transpose(train_images, (0, 3, 1, 2))
        test_images = np.transpose(test_images, (0, 3, 1, 2))

        print("done loading images")

        channel = 3
        im_size = (128, 128)
        num_classes = 2
        mean = [188, 165, 197]
        std = [45, 58, 38]

        # print(train_images.shape)

        # # Reshape the array and resize images
        train_images_resized = []

        for image in train_images:
          newImg = np.transpose(image, (1, 2, 0))
          newImg = cv2.resize(newImg, im_size)

          newImg = (newImg-np.min(newImg))/(np.max(newImg)-np.min(newImg))


          # newImg = (newImg - mean)/std

          newImg = np.transpose(newImg, (2, 0, 1))
          train_images_resized.append(newImg)

        train_images_resized = np.asarray(train_images_resized)

        test_images_resized = []

        for image in test_images:
          newImg = np.transpose(image, (1, 2, 0))
          newImg = cv2.resize(newImg, im_size)

          newImg = (newImg-np.min(newImg))/(np.max(newImg)-np.min(newImg))


          # newImg = (newImg - mean)/std

          newImg = np.transpose(newImg, (2, 0, 1))
          test_images_resized.append(newImg)

        test_images_resized = np.asarray(test_images_resized)

        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])

        dst_train = CustomDataset(train_images_resized, train_labels, transform=transform)

        dst_test = CustomDataset(test_images_resized, test_labels, transform=transform)

        class_names = ["HP", "SSA"]

    elif dataset == 'FashionMNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        mean = [0.2861]
        std = [0.3530]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'SVHN':
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        mean = [0.4377, 0.4438, 0.4728]
        std = [0.1980, 0.2010, 0.1970]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.SVHN(data_path, split='train', download=True, transform=transform)  # no augmentation
        dst_test = datasets.SVHN(data_path, split='test', download=True, transform=transform)
        class_names = [str(c) for c in range(num_classes)]

    elif dataset == 'CIFAR10':
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'CIFAR100':
        channel = 3
        im_size = (32, 32)
        num_classes = 100
        mean = [0.5071, 0.4866, 0.4409]
        std = [0.2673, 0.2564, 0.2762]
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'TinyImageNet':
        channel = 3
        im_size = (64, 64)
        num_classes = 200
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        data = torch.load(os.path.join(data_path, 'tinyimagenet.pt'), map_location='cpu')

        class_names = data['classes']

        images_train = data['images_train']
        labels_train = data['labels_train']
        images_train = images_train.detach().float() / 255.0
        labels_train = labels_train.detach()
        for c in range(channel):
            images_train[:,c] = (images_train[:,c] - mean[c])/std[c]
        dst_train = TensorDataset(images_train, labels_train)  # no augmentation

        images_val = data['images_val']
        labels_val = data['labels_val']
        images_val = images_val.detach().float() / 255.0
        labels_val = labels_val.detach()

        for c in range(channel):
            images_val[:, c] = (images_val[:, c] - mean[c]) / std[c]

        dst_test = TensorDataset(images_val, labels_val)  # no augmentation

    else:
        exit('unknown dataset: %s'%dataset)


    testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)
    return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader



class TensorDataset(Dataset):
    def __init__(self, images, labels): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return self.images.shape[0]



def get_default_convnet_setting():
    net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
    return net_width, net_depth, net_act, net_norm, net_pooling



def get_network(model, channel, num_classes, im_size=(32, 32)):
    torch.random.manual_seed(int(time.time() * 1000) % 100000)
    net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()

    if model == 'MLP':
        net = MLP(channel=channel, num_classes=num_classes)
    elif model == 'ConvNet':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'LeNet':
        net = LeNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNet':
        net = AlexNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNetBN':
        net = AlexNetBN(channel=channel, num_classes=num_classes)
    elif model == 'VGG11':
        net = VGG11( channel=channel, num_classes=num_classes)
    elif model == 'VGG11BN':
        net = VGG11BN(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18':
        net = ResNet18(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18BN_AP':
        net = ResNet18BN_AP(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18BN':
        net = ResNet18BN(channel=channel, num_classes=num_classes)

    elif model == 'ConvNetD1':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD2':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD3':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD4':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetW32':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW64':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW128':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW256':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetAS':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetAR':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetAL':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetASwish':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetASwishBN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetNN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetBN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetLN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetIN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetGN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetNP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none', im_size=im_size)
    elif model == 'ConvNetMP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling', im_size=im_size)
    elif model == 'ConvNetAP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling', im_size=im_size)

    else:
        net = None
        exit('unknown model: %s'%model)

    gpu_num = torch.cuda.device_count()
    if gpu_num>0:
        device = 'cuda'
        if gpu_num>1:
            net = nn.DataParallel(net)
    else:
        device = 'cpu'
    net = net.to(device)

    return net



def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))



def distance_wb(gwr, gws):
    shape = gwr.shape
    if len(shape) == 4: # conv, out*in*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
        gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
    elif len(shape) == 3:  # layernorm, C*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2])
        gws = gws.reshape(shape[0], shape[1] * shape[2])
    elif len(shape) == 2: # linear, out*in
        tmp = 'do nothing'
    elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
        gwr = gwr.reshape(1, shape[0])
        gws = gws.reshape(1, shape[0])
        return torch.tensor(0, dtype=torch.float, device=gwr.device)

    dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
    dis = dis_weight
    return dis



def match_loss(gw_syn, gw_real, args):
    dis = torch.tensor(0.0).to(args.device)

    if args.dis_metric == 'ours':
        for ig in range(len(gw_real)):
            gwr = gw_real[ig]
            gws = gw_syn[ig]
            dis += distance_wb(gwr, gws)

    elif args.dis_metric == 'mse':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = torch.sum((gw_syn_vec - gw_real_vec)**2)

    elif args.dis_metric == 'cos':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)

    else:
        exit('unknown distance function: %s'%args.dis_metric)

    return dis



def get_loops(ipc):
    # Get the two hyper-parameters of outer-loop and inner-loop.
    # The following values are empirically good.
    if ipc == 1:
        outer_loop, inner_loop = 1, 1
    elif ipc == 10:
        outer_loop, inner_loop = 10, 50
    elif ipc == 20:
        outer_loop, inner_loop = 20, 25
    elif ipc == 30:
        outer_loop, inner_loop = 30, 20
    elif ipc == 40:
        outer_loop, inner_loop = 40, 15
    elif ipc == 50:
        outer_loop, inner_loop = 50, 10
    else:
        outer_loop, inner_loop = 0, 0
        exit('loop hyper-parameters are not defined for %d ipc'%ipc)
    return outer_loop, inner_loop



def epoch(mode, dataloader, net, optimizer, criterion, args, aug):
    loss_avg, acc_avg, num_exp = 0, 0, 0
    net = net.to(args.device)
    criterion = criterion.to(args.device)

    if mode == 'train':
        net.train()
    else:
        net.eval()

    for i_batch, datum in enumerate(dataloader):

        if args.training_baseline and not(args.dataset=="MHIST"):
          if mode == 'train':
            img = datum[0][:, np.newaxis, :, :].float().to(args.device)
          else:
            img = datum[0].float().to(args.device)
        else:
            img = datum[0].squeeze(1).float().to(args.device)

        if args.dataset=="MNIST":
          img = datum[0].float().to(args.device)

        if args.dataset=="MNIST Custom":
          img = datum[0].float().to(args.device)

        if args.dataset=="MHIST Custom":
          img = datum[0].squeeze(0).float().to(args.device)




        # if aug:
        #     if args.dsa:
        #         img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
        #     else:
        #         img = augment(img, args.dc_aug_param, device=args.device)
        lab = datum[1].long().to(args.device)
        n_b = lab.shape[0]

        # save_image(img[0], "/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Initial/testing.png")

        output = net(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg



def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args):
    net = net.to(args.device)

    if not(args.dataset=="MNIST Custom" or args.dataset=="MHIST Custom"):
      images_train = images_train.to(args.device)
      labels_train = labels_train.to(args.device)
      dst_train = TensorDataset(images_train, labels_train)
      trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
    else:
      trainloader = torch.utils.data.DataLoader(images_train, batch_size=args.batch_train, shuffle=True, num_workers=0)


    lr = float(args.lr_net)
    Epoch = int(args.epoch_eval_train)
    lr_schedule = [Epoch//2+1]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss().to(args.device)


    start = time.time()
    for ep in range(Epoch+1):

        if args.training_baseline:
          shouldAug = False
        else: shouldAug = True

        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, shouldAug)
        if ep in lr_schedule:
            lr *= 0.1
            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

    time_train = time.time() - start

    shouldAug = True
    loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, shouldAug)
    print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test))

    return net, acc_train, acc_test



def augment(images, dc_aug_param, device):
    # This can be sped up in the future.

    if dc_aug_param != None and dc_aug_param['strategy'] != 'none':
        scale = dc_aug_param['scale']
        crop = dc_aug_param['crop']
        rotate = dc_aug_param['rotate']
        noise = dc_aug_param['noise']
        strategy = dc_aug_param['strategy']

        shape = images.shape
        mean = []
        for c in range(shape[1]):
            mean.append(float(torch.mean(images[:,c])))

        def cropfun(i):
            im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device)
            for c in range(shape[1]):
                im_[c] = mean[c]
            im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i]
            r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0]
            images[i] = im_[:, r:r+shape[2], c:c+shape[3]]

        def scalefun(i):
            h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
            w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
            tmp = F.interpolate(images[i:i + 1], [h, w], )[0]
            mhw = max(h, w, shape[2], shape[3])
            im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
            r = int((mhw - h) / 2)
            c = int((mhw - w) / 2)
            im_[:, r:r + h, c:c + w] = tmp
            r = int((mhw - shape[2]) / 2)
            c = int((mhw - shape[3]) / 2)
            images[i] = im_[:, r:r + shape[2], c:c + shape[3]]

        def rotatefun(i):
            im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean))
            r = int((im_.shape[-2] - shape[-2]) / 2)
            c = int((im_.shape[-1] - shape[-1]) / 2)
            images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device)

        def noisefun(i):
            images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)


        augs = strategy.split('_')

        for i in range(shape[0]):
            choice = np.random.permutation(augs)[0] # randomly implement one augmentation
            if choice == 'crop':
                cropfun(i)
            elif choice == 'scale':
                scalefun(i)
            elif choice == 'rotate':
                rotatefun(i)
            elif choice == 'noise':
                noisefun(i)

    return images



def get_daparam(dataset, model, model_eval, ipc):
    # We find that augmentation doesn't always benefit the performance.
    # So we do augmentation for some of the settings.

    dc_aug_param = dict()
    dc_aug_param['crop'] = 4
    dc_aug_param['scale'] = 0.2
    dc_aug_param['rotate'] = 45
    dc_aug_param['noise'] = 0.001
    dc_aug_param['strategy'] = 'none'

    if dataset == 'MNIST':
        dc_aug_param['strategy'] = 'crop_scale_rotate'

    if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier.
        dc_aug_param['strategy'] = 'crop_noise'

    return dc_aug_param


def get_eval_pool(eval_mode, model, model_eval):
    if eval_mode == 'M': # multiple architectures
        model_eval_pool = ['MLP', 'ConvNet', 'LeNet', 'AlexNet', 'VGG11', 'ResNet18']
    elif eval_mode == 'B':  # multiple architectures with BatchNorm for DM experiments
        model_eval_pool = ['ConvNetBN', 'ConvNetASwishBN', 'AlexNetBN', 'VGG11BN', 'ResNet18BN']
    elif eval_mode == 'W': # ablation study on network width
        model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256']
    elif eval_mode == 'D': # ablation study on network depth
        model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4']
    elif eval_mode == 'A': # ablation study on network activation function
        model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL', 'ConvNetASwish']
    elif eval_mode == 'P': # ablation study on network pooling layer
        model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP']
    elif eval_mode == 'N': # ablation study on network normalization layer
        model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN']
    elif eval_mode == 'S': # itself
        if 'BN' in model:
            print('Attention: Here I will replace BN with IN in evaluation, as the synthetic set is too small to measure BN hyper-parameters.')
        model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model]
    elif eval_mode == 'SS':  # itself
        model_eval_pool = [model]
    else:
        model_eval_pool = [model_eval]
    return model_eval_pool


class ParamDiffAug():
    def __init__(self):
        self.aug_mode = 'S' #'multiple or single'
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = 0.5 # the size would be 0.5x0.5
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5


def set_seed_DiffAug(param):
    if param.latestseed == -1:
        return
    else:
        torch.random.manual_seed(param.latestseed)
        param.latestseed += 1


def DiffAugment(x, strategy='', seed = -1, param = None):
    if strategy == 'None' or strategy == 'none' or strategy == '':
        return x

    if seed == -1:
        param.Siamese = False
    else:
        param.Siamese = True

    param.latestseed = seed

    if strategy:
        if param.aug_mode == 'M': # original
            for p in strategy.split('_'):
                for f in AUGMENT_FNS[p]:
                    x = f(x, param)
        elif param.aug_mode == 'S':
            pbties = strategy.split('_')
            set_seed_DiffAug(param)
            p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
            for f in AUGMENT_FNS[p]:
                x = f(x, param)
        else:
            exit('unknown augmentation mode: %s'%param.aug_mode)
        x = x.contiguous()
    return x


# We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
def rand_scale(x, param):
    # x>1, max scale
    # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
    ratio = param.ratio_scale
    set_seed_DiffAug(param)
    sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
    set_seed_DiffAug(param)
    sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
    theta = [[[sx[i], 0,  0],
            [0,  sy[i], 0],] for i in range(x.shape[0])]
    theta = torch.tensor(theta, dtype=torch.float)
    if param.Siamese: # Siamese augmentation:
        theta[:] = theta[0]
    grid = F.affine_grid(theta, x.shape).to(x.device)
    x = F.grid_sample(x, grid)
    return x


def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
    ratio = param.ratio_rotate
    set_seed_DiffAug(param)
    theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
    theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
        [torch.sin(theta[i]), torch.cos(theta[i]),  0],]  for i in range(x.shape[0])]
    theta = torch.tensor(theta, dtype=torch.float)
    if param.Siamese: # Siamese augmentation:
        theta[:] = theta[0]
    grid = F.affine_grid(theta, x.shape).to(x.device)
    x = F.grid_sample(x, grid)
    return x


def rand_flip(x, param):
    prob = param.prob_flip
    set_seed_DiffAug(param)
    randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
    if param.Siamese: # Siamese augmentation:
        randf[:] = randf[0]
    return torch.where(randf < prob, x.flip(3), x)


def rand_brightness(x, param):
    ratio = param.brightness
    set_seed_DiffAug(param)
    randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:  # Siamese augmentation:
        randb[:] = randb[0]
    x = x + (randb - 0.5)*ratio
    return x


def rand_saturation(x, param):
    ratio = param.saturation
    x_mean = x.mean(dim=1, keepdim=True)
    set_seed_DiffAug(param)
    rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:  # Siamese augmentation:
        rands[:] = rands[0]
    x = (x - x_mean) * (rands * ratio) + x_mean
    return x


def rand_contrast(x, param):
    ratio = param.contrast
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    set_seed_DiffAug(param)
    randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:  # Siamese augmentation:
        randc[:] = randc[0]
    x = (x - x_mean) * (randc + ratio) + x_mean
    return x


def rand_crop(x, param):
    # The image is padded on its surrounding and then cropped.
    ratio = param.ratio_crop_pad
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    if param.Siamese:  # Siamese augmentation:
        translation_x[:] = translation_x[0]
        translation_y[:] = translation_y[0]
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, param):
    ratio = param.ratio_cutout
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    if param.Siamese:  # Siamese augmentation:
        offset_x[:] = offset_x[0]
        offset_y[:] = offset_y[0]
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'crop': [rand_crop],
    'cutout': [rand_cutout],
    'flip': [rand_flip],
    'scale': [rand_scale],
    'rotate': [rand_rotate],
}

In [4]:
def main(method="DC", dataset="CIFAR10", model="ConvNet", ipc=1, eval_mode="S", num_exp=5,
         num_eval=20, epoch_eval_train=300, iteration=1000, lr_img=0.1, lr_net=0.01, batch_real=128,
         batch_train=128, init="noise", dsa_strategy="None", data_path="/content/drive/MyDrive/ECE1512/Project_B/",
         save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/", dis_metric="ours", training_baseline=False):

    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, default=method, help='DC/DSA')
    parser.add_argument('--dataset', type=str, default=dataset, help='dataset')
    parser.add_argument('--model', type=str, default=model, help='model')
    parser.add_argument('--ipc', type=int, default=ipc, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default=eval_mode, help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    parser.add_argument('--num_exp', type=int, default=num_exp, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=num_eval, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=epoch_eval_train, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=iteration, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=lr_img, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=lr_net, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=batch_real, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=batch_train, help='batch size for training networks')
    parser.add_argument('--init', type=str, default=init, help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default=dsa_strategy, help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default=data_path, help='dataset path')
    parser.add_argument('--save_path', type=str, default=save_path, help='path to save results')
    parser.add_argument('--dis_metric', type=str, default=dis_metric, help='distance metric')
    parser.add_argument('--training_baseline', type=bool, default=training_baseline, help='Boolean for baseline training')


    args, unknown = parser.parse_known_args()
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    #args.device = 'cpu'

    if args.device == 'cuda': print("RUNNING ON GPU")

    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.method == 'DSA' else False

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 10).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []


    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    if args.dsa:
                        args.epoch_eval_train = 1000
                        args.dc_aug_param = None
                        print('DSA augmentation strategy: \n', args.dsa_strategy)
                        print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
                    else:
                        args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
                        print('DC augmentation parameters: \n', args.dc_aug_param)

                    if args.dsa or args.dc_aug_param['strategy'] != 'none':
                        args.epoch_eval_train = 1000  # Training with data augmentation needs more epochs.
                    else:
                        args.epoch_eval_train = 300

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())

                count = 0

                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]

                if not(args.dataset == "MHIST"):
                  image_syn_vis[image_syn_vis<0] = 0.0
                  image_syn_vis[image_syn_vis>1] = 1.0
                  save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.

                  count = 0
                  for synImg in image_syn_vis:
                      save_image(synImg, save_path+str(count)+"_"+str(it)+".png")
                      count += 1
                else:

                  image_syn_vis = (image_syn_vis - image_syn_vis.min()) / (image_syn_vis.max() - image_syn_vis.min())

                  print(image_syn_vis.min())
                  print(image_syn_vis.max())

                  image_syn_vis[image_syn_vis<0] = 0.0
                  image_syn_vis[image_syn_vis>1] = 1.0

                  save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.

                  count = 0
                  for synImg in image_syn_vis:
                      save_image(synImg, save_path+str(count)+"_"+str(it)+".png")
                      count += 1



            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None  # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.


            for ol in range(args.outer_loop):

                ''' freeze the running mu and sigma for BatchNorm layers '''
                # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
                # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
                # This would make the training with BatchNorm layers easier.

                BN_flag = False
                BNSizePC = 16  # for batch normalization
                for module in net.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train() # for updating the mu, sigma of BatchNorm
                    output_real = net(img_real) # get running mu, sigma
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer


                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break


                ''' update network '''
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    if args.dsa:
                        shouldAug = False
                    else:
                        shouldAug = False
                    epoch('train', trainloader, net, optimizer_net, criterion, args, shouldAug)


            loss_avg /= (num_classes*args.outer_loop)

            if it%10 == 0:
                print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))


    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))

In [6]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

# **Project Questions**

In [35]:
#2A

# Train MNIST using ConvNet 3 to establish baseline
args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST"
}

args.update(args_dict)

In [36]:
# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ConvNet", channel_mnist, num_classes_mnist, im_size_mnist)

evaluate_synset(1, mnist_network, dst_train_mnist.data, dst_train_mnist.targets, testloader_mnist, args)

[2023-12-10 20:56:13] Evaluate_01: epoch = 0020 train time = 119 s train loss = 0.004321 train acc = 0.9997, test acc = 0.9950


(ConvNet(
   (features): Sequential(
     (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
     (1): GroupNorm(128, 128, eps=1e-05, affine=True)
     (2): ReLU(inplace=True)
     (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): GroupNorm(128, 128, eps=1e-05, affine=True)
     (6): ReLU(inplace=True)
     (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): GroupNorm(128, 128, eps=1e-05, affine=True)
     (10): ReLU(inplace=True)
     (11): AvgPool2d(kernel_size=2, stride=2, padding=0)
   )
   (classifier): Linear(in_features=2048, out_features=10, bias=True)
 ),
 0.9997,
 0.995)

In [17]:
#MNIST FLOPS

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST"
}

args.update(args_dict)

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ConvNet", channel_mnist, num_classes_mnist, im_size_mnist)

# Create dummy input
dummy_input = torch.randn(1, 1, 28, 28).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mnist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         1.04%      32.000us        58.34%       1.797ms     599.000us             3     96731.136  
                                            aten::addmm         3.83%     118.000us         5.06%     156.000us     156.000us             1        40.960  
                                        model_inference        17.08%     526.000us       100.00%       3.080ms       3.080ms             1            --  
                                      aten::convolution         

In [123]:
# Train MHIST using ConvNet to establish baseline
mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ConvNet", channel_mhist, num_classes_mhist, im_size_mhist)

evaluate_synset(1, mhist_network, dst_train_mhist.data, dst_train_mhist.targets, testloader_mhist, args)

done loading images
[2023-12-10 22:17:49] Evaluate_01: epoch = 0020 train time = 27 s train loss = 0.138041 train acc = 0.9706, test acc = 0.7410


(ConvNet(
   (features): Sequential(
     (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): GroupNorm(128, 128, eps=1e-05, affine=True)
     (2): ReLU(inplace=True)
     (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): GroupNorm(128, 128, eps=1e-05, affine=True)
     (6): ReLU(inplace=True)
     (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): GroupNorm(128, 128, eps=1e-05, affine=True)
     (10): ReLU(inplace=True)
     (11): AvgPool2d(kernel_size=2, stride=2, padding=0)
   )
   (classifier): Linear(in_features=32768, out_features=2, bias=True)
 ),
 0.9705747126436781,
 0.7410440122824974)

In [12]:
#MHIST FLOPS

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST"
}

args.update(args_dict)

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ConvNet", channel_mhist, num_classes_mhist, im_size_mhist)

# Create dummy input
dummy_input = torch.randn(1, 3, 128, 128).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mhist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

done loading images
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         0.10%      20.000us        65.57%      13.020ms       4.340ms             3   1623195.648  
                                            aten::addmm        27.26%       5.412ms        27.53%       5.467ms       5.467ms             1       131.072  
                                        model_inference         3.33%     662.000us       100.00%      19.856ms      19.856ms             1            --  
                                      aten::

In [6]:
#2B

#MNIST

main(dataset="MNIST", model="ConvNet", iteration = 10, batch_real=256, batch_train=256, num_exp=1,
     ipc=10, init="real", save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Initial/")

RUNNING ON GPU
eval_it_pool:  [0]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 10, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'None', 'data_path': '/content/drive/MyDrive/ECE1512/Project_B/', 'save_path': '/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Initial/', 'dis_metric': 'ours', 'training_baseline': False, 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x7cd95210c970>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
class c = 9: 5949 real images
real images channel 0, 

  label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]


[2023-12-10 16:07:40] Evaluate_00: epoch = 1000 train time = 39 s train loss = 0.003873 train acc = 1.0000, test acc = 0.9558
[2023-12-10 16:08:14] Evaluate_01: epoch = 1000 train time = 31 s train loss = 0.005113 train acc = 1.0000, test acc = 0.9577
[2023-12-10 16:08:48] Evaluate_02: epoch = 1000 train time = 31 s train loss = 0.006244 train acc = 1.0000, test acc = 0.9538
[2023-12-10 16:09:21] Evaluate_03: epoch = 1000 train time = 31 s train loss = 0.003648 train acc = 1.0000, test acc = 0.9559
[2023-12-10 16:09:55] Evaluate_04: epoch = 1000 train time = 31 s train loss = 0.005302 train acc = 1.0000, test acc = 0.9524
[2023-12-10 16:10:29] Evaluate_05: epoch = 1000 train time = 31 s train loss = 0.005753 train acc = 1.0000, test acc = 0.9581
[2023-12-10 16:11:03] Evaluate_06: epoch = 1000 train time = 31 s train loss = 0.008147 train acc = 1.0000, test acc = 0.9568
[2023-12-10 16:11:36] Evaluate_07: epoch = 1000 train time = 31 s train loss = 0.005655 train acc = 1.0000, test acc =

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe',
  ret = ret.dtype.type(ret / rcount)


In [124]:
#MHIST

main(dataset="MHIST", model="ConvNet", iteration = 10, batch_real=128, batch_train=128, num_exp=1,
     ipc=50, init="real", save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Initial/")

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ConvNet", channel_mhist, num_classes_mhist, im_size_mhist)



RUNNING ON GPU
eval_it_pool:  [0]
done loading images

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MHIST', 'model': 'ConvNet', 'ipc': 50, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 10, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 128, 'batch_train': 128, 'init': 'real', 'dsa_strategy': 'None', 'data_path': '/content/drive/MyDrive/ECE1512/Project_B/', 'save_path': '/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Initial/', 'dis_metric': 'ours', 'training_baseline': False, 'outer_loop': 50, 'inner_loop': 10, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x7baa2dced210>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 1545 real images
class c = 1: 630 real images
real images channel 0, mean = 0.7195, std = 0.1973
real images channel 1, mean = 0.6218, std = 0.2513
real images channel 2, mean = 0.7605, std = 0.1662
initialize synthetic data from random real images
[2023-12-10 22:19:08] training begins
---

In [29]:
#2D MNIST (Longer)

main(dataset="MNIST", model="ConvNet", iteration = 10, batch_real=256, batch_train=256, num_exp=1,
     ipc=10, save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Noise/")

RUNNING ON GPU
eval_it_pool:  [0, 10]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 10, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': '/content/drive/MyDrive/ECE1512/Project_B/', 'save_path': '/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Noise/', 'dis_metric': 'ours', 'training_baseline': False, 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x7c157558a230>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
class c = 9: 5949 real images
real images channel 

In [7]:
#2D MHIST (Longer)

main(dataset="MHIST", model="ConvNet", iteration = 10, batch_real=128, batch_train=128, num_exp=1,
     ipc=50, init="noise", save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Noise/")

RUNNING ON GPU
eval_it_pool:  [0, 10]
done loading images

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MHIST', 'model': 'ConvNet', 'ipc': 50, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 10, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 128, 'batch_train': 128, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': '/content/drive/MyDrive/ECE1512/Project_B/', 'save_path': '/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Noise/', 'dis_metric': 'ours', 'training_baseline': False, 'outer_loop': 50, 'inner_loop': 10, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x7c1575bf80d0>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 1545 real images
class c = 1: 630 real images
real images channel 0, mean = 0.7195, std = 0.1973
real images channel 1, mean = 0.6218, std = 0.2513
real images channel 2, mean = 0.7605, std = 0.1662
initialize synthetic data from random noise
[2023-12-10 22:58:19] training begins
------

In [45]:
#2D MNIST Extreme

main(dataset="MNIST", model="ConvNet", iteration = 100, batch_real=128, batch_train=128, num_exp=1,
     ipc=10, init="noise", save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Noise Extreme/")

RUNNING ON GPU
eval_it_pool:  [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 128, 'batch_train': 128, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': '/content/drive/MyDrive/ECE1512/Project_B/', 'save_path': '/content/drive/MyDrive/ECE1512/Project_B/Runs/MNIST Noise Extreme/', 'dis_metric': 'ours', 'training_baseline': False, 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x7d6e44b70a90>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
clas

In [8]:
#2D MHIST Extreme

main(dataset="MHIST", model="ConvNet", iteration = 100, batch_real=128, batch_train=128, num_exp=1,
     ipc=50, init="noise", save_path="/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Noise Extreme/")

RUNNING ON GPU
eval_it_pool:  [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
done loading images

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MHIST', 'model': 'ConvNet', 'ipc': 50, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 20, 'epoch_eval_train': 300, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 128, 'batch_train': 128, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': '/content/drive/MyDrive/ECE1512/Project_B/', 'save_path': '/content/drive/MyDrive/ECE1512/Project_B/Runs/MHIST Noise Extreme/', 'dis_metric': 'ours', 'training_baseline': False, 'outer_loop': 50, 'inner_loop': 10, 'device': 'cuda', 'dsa_param': <__main__.ParamDiffAug object at 0x7d6e6e638340>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 1545 real images
class c = 1: 630 real images
real images channel 0, mean = 0.7195, std = 0.1973
real images channel 1, mean = 0.6218, std = 0.2513
real images channel 2, mean = 0.7605, std = 0.1662
initialize synthetic data from random nois

  label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]


[2023-12-11 00:06:40] Evaluate_00: epoch = 0300 train time = 17 s train loss = 0.000005 train acc = 1.0000, test acc = 0.5189
[2023-12-11 00:06:50] Evaluate_01: epoch = 0300 train time = 10 s train loss = 0.000000 train acc = 1.0000, test acc = 0.3470
[2023-12-11 00:07:00] Evaluate_02: epoch = 0300 train time = 10 s train loss = 0.000000 train acc = 1.0000, test acc = 0.3644
[2023-12-11 00:07:11] Evaluate_03: epoch = 0300 train time = 10 s train loss = 0.000008 train acc = 1.0000, test acc = 0.4811
[2023-12-11 00:07:21] Evaluate_04: epoch = 0300 train time = 10 s train loss = 0.000000 train acc = 1.0000, test acc = 0.5241
[2023-12-11 00:07:31] Evaluate_05: epoch = 0300 train time = 10 s train loss = 0.000000 train acc = 1.0000, test acc = 0.5670
[2023-12-11 00:07:42] Evaluate_06: epoch = 0300 train time = 10 s train loss = 0.000000 train acc = 1.0000, test acc = 0.4596
[2023-12-11 00:07:52] Evaluate_07: epoch = 0300 train time = 10 s train loss = 0.000000 train acc = 1.0000, test acc =

KeyboardInterrupt: ignored

In [35]:
#2E - Train on synthetic set

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ConvNet", channel_mnist, num_classes_mnist, im_size_mnist)

evaluate_synset(1, mnist_network, dst_train_mnist, dst_train_mnist.targets, testloader_mnist, args)

[2023-12-11 01:12:17] Evaluate_01: epoch = 0020 train time = 9 s train loss = 0.025028 train acc = 1.0000, test acc = 0.9063


(ConvNet(
   (features): Sequential(
     (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
     (1): GroupNorm(128, 128, eps=1e-05, affine=True)
     (2): ReLU(inplace=True)
     (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): GroupNorm(128, 128, eps=1e-05, affine=True)
     (6): ReLU(inplace=True)
     (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): GroupNorm(128, 128, eps=1e-05, affine=True)
     (10): ReLU(inplace=True)
     (11): AvgPool2d(kernel_size=2, stride=2, padding=0)
   )
   (classifier): Linear(in_features=2048, out_features=10, bias=True)
 ),
 1.0,
 0.9063)

In [42]:
#2E - train on synthetic set (MHIST)

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ConvNet", channel_mhist, num_classes_mhist, im_size_mhist)

evaluate_synset(1, mhist_network, dst_train_mhist, dst_train_mhist.targets, testloader_mhist, args)

[2023-12-11 01:21:43] Evaluate_01: epoch = 0020 train time = 4 s train loss = 0.693508 train acc = 0.5000, test acc = 0.3685


(ConvNet(
   (features): Sequential(
     (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): GroupNorm(128, 128, eps=1e-05, affine=True)
     (2): ReLU(inplace=True)
     (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): GroupNorm(128, 128, eps=1e-05, affine=True)
     (6): ReLU(inplace=True)
     (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): GroupNorm(128, 128, eps=1e-05, affine=True)
     (10): ReLU(inplace=True)
     (11): AvgPool2d(kernel_size=2, stride=2, padding=0)
   )
   (classifier): Linear(in_features=32768, out_features=2, bias=True)
 ),
 0.5,
 0.368474923234391)

In [46]:
#3 Cross architecture performance

# Train MNIST using ResNet for comparison
args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ResNet18", channel_mnist, num_classes_mnist, im_size_mnist)

evaluate_synset(1, mnist_network, dst_train_mnist, dst_train_mnist.targets, testloader_mnist, args)

[2023-12-11 02:31:39] Evaluate_01: epoch = 0020 train time = 3 s train loss = 0.177401 train acc = 1.0000, test acc = 0.8144


(ResNet(
   (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): GroupNorm(64, 64, eps=1e-05, affine=True)
       (shortcut): Sequential()
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): GroupNorm(64, 64, eps=1e-05, affine=True)
       (shortcut): Sequential()
     )
   )
   (layer2): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 128, kernel_size=(3, 3),

In [18]:
#RESNET MNIST FLOPS

# Train MNIST using ResNet for comparison
args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ResNet18", channel_mnist, num_classes_mnist, im_size_mnist)

# Create dummy input
dummy_input = torch.randn(1, 1, 28, 28).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mnist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         6.23%       2.759ms        80.29%      35.543ms       1.777ms            20    911591.424  
                                            aten::addmm         0.43%     192.000us         0.54%     240.000us     240.000us             1        10.240  
                                        model_inference         3.93%       1.740ms       100.00%      44.267ms      44.267ms             1            --  
                                      aten::convolution         

In [66]:
#2E

# Train MHIST using ResNet18 for comparison

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

# Train MHIST using ResNet18 to compare performance

# set im_size to 32, 32 before running

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ResNet18", channel_mhist, num_classes_mhist, im_size_mhist)

evaluate_synset(1, mhist_network, dst_train_mhist, dst_train_mhist.targets, testloader_mhist, args)

(977, 3, 32, 32)
[2023-12-11 02:44:10] Evaluate_01: epoch = 0020 train time = 5 s train loss = 0.723770 train acc = 0.5000, test acc = 0.3685


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): GroupNorm(64, 64, eps=1e-05, affine=True)
       (shortcut): Sequential()
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): GroupNorm(64, 64, eps=1e-05, affine=True)
       (shortcut): Sequential()
     )
   )
   (layer2): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 128, kernel_size=(3, 3),

In [19]:
#RESNET MHIST FLOPS

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

# Train MHIST using ResNet18 to compare performance

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ResNet18", channel_mhist, num_classes_mhist, im_size_mhist)

# Create dummy input
dummy_input = torch.randn(1, 3, 32, 32).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mhist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

(977, 3, 128, 128)
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         0.76%     264.000us        83.32%      29.015ms       1.451ms            20   1110835.200  
                                            aten::addmm         0.49%     172.000us         0.60%     208.000us     208.000us             1         2.048  
                                        model_inference         6.04%       2.104ms       100.00%      34.823ms      34.823ms             1            --  
                                      aten::c

In [67]:
#3 Cross architecture performance

# Train MNIST using MLP for comparison
args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("MLP", channel_mnist, num_classes_mnist, im_size_mnist)

evaluate_synset(1, mnist_network, dst_train_mnist, dst_train_mnist.targets, testloader_mnist, args)

[2023-12-11 02:50:44] Evaluate_01: epoch = 0020 train time = 2 s train loss = 0.176212 train acc = 0.9800, test acc = 0.7492


(MLP(
   (fc_1): Linear(in_features=784, out_features=128, bias=True)
   (fc_2): Linear(in_features=128, out_features=128, bias=True)
   (fc_3): Linear(in_features=128, out_features=10, bias=True)
 ),
 0.98,
 0.7492)

In [20]:
#MLP MNIST FLOPS

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("MLP", channel_mnist, num_classes_mnist, im_size_mnist)

# Create dummy input
dummy_input = torch.randn(1, 1, 28, 28).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mnist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::addmm        54.48%       1.010ms        64.46%       1.195ms     398.333us             3       236.032  
                                        model_inference        19.69%     365.000us       100.00%       1.854ms       1.854ms             1            --  
                                             aten::view         0.54%      10.000us         0.54%      10.000us      10.000us             1            --  
                                           aten::linear         

In [69]:
#3

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

# Train MHIST using ResNet18 to compare performance

# set im_size to 32, 32 before running

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("MLP", channel_mhist, num_classes_mhist, im_size_mhist)

evaluate_synset(1, mhist_network, dst_train_mhist, dst_train_mhist.targets, testloader_mhist, args)

(977, 3, 32, 32)
[2023-12-11 02:51:24] Evaluate_01: epoch = 0020 train time = 3 s train loss = 0.693711 train acc = 0.5000, test acc = 0.3685


(MLP(
   (fc_1): Linear(in_features=3072, out_features=128, bias=True)
   (fc_2): Linear(in_features=128, out_features=128, bias=True)
   (fc_3): Linear(in_features=128, out_features=2, bias=True)
 ),
 0.5,
 0.368474923234391)

In [24]:
#MLP MHIST FLOPS

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

# Train MHIST using ResNet18 to compare performance

# set im_size to 32, 32 before running

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("MLP", channel_mhist, num_classes_mhist, im_size_mhist)

# Create dummy input
dummy_input = torch.randn(1, 3, 32, 32).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mhist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

(977, 3, 128, 128)
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::addmm        54.43%     921.000us        61.11%       1.034ms     344.667us             3       819.712  
                                        model_inference        23.17%     392.000us       100.00%       1.692ms       1.692ms             1            --  
                                             aten::view         0.89%      15.000us         0.89%      15.000us      15.000us             1            --  
                                           at

In [72]:
#3 Cross architecture performance

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ConvNetD4", channel_mnist, num_classes_mnist, im_size_mnist)

evaluate_synset(1, mnist_network, dst_train_mnist, dst_train_mnist.targets, testloader_mnist, args)

[2023-12-11 02:53:11] Evaluate_01: epoch = 0020 train time = 3 s train loss = 0.134468 train acc = 1.0000, test acc = 0.8930


(ConvNet(
   (features): Sequential(
     (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
     (1): GroupNorm(128, 128, eps=1e-05, affine=True)
     (2): ReLU(inplace=True)
     (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): GroupNorm(128, 128, eps=1e-05, affine=True)
     (6): ReLU(inplace=True)
     (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): GroupNorm(128, 128, eps=1e-05, affine=True)
     (10): ReLU(inplace=True)
     (11): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (13): GroupNorm(128, 128, eps=1e-05, affine=True)
     (14): ReLU(inplace=True)
     (15): AvgPool2d(kernel_size=2, stride=2, padding=0)
   )
   (classifier): Linear(in_features=512, out_features=10, bias=True)
 ),
 1.0,
 0.893)

In [26]:
#3 Cross architecture performance

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MNIST Custom"
}

args.update(args_dict)

# train MNIST

mnist_path = "/content/drive/MyDrive/ECE1512/Project_B/MNIST Dataset/"

MNIST_dataset = get_dataset("MNIST Custom", mnist_path)
channel_mnist = MNIST_dataset[0]
im_size_mnist = MNIST_dataset[1]
num_classes_mnist = MNIST_dataset[2]
class_names_mnist = MNIST_dataset[3]
mean_mnist = MNIST_dataset[4]
std_mnist = MNIST_dataset[5]
dst_train_mnist = MNIST_dataset[6]
dst_test_mnist = MNIST_dataset[7]
testloader_mnist = MNIST_dataset[8]

mnist_network = get_network("ConvNetD4", channel_mnist, num_classes_mnist, im_size_mnist)

# Create dummy input
dummy_input = torch.randn(1, 1, 28, 28).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mnist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         0.32%      27.000us        79.64%       6.736ms       1.684ms             4    101449.728  
                                            aten::addmm         2.00%     169.000us         2.42%     205.000us     205.000us             1        10.240  
                                        model_inference         8.38%     709.000us       100.00%       8.458ms       8.458ms             1            --  
                                      aten::convolution         

In [71]:
#3D

# Train MHIST using ConvNet 4 for comparison

args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

# Train MHIST using ResNet18 to compare performance

# set im_size to 32, 32 before running

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ConvNetD4", channel_mhist, num_classes_mhist, im_size_mhist)

evaluate_synset(1, mhist_network, dst_train_mhist, dst_train_mhist.targets, testloader_mhist, args)

(977, 3, 128, 128)
[2023-12-11 02:52:34] Evaluate_01: epoch = 0020 train time = 5 s train loss = 0.705282 train acc = 0.5000, test acc = 0.4964


(ConvNet(
   (features): Sequential(
     (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): GroupNorm(128, 128, eps=1e-05, affine=True)
     (2): ReLU(inplace=True)
     (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): GroupNorm(128, 128, eps=1e-05, affine=True)
     (6): ReLU(inplace=True)
     (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): GroupNorm(128, 128, eps=1e-05, affine=True)
     (10): ReLU(inplace=True)
     (11): AvgPool2d(kernel_size=2, stride=2, padding=0)
     (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (13): GroupNorm(128, 128, eps=1e-05, affine=True)
     (14): ReLU(inplace=True)
     (15): AvgPool2d(kernel_size=2, stride=2, padding=0)
   )
   (classifier): Linear(in_features=8192, out_features=2, bias=True)
 ),
 0.5,
 0.4964

In [27]:
args = AttrDict()
args_dict = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dis_metric": "cos",
    "lr_net": 0.01,
    "epoch_eval_train": 20,
    "batch_train": 32,
    "training_baseline": True,
    "dataset": "MHIST Custom"
}

args.update(args_dict)

# Train MHIST using ResNet18 to compare performance

# set im_size to 32, 32 before running

mhist_path = "/content/drive/MyDrive/ECE1512/Project_B/MHIST Dataset/"

MHIST_dataset = get_dataset("MHIST Custom", mhist_path)
channel_mhist = MHIST_dataset[0]
im_size_mhist = MHIST_dataset[1]
num_classes_mhist = MHIST_dataset[2]
class_names_mhist = MHIST_dataset[3]
mean_mhist = MHIST_dataset[4]
std_mhist = MHIST_dataset[5]
dst_train_mhist = MHIST_dataset[6]
dst_test_mhist = MHIST_dataset[7]
testloader_mhist = MHIST_dataset[8]

mhist_network = get_network("ConvNetD4", channel_mhist, num_classes_mhist, im_size_mhist)

# Create dummy input
dummy_input = torch.randn(1, 3, 128, 128).to('cuda')

# Use profiler to record FLOPS
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_flops = True) as prof:
    with record_function("model_inference"):
        mhist_network(dummy_input)

# Print the FLOPS
print(prof.key_averages().table(sort_by="flops", row_limit=-1))

(977, 3, 128, 128)
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         0.53%      21.000us        53.72%       2.138ms     534.500us             4   1698693.120  
                                            aten::addmm         5.90%     235.000us         6.96%     277.000us     277.000us             1        32.768  
                                        model_inference        17.91%     713.000us       100.00%       3.980ms       3.980ms             1            --  
                                      aten::c

In [None]:
# APPLICATION

# See separate application notebook