<a href="https://colab.research.google.com/github/VictoriaDraganova/Dataset-Condensation-with-Gradient-Matching/blob/main/CW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# Execute this code block to install dependencies when running on colab
try:
    import torch
except:
    from os.path import exists
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

try: 
    import torchbearer
except:
    !pip install torchbearer

In [2]:
# automatically reload external modules if they change
%load_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchbearer
import tqdm.notebook as tq
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchbearer import Trial
import numpy as np
import copy
from torch.utils.data import Dataset
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.utils import save_image
import os
import statistics as st
import os.path
from os import path
import math

In [3]:
n_classes = 0
device = "cuda:0" if torch.cuda.is_available() else "cpu"
channel = 1
im_size = []
trainset = []
testset = []
trainset_copy = []
images_all = []
labels_all = []
indices_class = []

In [4]:
class Synthetic(Dataset):
    def __init__(self, data, targets):
        self.data = data.detach().float()
        self.targets = targets.detach()

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

    def __len__(self):
        return self.data.shape[0]

In [5]:
def sample_batch(data):
  batches = DataLoader(data, batch_size=256, shuffle=True)
  for data, target in (batches):
    data, target = data.to(device), target.to(device)
    return data, target

In [6]:
def updateNetwork(optimizer, steps, loss_function, net, syn_data_data, syn_data_target):
  for s in range(steps):
    net.train()
    prediction_syn = net(syn_data_data)
    loss_syn = loss_function(prediction_syn, syn_data_target)
    optimizer.zero_grad()
    loss_syn.backward()
    optimizer.step()

In [11]:
#based on author's published code
def distance(grad1, grad2):
  dist = torch.tensor(0.0).to(device)
  for gr, gs in zip(grad1, grad2):
    shape=gr.shape
    if len(shape) == 4: 
        gr = gr.reshape(shape[0], shape[1] * shape[2] * shape[3])
        gs = gs.reshape(shape[0], shape[1] * shape[2] * shape[3])
    elif len(shape) == 3:  
        gr = gr.reshape(shape[0], shape[1] * shape[2])
        gs = gs.reshape(shape[0], shape[1] * shape[2])
    elif len(shape) == 2: 
        tmp = 'do nothing'
    elif len(shape) == 1: 
        gr = gr.reshape(1, shape[0])
        gs = gs.reshape(1, shape[0])
        continue
    dis_weight = torch.sum(1 - torch.sum(gr * gs, dim=-1) / (torch.norm(gr, dim=-1) * torch.norm(gs, dim=-1)+ 0.000001))
    dist+=dis_weight
  return dist

In [12]:
#from author's published code
def get_images(c, n): # get random n images from class c
    idx_shuffle = np.random.permutation(indices_class[c])[:n]
    return images_all[idx_shuffle]

