# **DCGAN Unlearning Metrics**

Requires:
- Full training dataset (if evaluate original model).
- Desired dataset (obtained by removing the unwanted samples/class from the training data).
- Original model.
- Unlearned models (all 4 losses `inv`, `neg`, `exp`, and `ens`).
- Retrained model using desired dataset.

## **Import**

In [7]:
import os
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F


import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from scipy.linalg import sqrtm
from sklearn.manifold import TSNE


import networks as nws
import utils

In [8]:
# Thiết bị sử dụng
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)

cuda:0


## **Load DCGAN models**

In [None]:
# Pretrained model location
original_train_epochs = 300
original_save_dir = "checkpoints_original"

# Choose an undesired class
ul_class = 8  # `1` or `8`

# Unlearned model location
do_mask = True
ckpt_dir_suffix = "_mask" if do_mask else ""
dset_name = f"MNIST_ul{ul_class}_data"
ckpt_dir = f"checkpoints_unlearn{ckpt_dir_suffix}"

# Hyperparams for all the repulsion loss
rlweight_inv, rlweight_neg, rlweight_exp, rlweight_ens = 5, 1, 20, 20
epochs_inv, epochs_neg, epochs_exp, epochs_ens = 5, 300, 300, 300
alpha = 0.005

In [None]:
nz = 128
version = 3

G_original = nws.Generator(version=version).to(device)
G_retrained_desired = nws.Generator(version=version).to(device)
G_unlearned_l2_inv = nws.Generator(version=version).to(device)
G_unlearned_l2_neg = nws.Generator(version=version).to(device)
G_unlearned_l2_exp = nws.Generator(version=version).to(device)
G_unlearned_l2_ens = nws.Generator(version=version).to(device)

G_original.load_state_dict(torch.load(f'{original_save_dir}/gen_MNIST_full_v{version}_nz={nz}_epochs={original_train_epochs}.pth'))
G_retrained_desired.load_state_dict(torch.load(f'{original_save_dir}/gen_MNIST_without_{ul_class}s_v{version}_nz={nz}_epochs={original_train_epochs}.pth'))  # change

dir_l2inv = f"rlweight={rlweight_inv}_alpha=None_epochs={epochs_inv}_dataset={dset_name}"
dir_l2neg = f"rlweight={rlweight_neg}_alpha=None_epochs={epochs_neg}_dataset={dset_name}"
dir_l2exp = f"rlweight={rlweight_exp}_alpha={alpha}_epochs={epochs_exp}_dataset={dset_name}"
dir_l2ens = f"rlweight={rlweight_ens}_alpha={alpha}_epochs={epochs_ens}_dataset={dset_name}"

G_unlearned_l2_inv.load_state_dict(torch.load(f'{ckpt_dir}/l2_inv/{dir_l2inv}/gen_unlearned_v3_nz=128_l2_inv.pth'))
G_unlearned_l2_neg.load_state_dict(torch.load(f'{ckpt_dir}/l2_neg/{dir_l2neg}/gen_unlearned_v3_nz=128_l2_neg.pth'))
G_unlearned_l2_exp.load_state_dict(torch.load(f'{ckpt_dir}/l2_exp/{dir_l2exp}/gen_unlearned_v3_nz=128_l2_exp.pth'))
G_unlearned_l2_ens.load_state_dict(torch.load(f'{ckpt_dir}/l2_ens/{dir_l2ens}/gen_unlearned_v3_nz=128_l2_ens.pth'))

  G_original.load_state_dict(torch.load(f'{original_save_dir}/gen_MNIST_full_v{version}_nz={nz}_epochs={original_train_epochs}.pth'))
  G_retrained_desired.load_state_dict(torch.load(f'{original_save_dir}/gen_MNIST_without_{ul_class}s_v{version}_nz={nz}_epochs={original_train_epochs}.pth'))  # change
  G_unlearned_l2_inv.load_state_dict(torch.load(f'{ckpt_dir}/l2_inv/{dir_l2inv}/gen_unlearned_v3_nz=128_l2_inv.pth'))
  G_unlearned_l2_neg.load_state_dict(torch.load(f'{ckpt_dir}/l2_neg/{dir_l2neg}/gen_unlearned_v3_nz=128_l2_neg.pth'))
  G_unlearned_l2_exp.load_state_dict(torch.load(f'{ckpt_dir}/l2_exp/{dir_l2exp}/gen_unlearned_v3_nz=128_l2_exp.pth'))
  G_unlearned_l2_ens.load_state_dict(torch.load(f'{ckpt_dir}/l2_ens/{dir_l2ens}/gen_unlearned_v3_nz=128_l2_ens.pth'))


<All keys matched successfully>

## **Load classifier**

In [10]:
classifier = nws.CNN_Classifier().to(device)
classifier.load_state_dict(torch.load(f'classifier/mnist_classifier.pth'))

  classifier.load_state_dict(torch.load(f'classifier/mnist_classifier.pth'))


