## Project Initialization

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.auto import tqdm
from torchvision.utils import save_image
from thop import profile
from utils import get_network, get_dataset
import main_AttentionMatching
import sys

figures_dir = '../report/figures/'

print(f"PyTorch Version: {torch.__version__}")
if torch.cuda.is_available():
    device = torch.device("cuda")
    device_name = torch.cuda.get_device_name(0)
    properties = torch.cuda.get_device_properties(0)
    compute_capability = f"{properties.major}.{properties.minor}"
    total_memory = properties.total_memory / 1024**3

    print(f"CUDA Device: {device_name}")
    print(f"CUDA Compute Capability: {compute_capability}")
    print(f"Total Memory: {total_memory:.2f} GB")
else:
    device = torch.device("cpu")
    print("GPU is not available")
    
    
def epoch_S(mode, dataloader, net, optimizer, criterion, device, progress_bar):
    loss_avg, acc_avg, num_exp = 0, 0, 0
    net = net.to(device)
    criterion = criterion.to(device)

    if mode == 'train':
        net.train()
    else:
        net.eval()

    for i_batch, datum in enumerate(dataloader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]

        output = net(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        progress_bar.set_postfix(loss=loss.item() / (i + 1))
        progress_bar.update(1)

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg

# function to get FLOPS for a given model
def get_flops(model, dataloader, device):
    for inputs, _ in dataloader:
        # Get a single image from the batch
        # add an extra batch dimension to the image, as the models expect a batch of
        # images as input, not a single image.
        single_image = inputs[0].unsqueeze(0).to(device)
        break
    flops = profile(model, inputs=(single_image, ), verbose=False)
    return flops


## MNIST

In [None]:
# load MNIST dataset from utils
(
    channel,
    im_size,
    num_classes,
    class_names,
    mean,
    std,
    train_MNIST_dataset,
    test_MNIST_dataset,
    test_MNIST_dataloader,
    train_MNIST_dataloader,
) = get_dataset("MNIST", "../datasets")

# visualize 10 classes of MNIST (2 by 5)
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(
        train_MNIST_dataset.data[train_MNIST_dataset.targets == i][0], cmap="gray"
    )
    ax.set_title(f"{i}")
    ax.axis("off")
plt.tight_layout()
plt.savefig(figures_dir + "MNIST_dataset.png", dpi=300)

### ConvNet3


In [None]:
ConvNet3 = get_network('ConvNetD3', channel, num_classes, im_size)
print(ConvNet3)

In [None]:
n_epochs = 5
lr = 0.01
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(ConvNet3.parameters(), lr=lr)
train_acc, test_acc = [], []
trainLoader = train_MNIST_dataloader
testLoader = test_MNIST_dataloader
for ep in range (n_epochs):
    progress_bar = tqdm(
        enumerate(trainLoader, 0),
        total=len(trainLoader) + len(testLoader),
        desc=f"Epoch {ep+1}",
    )
    
    train_loss_avg, train_acc_avg = epoch_S('train', trainLoader, ConvNet3, optimizer, criterion, device, progress_bar)
    test_loss_avg, test_acc_avg = epoch_S('test', testLoader, ConvNet3, optimizer, criterion, device,progress_bar)
    
    train_acc.append(train_acc_avg)
    test_acc.append(test_acc_avg)
    
flops, _ = get_flops(ConvNet3, testLoader, device)
print("FLOPS: {:,}".format(flops))

In [None]:
# plot the training and test accuracy
plt.plot(train_acc, label="Train")
plt.plot(test_acc, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

### Synthetic dataset generation

In [None]:


sys.argv = [
    'main_AttentionMatching.py',
    '--init', 'real',
    '--model', 'ConvNetD3',
    '--dataset', 'MNIST',
    '--output_file', 'NMIST_real_res.pt',       # Output file
    '--ipc', '10',                              # Images/class
    '--lr_img', '0.1',                          # eta_s
    '--lr_net', '0.01',                         # eta_theta
    '--num_eval', '50',                         # zeta_theta
    '--epoch_eval_train', '1',                  # zeta_s
    '--Iteration', '10',                        # T
]

NMIST_real_res = main_AttentionMatching.main()

In [None]:
# plot the results
results = torch.load(NMIST_real_res)
syn_imgs = results['data'][0][0]
syn_imgs = torch.clamp(syn_imgs, 0, 1)

# clip the images to [0, 1]
 
# 10 images per class
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(syn_imgs[i].permute(1, 2, 0).squeeze(), cmap="gray")
    ax.axis("off")

In [None]:
sys.argv = [
    'main_AttentionMatching.py',
    '--init', 'noise',
    '--model', 'ConvNetD3',
    '--dataset', 'MNIST',
    '--output_file', 'NMIST_noise_res.pt',      # Output file
    '--ipc', '10',                              # Images/class
    '--lr_img', '0.1',                          # eta_s
    '--lr_net', '0.01',                         # eta_theta
    '--num_eval', '50',                         # zeta_theta
    '--epoch_eval_train', '1',                  # zeta_s
    '--Iteration', '10',                        # T
]

NMIST_noise_res = main_AttentionMatching.main()

In [None]:
# plot the results
results = torch.load(NMIST_noise_res)
syn_imgs = results['data'][0][0]
syn_imgs = torch.clamp(syn_imgs, 0, 1)

# clip the images to [0, 1]
 
# 10 images per class
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(syn_imgs[i].permute(1, 2, 0).squeeze(), cmap="gray")
    ax.axis("off")

## MHIST

In [None]:
(
    channel,
    im_size,
    num_classes,
    class_names,
    mean,
    std,
    train_MHIST_dataset,
    test_MHIST_dataset,
    test_MHIST_dataloader,
    train_MHIST_dataloader,
) = get_dataset("MHIST", "../datasets")

indices = [100, 1560]

# plot 2 images from the dataset
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
for i, ax in enumerate(axes.flat):
    image, label = train_MHIST_dataset[indices[i]]
    # Transpose the image from [3, 224, 224] to [224, 224, 3] for plotting
    image = image.permute(1, 2, 0)
    ax.imshow(image)
    if label == 0:
        ax.set_title("HP")
    else:
        ax.set_title("SSA")
    ax.axis("off")
plt.tight_layout()
plt.savefig(figures_dir + "MHIST_dataset.png", dpi=300)

### ConvNet7

In [None]:
ConvNet7 = get_network('ConvNetD7', channel, num_classes, im_size)
print(ConvNet7)

In [None]:
n_epochs = 5
lr = 0.01
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(ConvNet7.parameters(), lr=lr)
train_acc, test_acc = [], []
trainLoader = train_MHIST_dataloader
testLoader = test_MHIST_dataloader
for ep in range (n_epochs):
    progress_bar = tqdm(
        enumerate(trainLoader, 0),
        total=len(trainLoader) + len(testLoader),
        desc=f"Epoch {ep+1}",
    )
    
    train_loss_avg, train_acc_avg = epoch_S('train', trainLoader, ConvNet7, optimizer, criterion, device, progress_bar)
    test_loss_avg, test_acc_avg = epoch_S('test', testLoader, ConvNet7, optimizer, criterion, device,progress_bar)
    
    train_acc.append(train_acc_avg)
    test_acc.append(test_acc_avg)
    
flops, _ = get_flops(ConvNet7, testLoader, device)
print("FLOPS: {:,}".format(flops))

### Synthetic dataset generation

In [None]:
sys.argv = [
    'main_AttentionMatching.py',
    '--init', 'real',
    '--model', 'ConvNetD7',
    '--dataset', 'MHIST',
    '--output_file', 'MHIST_real_res.pt',       # Output file
    '--ipc', '50',                              # Images/class
    '--lr_img', '0.1',                          # eta_s
    '--lr_net', '0.01',                         # eta_theta
    '--num_eval', '50',                         # zeta_theta
    '--epoch_eval_train', '1',                  # zeta_s
    '--Iteration', '10',                        # T
]

MHIST_real_res = main_AttentionMatching.main()