In [13]:
#create synthetic data
def train_synthetic(model, dataset, images_per_class,  iterations, network_steps):
  synthetic_datas = []
  T = images_per_class
  K = iterations
  for i in range(1): #to generate 1 synthetic datasets
    #create synthetic data
    data_syn = torch.randn(size=(n_classes*T, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=device)
    targets_syn = torch.tensor([np.ones(T)*i for i in range(n_classes)], dtype=torch.long, requires_grad=False,  device=device).view(-1) 

    #optimizer for image
    optimizer_img = torch.optim.SGD([data_syn, ], lr=0.1) # optimizer_img for synthetic data; only update synthetic image, labels don't change
    optimizer_img.zero_grad()
    loss_function = nn.CrossEntropyLoss().to(device)

    #training synthetic data
    for k in tq.tqdm(range(K)):
      net = new_network(model).to(device)
      net.train()
      net_parameters = list(net.parameters())
      optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_net for network
      optimizer_net.zero_grad()
      loss_avg = 0
      for t in range(T):
        loss = torch.tensor(0.0).to(device)
        for c in range(n_classes):
          img_real = get_images(c, 256)
          targets_real = torch.ones((img_real.shape[0],), device=device, dtype=torch.long) * c
          prediction_real = net(img_real) # makes prediction
          
          loss_real = loss_function(prediction_real, targets_real) # computes the cross entropy loss
          gw_real = torch.autograd.grad(loss_real, net_parameters) # returns the sum of the gradients of the loss wrt the network parameters

          data_synth = data_syn[c*T:(c+1)*T].reshape((T, channel, im_size[0], im_size[1]))
          targets_synth = torch.ones((T,), device=device, dtype=torch.long) * c
          prediction_syn = net(data_synth)
          loss_syn = loss_function(prediction_syn, targets_synth)
          gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

          dist = distance(gw_syn, gw_real)
          loss+=dist

        optimizer_img.zero_grad()
        loss.backward()
        optimizer_img.step()
        loss_avg += loss.item()

        if t == T - 1:
          break

        updateNetwork(optimizer_net, network_steps, loss_function, net, data_syn, targets_syn)
        
      loss_avg /= (n_classes*T)
      if k%10 == 0:
            print('iter = %.4f, loss = %.4f' % (k, loss_avg))
            # model_save_name = 'data_syn.pt'
            # path = F"/content/gdrive/MyDrive/{model_save_name}"  #to save synthetic data
            # torch.save(data_syn, path)
    synthetic_datas.append(data_syn)

    print('Synthetic %d created ' % (i))  

  return synthetic_datas


In [14]:
#evaluation of synthetic data produced
def evaluation(model, all_synthetic_data, images_per_class):
  accuracies = []
  targets_syn = torch.tensor([np.ones(images_per_class)*i for i in range(n_classes)], dtype=torch.long, requires_grad=False,  device=device).view(-1) 
  for data in all_synthetic_data:
    loss_function = nn.CrossEntropyLoss().to(device)
    for it in range(20): #number of random models for evaluation
      print(it)
      net = new_network(model).to(device)
      net.train()
      net_parameters = list(net.parameters())
      optimizer_train = torch.optim.SGD(net.parameters(), lr=0.01) 
      optimizer_train.zero_grad()
      trial = Trial(net,optimizer=optimizer_train, criterion=loss_function, metrics=['loss', 'accuracy'], verbose=0).to(device)
      syn_data_whole = Synthetic(data, targets_syn)
      train_loader = DataLoader(syn_data_whole, batch_size=256, shuffle=True)
      test_loader = DataLoader(testset, batch_size=256, shuffle=False)
      trial.with_generators(train_loader, test_generator=test_loader)
      trial.run(epochs=300)
      results = trial.evaluate(data_key=torchbearer.TEST_DATA)
      print()
      print(results)
      accuracies.append(results['test_acc'])
  
  average_acc = sum(accuracies)/len(accuracies)
  std_acc = st.pstdev(accuracies)
  print("Model is: ", model)
  print("Standard deviation is : " , std_acc)
  print("Average is : " ,average_acc)


In [15]:
def createData(dataset):

  global im_size
  global trainset
  global testset
  global trainset_copy
  global n_classes
  global channel
  global images_all
  global labels_all
  global indices_class

  if dataset == "MNIST":
    !wget https://artist-cloud.ecs.soton.ac.uk/s/sFkQ7HYOekDoDEG/download
    !unzip download
    !mv mnist MNIST
    from torchvision.datasets import MNIST
    mean = [0.1307]
    std = [0.3015]

    transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=mean, std=std)
    ])
    trainset = MNIST(".", train=True, download=True, transform=transform)
    testset = MNIST(".", train=False, download=True, transform=transform)
    trainset_copy = MNIST(".", train=True, download=True, transform=transform)
    n_classes = 10
    channel = 1
    im_size = [28,28]

  elif dataset == "FashionMNIST":
    from torchvision.datasets import FashionMNIST
    mean = [0.2860]
    std =  [0.3205]
    transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=mean, std=std)
    ])
    trainset = FashionMNIST(".", train=True, download=True, transform=transform)
    testset = FashionMNIST(".", train=False, download=True, transform=transform)
    trainset_copy = FashionMNIST(".", train=True, download=True, transform=transform)
    n_classes = 10
    channel = 1
    im_size = [28,28]

  elif dataset == "SVHN":
    from torchvision.datasets import SVHN
    mean = [0.4377, 0.4438, 0.4728]
    std = [0.1201, 0.1231, 0.1052]
    transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=mean, std=std)
    ])
    trainset = SVHN(".", split='train', transform=transform, download=True)
    testset = SVHN(".", split='test', transform=transform, download=True)
    trainset_copy = SVHN(".", split='test', transform=transform, download=True)
    n_classes = 10
    channel = 3
    im_size = [32,32]

  elif dataset == "CIFAR10":
    from torchvision.datasets import CIFAR10
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=mean, std=std)
    ])
    trainset = CIFAR10(".", train=True, download=True, transform=transform)
    testset = CIFAR10(".", train=False, download=True, transform=transform)
    trainset_copy = CIFAR10(".", train=True, download=True, transform=transform)
    n_classes = 10
    channel = 3
    im_size = [32,32]

  #from author's published code
  indices_class = [[] for c in range(n_classes)]
  images_all = [torch.unsqueeze(trainset[i][0], dim=0) for i in range(len(trainset))]
  labels_all = [trainset[i][1] for i in range(len(trainset))]
  for i, lab in enumerate(labels_all):
      indices_class[lab].append(i)
  images_all = torch.cat(images_all, dim=0).to(device)
  labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)

**Networks**

In [16]:
#to calculate image output size
def calculate(size, kernel, stride, padding):
  return int(((size+(2*padding)-kernel)/stride) + 1)