<All keys matched successfully>

## **Original GAN metrics**

In [None]:
batch_size = 64
dataset = dset.MNIST(
    root="MNIST_full",
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]),
    download=True
)

# Create the dataloader
original_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

# print(len(original_dataloader))
# print(len(original_dataloader.dataset))

# features_original_list = []
# count = 0
# for i, (images, labels) in enumerate(original_dataloader):
#     if count >= 12800: break
#     features_original_list.append(utils.get_features(classifier, images))
#     count += images.size(0)

# features_original = np.concatenate(features_original_list, axis=0)
# print(features_original.shape)

In [None]:
n_iters = 1
fids = utils.compute_FIDs(G_original, classifier, nz, device, dset_loader=original_dataloader, n_iters=n_iters)
torch.cuda.empty_cache()
np.savetxt(f"FID_save/FID_niters={n_iters}_originalGAN_v{version}_nz={nz}_epochs={original_train_epochs}.txt", fids)
print(f">> FID ({n_iters} iters): {fids.mean()} (err: {fids.std()})")

In [13]:
# tsne = TSNE(n_components=2, perplexity=30, random_state=42)

# features_original_data_list = []
# count = 0
# for i, (images, labels) in enumerate(original_dataloader):
#     if count >= 12800: break
#     features_original_data_list.append(get_features(classifier, images))
#     count += images.size(0)
# features_original = np.concatenate(features_original_data_list, axis=0)

# generated = generate_folder(G_original, examples=12800)
# features_generated = get_features(classifier, generated)
# X_original_GAN = np.concatenate([features_original, features_generated])
# X_tsne_original_GAN = tsne.fit_transform(X_original_GAN)

In [14]:
# ft_tsne_original = X_tsne_original_GAN[0:12800+1]
# ft_tsne_original_generated = X_tsne_original_GAN[12800+1:]

# fig, ax = plt.subplots(1, 2)

# ax[0].scatter(ft_tsne_original[:, 0], ft_tsne_original[:, 1], alpha=0.1)
# ax[0].set_title("Original data (full class)")
# ax[1].scatter(ft_tsne_original_generated[:, 0], ft_tsne_original_generated[:, 1], alpha=0.1)
# ax[1].set_title("Generated data (original GAN)")

# fig.set_size_inches((12, 6))
# plt.show()

## **Unlean metrics**

In [21]:
batch_size = 64
dataset = dset.ImageFolder(
    root=f"MNIST_without_{ul_class}s",
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.Resize((28, 28)), 
        transforms.Grayscale()
    ])
)

# Create the dataloader
desired_dset_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                        shuffle=True)
print(len(desired_dset_dataloader))

847


### **Compute PUL**

In [None]:
n_iters = 1   # ---> Set to >1 if measure multiple times
puls_l2inv = utils.compute_PULs(G_original, G_unlearned_l2_inv, classifier, ul_class, nz, device, n_iters=n_iters)
puls_l2neg = utils.compute_PULs(G_original, G_unlearned_l2_neg, classifier, ul_class, nz, device, n_iters=n_iters)
puls_l2exp = utils.compute_PULs(G_original, G_unlearned_l2_exp, classifier, ul_class, nz, device, n_iters=n_iters)
puls_l2ens = utils.compute_PULs(G_original, G_unlearned_l2_exp, classifier, ul_class, nz, device, n_iters=n_iters)
print(f">> PUL l2_inv ({n_iters} iters): {puls_l2inv.mean()} (err: {puls_l2inv.std()})")
print(f">> PUL l2_neg ({n_iters} iters): {puls_l2neg.mean()} (err: {puls_l2neg.std()})")
print(f">> PUL l2_exp ({n_iters} iters): {puls_l2exp.mean()} (err: {puls_l2exp.std()})")
print(f">> PUL l2_exp ({n_iters} iters): {puls_l2ens.mean()} (err: {puls_l2ens.std()})")

np.savetxt(f"PUL_save/PUL_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2inv_nz={nz}_dset={dset_name}.txt", puls_l2inv)
np.savetxt(f"PUL_save/PUL_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2neg_nz={nz}_dset={dset_name}.txt", puls_l2neg)
np.savetxt(f"PUL_save/PUL_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2exp_nz={nz}_dset={dset_name}.txt", puls_l2exp)
np.savetxt(f"PUL_save/PUL_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2ens_nz={nz}_dset={dset_name}.txt", puls_l2ens)

>> PUL l2_inv (1 iters): 75.1709626512362 (err: 0.0)
>> PUL l2_neg (1 iters): 76.01931330472102 (err: 0.0)
>> PUL l2_exp (1 iters): 77.2823779193206 (err: 0.0)
>> PUL l2_exp (1 iters): 77.9596290234588 (err: 0.0)


### **Compute FID**
- ref: https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/