In [20]:
#based on https://cs231n.github.io/convolutional-networks/
class CNN(torch.nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    outsize = im_size[0]
    self.conv1 = nn.Conv2d(in_channels=channel, out_channels=128, kernel_size=3, padding=1) #32*32
    outsize = calculate(outsize,3,1,1)
    self.norm1 = nn.GroupNorm(128, 128)
    self.avg_pooling1 = nn.AvgPool2d(kernel_size=2, stride=2) # (n+2p-f)/s+1 => 32+0-2/2 + 1 =16
    outsize = calculate(outsize,2,2,0)
    self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) #out = (n+2p-f)/s+1 => 16+2-3/1 + 1 => 16
    outsize = calculate(outsize,3,1,1)
    self.norm2 = nn.GroupNorm(128, 128)
    self.avg_pooling2 = nn.AvgPool2d(kernel_size=2, stride=2) #out = (n+2p-f)/s+1 => 16+0-2/2 +1 => 8
    outsize = calculate(outsize,2,2,0)
    self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) #out = (n+2p-f)/s+1 => 8+2-3/1 + 1 => 8
    outsize = calculate(outsize,3,1,1)
    self.norm3 = nn.GroupNorm(128, 128)
    self.avg_pooling3 = nn.AvgPool2d(kernel_size=2, stride=2) #out = (n+2p-f)/s+1 => 8+0-2/2 +1 => 4
    outsize = calculate(outsize,2,2,0)
    self.classifier = nn.Linear(outsize*outsize*128, 10)

  def forward(self, x):
    out = self.conv1(x)
    out = self.norm1(out)
    out = F.relu(out)
    out = self.avg_pooling1(out)
    out = self.conv2(out)
    out = self.norm2(out)
    out = F.relu(out)
    out = self.avg_pooling2(out)
    out = self.conv3(out)
    out = self.norm3(out)
    out = F.relu(out)
    out = self.avg_pooling3(out)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

In [21]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(im_size[0]*im_size[1]*channel, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, n_classes)

    def forward(self, x):
        out = x.view(x.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [22]:
#based on https://en.wikipedia.org/wiki/LeNet
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        outsize = 28
        self.conv1 = nn.Conv2d(channel, 6, kernel_size=5)
        outsize = calculate(outsize, 5, 1,0)
        self.avg1 = nn.AvgPool2d(kernel_size=2, stride=2)
        outsize = calculate(outsize, 2, 2,0)
        self.conv2 = nn.Conv2d(6,16,kernel_size=5)
        outsize = calculate(outsize, 5, 1,0)
        self.avg2 = nn.AvgPool2d(kernel_size=2, stride=2)
        outsize = calculate(outsize, 2, 2,0)
        self.fc1 = nn.Linear(outsize*outsize*16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, n_classes)
        
    def forward(self, x):
        out = self.conv1(x)
        out = F.sigmoid(out)
        out = self.avg1(out)
        out = self.conv2(out)
        out = F.sigmoid(out)
        out = self.avg2(out)
        out = out.view(out.size(0), -1)
        out = F.sigmoid(self.fc1(out))
        out = F.sigmoid(self.fc2(out))
        out = self.fc3(out)
        return out

In [23]:
trans = transforms.Resize((227,227))

#based on https://www.analyticsvidhya.com/blog/2021/03/introduction-to-the-architecture-of-alexnet/ 
class AlexNet(torch.nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()
    outsize = 227
    self.conv1 = nn.Conv2d(in_channels=channel, out_channels=96, kernel_size=11, padding=0, stride=4) #32*32
    outsize = calculate(outsize,11,4,0)
    self.max_pooling1 = nn.MaxPool2d(kernel_size=3, stride=2) # (n+2p-f)/s+1 => 32+0-2/2 + 1 =16
    outsize = calculate(outsize,3,2,0)
    self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2, stride=1) #32*32
    outsize = calculate(outsize,5,1,2)
    self.max_pooling2 = nn.MaxPool2d(kernel_size=3, stride=2) # (n+2p-f)/s+1 => 32+0-2/2 + 1 =16
    outsize = calculate(outsize,3,2,0)
    self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1, stride=1) #32*32
    outsize = calculate(outsize,3,1,1)
    self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1, stride=1) #32*32
    outsize = calculate(outsize,3,1,1)
    self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1, stride=1) #32*32
    outsize = calculate(outsize,3,1,1)
    self.max_pooling3 = nn.MaxPool2d(kernel_size=3, stride=2) # (n+2p-f)/s+1 => 32+0-2/2 + 1 =16
    outsize = calculate(outsize,3,2,0)
    self.dropout1 = nn.Dropout(p=0.5)
    self.fc1 = nn.Linear(outsize*outsize*256, 4096)
    self.dropout2 = nn.Dropout(p=0.5)
    self.fc2 = nn.Linear(4096, 4096)
    self.fc3 = nn.Linear(4096, 10)

  def forward(self, x):
    x = trans(x)
    out = self.conv1(x)
    out = F.relu(out)
    out = self.max_pooling1(out)
    out = self.conv2(out)
    out = F.relu(out)
    out = self.max_pooling2(out)
    out = self.conv3(out)
    out = F.relu(out)
    out = self.conv4(out)
    out = F.relu(out)
    out = self.conv5(out)
    out = F.relu(out)
    out = self.max_pooling3(out)
    out = self.dropout1(out)
    out = out.view(out.size(0), -1)
    out = self.fc1(out)
    out = F.relu(out)
    out = self.dropout2(out)
    out = self.fc2(out)
    out = F.relu(out)
    out = self.fc3(out)
    out = F.softmax(out)
    return out

In [24]:
#Author's published implementation of LeNet
class LeNetTheirs(nn.Module):
    def __init__(self, channel, num_classes):
        super(LeNetTheirs, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_1 = nn.Linear(16 * 5 * 5, 120)
        self.fc_2 = nn.Linear(120, 84)
        self.fc_3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = self.fc_3(x)
        return x

In [25]:
#Author's published implementation of AlexNet
class AlexNetTheirs(nn.Module):
    def __init__(self, channel, num_classes):
        super(AlexNetTheirs, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Linear(192 * 4 * 4, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [26]:
#to get the model specified
def new_network(model):
  if model == "CNN":
    return CNN()
  if model == "AlexNet":
    return AlexNet()
  if model == "AlexNetTheirs":
    return AlexNetTheirs(channel,n_classes)
  if model == "MLP":
    return MLP()
  if model == "LeNet":
    return LeNet()
  if model == "LeNetTheirs":
    return LeNetTheirs(channel,n_classes)

In [27]:
# Experiment 1 
def experiment1(model, dataset, images_per_class,  iterations, network_steps):
  createData(dataset)
  all_synthetic_datas = train_synthetic(model,dataset, images_per_class, iterations, network_steps)
  evaluation(model,all_synthetic_datas, images_per_class)

In [28]:
# Experiment 2
def experiment2(model, dataset, images_per_class,  iterations, network_steps):
  createData(dataset)
  all_synthetic_datas = train_synthetic(model,dataset, images_per_class, iterations, network_steps)
  models = ["CNN", "MLP", "LeNet", "AlexNet"] #models used to evaluate the synthetic data
  for m in models:
    evaluation(m, all_synthetic_datas, images_per_class)

In [None]:
experiment1("CNN", "SVHN", 1, 1000, 1) #ConvNet model, MNIST dataset, 1000 iterations and 1 image per class

In [None]:
experiment2("AlexNet", "MNIST", 1, 1000, 1) #ConvNet model, MNIST dataset, 1000 iterations and 1 image per class

**For mean and std**

In [None]:
#To find the mean and std of the datasets
transform = transforms.Compose([
  transforms.ToTensor()
])

from torchvision.datasets import MNIST
trainset = MNIST(".", train=True, download=True, transform=transform)
testset = MNIST(".", train=False, download=True, transform=transform)
trainset_copy = MNIST(".", train=True, download=True, transform=transform)

# from torchvision.datasets import FashionMNIST
# trainset = FashionMNIST(".", train=True, download=True, transform=transform)
# testset = FashionMNIST(".", train=False, download=True, transform=transform)
# trainset_copy = FashionMNIST(".", train=True, download=True, transform=transform)

# from torchvision.datasets import CIFAR10
# trainset = CIFAR10(".", train=True, download=True, transform=transform)
# testset = CIFAR10(".", train=False, download=True, transform=transform)
# trainset_copy = CIFAR10(".", train=True, download=True, transform=transform)

# from torchvision.datasets import SVHN
# trainset = SVHN(".", split='train', transform=transform, download=True)
# testset = SVHN(".", split='test', transform=transform, download=True)
# trainset_copy = SVHN(".", split='test', transform=transform, download=True)


loader = DataLoader(trainset, batch_size=256, num_workers=0, shuffle=False)

mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

print(mean)
print(std)

MNIST
mean = [0.1307], std = [0.3015]

FashionMNIST
mean = [0.2860], std = [0.3205]

SVHN
mean = [0.4377, 0.4438, 0.4728],
std = [0.1201, 0.1231, 0.1052]

CIFAR10
mean = [0.4914, 0.4822, 0.4465],
std = [0.2023, 0.1994, 0.2010]

If there is a crash, reload:

In [None]:
path = F"/content/gdrive/MyDrive/data_syn.pt" 
syn=torch.load(path)
print(syn.shape)
all_synthetic_data=[]
all_synthetic_data.append(syn)
loss_function = nn.CrossEntropyLoss().to(device)
#evaluation("CNN",all_synthetic_data, 1)