In [None]:
# Retrained
n_iters = 1   # ---> Set to >1 if measure multiple times
fids_retrained = utils.compute_FIDs(G_retrained_desired, classifier, nz, device, dset_loader=desired_dset_dataloader, n_iters=n_iters)
print(f">> FID retrained no {ul_class} ({n_iters} iters): {fids_retrained.mean()} (err: {fids_retrained.std()})")
np.savetxt(f"FID_save/FID_niters={n_iters}_retrainedGAN_v{version}_l2inv_nz={nz}_dset=MNIST_without_{ul_class}s.txt", fids_retrained)

>> FID retrained no 8 (1 iters): 21.090245077570014 (err: 0.0)


In [None]:
n_iters = 1   # ---> Set to >1 if measure multiple times
fids_l2inv = utils.compute_FIDs(G_unlearned_l2_inv, classifier, nz, device, dset_loader=desired_dset_dataloader, n_iters=n_iters)
fids_l2neg = utils.compute_FIDs(G_unlearned_l2_neg, classifier, nz, device, dset_loader=desired_dset_dataloader, n_iters=n_iters)
fids_l2exp = utils.compute_FIDs(G_unlearned_l2_exp, classifier, nz, device, dset_loader=desired_dset_dataloader, n_iters=n_iters)
fids_l2ens = utils.compute_FIDs(G_unlearned_l2_exp, classifier, nz, device, dset_loader=desired_dset_dataloader, n_iters=n_iters)
print(f">> FID l2_inv ({n_iters} iters): {fids_l2inv.mean()} (err: {fids_l2inv.std()})")
print(f">> FID l2_neg ({n_iters} iters): {fids_l2neg.mean()} (err: {fids_l2neg.std()})")
print(f">> FID l2_exp ({n_iters} iters): {fids_l2exp.mean()} (err: {fids_l2exp.std()})")
print(f">> FID l2_exp ({n_iters} iters): {fids_l2ens.mean()} (err: {fids_l2ens.std()})")

np.savetxt(f"FID_save/FID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2inv_nz={nz}_dset={dset_name}.txt", fids_l2inv)
np.savetxt(f"FID_save/FID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2neg_nz={nz}_dset={dset_name}.txt", fids_l2neg)
np.savetxt(f"FID_save/FID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2exp_nz={nz}_dset={dset_name}.txt", fids_l2exp)
np.savetxt(f"FID_save/FID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2ens_nz={nz}_dset={dset_name}.txt", fids_l2ens)

>> FID l2_inv (1 iters): 166.1316368791977 (err: 0.0)
>> FID l2_neg (1 iters): 48.95280833440727 (err: 0.0)
>> FID l2_exp (1 iters): 51.22682271719442 (err: 0.0)
>> FID l2_exp (1 iters): 48.90649580347077 (err: 0.0)


In [None]:
# Ret-FID
n_iters = 1   # ---> Set to >1 if measure multiple times
retfids_l2inv = utils.compute_FIDs(G_unlearned_l2_inv, classifier, nz, device, model2=G_retrained_desired, n_iters=n_iters)
retfids_l2neg = utils.compute_FIDs(G_unlearned_l2_neg, classifier, nz, device, model2=G_retrained_desired, n_iters=n_iters)
retfids_l2exp = utils.compute_FIDs(G_unlearned_l2_exp, classifier, nz, device, model2=G_retrained_desired, n_iters=n_iters)
retfids_l2ens = utils.compute_FIDs(G_unlearned_l2_exp, classifier, nz, device, model2=G_retrained_desired, n_iters=n_iters)
print(f">> Ret-FID l2_inv ({n_iters} iters): {retfids_l2inv.mean()} (err: {retfids_l2inv.std()})")
print(f">> Ret-FID l2_neg ({n_iters} iters): {retfids_l2neg.mean()} (err: {retfids_l2neg.std()})")
print(f">> Ret-FID l2_exp ({n_iters} iters): {retfids_l2exp.mean()} (err: {retfids_l2exp.std()})")
print(f">> Ret-FID l2_exp ({n_iters} iters): {retfids_l2ens.mean()} (err: {retfids_l2ens.std()})")

np.savetxt(f"FID_save/RetFID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2inv_nz={nz}_dset={dset_name}.txt", retfids_l2inv)
np.savetxt(f"FID_save/RetFID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2neg_nz={nz}_dset={dset_name}.txt", retfids_l2neg)
np.savetxt(f"FID_save/RetFID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2exp_nz={nz}_dset={dset_name}.txt", retfids_l2exp)
np.savetxt(f"FID_save/RetFID_niters={n_iters}_unlearnedGAN{ckpt_dir_suffix}_v{version}_l2ens_nz={nz}_dset={dset_name}.txt", retfids_l2ens)

>> Ret-FID l2_inv (1 iters): 182.6613977219073 (err: 0.0)
>> Ret-FID l2_neg (1 iters): 59.44633157416936 (err: 0.0)
>> Ret-FID l2_exp (1 iters): 67.97716664634336 (err: 0.0)
>> Ret-FID l2_exp (1 iters): 67.17950681242627 (err: 0.0)
