In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import namedtuple
import os
from time import time, sleep
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision
import torchvision.transforms as transforms

import opacus
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator

from utils import generate_run_id, get_input_args, Args, parse_run_id
from models import Discriminator_FC, Generator_MNIST, Weight_Clipper, G_weights_init, Generator_FC, Encoder_Mini, Decoder_Mini, VAE, Encoder_VAE, Decoder_VAE
from data import load_MNIST
from metrics import get_IS, get_FID
from model_inversion import enc_fp, dec_fp, gen_fp
from evaluate_metrics import last_num_models

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
pub_G = Generator_FC(hidden_sizes=[256], nz=100).to(device)
pub_G.load_state_dict(torch.load(gen_fp))

pub_Dec = Decoder_Mini(latent_size=100).to(device)
pub_Dec.load_state_dict(torch.load(dec_fp))

pub_Enc = Encoder_Mini(latent_size=100).to(device)
pub_Enc.load_state_dict(torch.load(enc_fp))

pub_G.eval()
pub_Dec.eval()
pub_Enc.eval()
print("Loaded public models")

Loaded public models


In [5]:
# Load ISs
df_1 = pd.read_csv("vae_results.csv")
# Read FIDs
df_2 = pd.read_csv("vae_FID.csv")

# Join df and df_2 on model_fp
df = df_1.merge(df_2, on="model_fp", suffixes=("", "_gan"))
# Drop non-48000 n_g
df = df[df["n_g"] == 48000]

# Keep only the columns we want (noise_multiplier, activation, c_p, lr, model_fp, IS, FID)
df = df[["noise_multiplier", "c_p", "activation", "lr", "IS", "FID", "model_fp"]]
df = df.sort_values(by=["noise_multiplier","c_p", "FID"])
df = df.reset_index(drop=True)

df.head()

Unnamed: 0,noise_multiplier,c_p,activation,lr,IS,FID,model_fp
0,0.01,0.001,Tanh,0.02,1.092349,338.987852,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
1,0.01,0.001,LeakyReLU,0.01,1.004933,350.900943,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
2,0.01,0.001,Tanh,0.01,1.064879,354.84327,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
3,0.01,0.001,LeakyReLU,0.02,1.017361,358.692219,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
4,0.01,0.005,LeakyReLU,0.01,1.335345,242.313955,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....


In [6]:
# group by noise and c_p, and get the best FID
df_best = df.groupby(["noise_multiplier", "c_p"]).first().reset_index()
df_best = df_best.sort_values(by=["noise_multiplier", "c_p"])
df_best = df_best.reset_index(drop=True)
df_best.head()

Unnamed: 0,noise_multiplier,c_p,activation,lr,IS,FID,model_fp
0,0.01,0.001,Tanh,0.02,1.092349,338.987852,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
1,0.01,0.005,LeakyReLU,0.01,1.335345,242.313955,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
2,0.01,0.01,LeakyReLU,0.01,1.65778,178.152178,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
3,0.01,0.05,LeakyReLU,0.01,1.863078,125.267373,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0....
4,0.05,0.001,Tanh,0.02,1.119791,341.618543,runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.05_0....


In [44]:
for fp in list(df_best["model_fp"].values):
    os.system("cp -r {} runs_vae_tmp/".format(fp))

In [7]:
# for fp in list(df_best["model_fp"].values):
#     os.system("rm -rf {}".format(fp))

In [8]:
hiddens = [64]
noise_multipliers = [0.2, 0.3, 0.4] # [0.01, 0.05, 0.1]
activations = ["LeakyReLU", "Tanh", ]
c_ps = [0.25, 0.1] # [0.05, 0.01, 0.005, 0.001]
lrs = [0.01] #  0.005]

nz = 32
n_d = 0
n_g = 48000
batch_size = 64
from itertools import product
for activation, c_p, noise_multiplier, lr in product(
        activations, c_ps, noise_multipliers, lrs):
    args = Args(
        # Model Parameters
        hidden=[64], nz=32, ngf=32, nc=1, activation=activation,
        # Privacy Parameters
        epsilon=50.0, delta=1e-6, noise_multiplier=noise_multiplier, c_p=c_p,
        # Training Parameters
        lr=lr, beta1=0.5, batch_size=batch_size, n_d=0, n_g=n_g, lambda_gp=0.0
    )
    # main(args, latent_type="ae_enc")
    run_id = "ae-grad_" + generate_run_id(args)
    print(run_id, os.path.exists("runs_vae/{}".format(run_id)))

ae-grad_64_32_32_1_50.0_1e-06_0.2_0.25_0.01_0.5_64_0_48000_LeakyReLU_0.0 True
ae-grad_64_32_32_1_50.0_1e-06_0.3_0.25_0.01_0.5_64_0_48000_LeakyReLU_0.0 True
ae-grad_64_32_32_1_50.0_1e-06_0.4_0.25_0.01_0.5_64_0_48000_LeakyReLU_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.2_0.1_0.01_0.5_64_0_48000_LeakyReLU_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.3_0.1_0.01_0.5_64_0_48000_LeakyReLU_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.4_0.1_0.01_0.5_64_0_48000_LeakyReLU_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.2_0.25_0.01_0.5_64_0_48000_Tanh_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.3_0.25_0.01_0.5_64_0_48000_Tanh_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.4_0.25_0.01_0.5_64_0_48000_Tanh_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.2_0.1_0.01_0.5_64_0_48000_Tanh_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.3_0.1_0.01_0.5_64_0_48000_Tanh_0.0 False
ae-grad_64_32_32_1_50.0_1e-06_0.4_0.1_0.01_0.5_64_0_48000_Tanh_0.0 False


In [7]:
from scipy.interpolate import interp1d
import pickle

def calculate_epsilon_used(run_fp, delta=1e-5):
    """Calculates the epsilon used for a given noise_multiplier
    We need to linearly extrapolate past a few thousand batches since
        the accountant uses too much memory and gets killed
    """
    run_id = run_fp.split("/")[-1]
    args = parse_run_id(run_id)

    # Read in loss file
    # Lines look like:
    # Epsilon 200 68.85602575854573
    # with open(f"{run_fp}/losses.txt", "r") as f:
    #     lines = f.readlines()
    # epsilons = []
    # for line in lines:
    #     if "Epsilon" in line:
    #         batch_idx, eps = line.split(" ")[1:]
    #         epsilons.append([float(batch_idx), float(eps)])
    

    accts = sorted([
        (int(fp.split("_")[-1].strip(".pt")), fp) 
        for fp in os.listdir(run_fp) if fp.startswith("accountant")
    ])

    epsilons = []
    for batch_idx, acct_fp in accts:
        if batch_idx > 20000:
            print("Breaking")
            break
        try:
            print(batch_idx, "Noise:", args.noise_multiplier, "Clip:", args.c_p, acct_fp)
            accountant = torch.load(f"{run_fp}/{acct_fp}")
            print(run_fp, acct_fp)
            curr_eps = accountant.get_epsilon(delta)
            print(batch_idx, curr_eps)
        except Exception as e:
            print("Error, breaking", e)
            break
        epsilons.append([batch_idx, curr_eps])

        # If epsilon is too high, we can't use it
        if curr_eps > 200:
            if len(epsilons) == 1:
                epsilons = []
            print("Epsilon too high, breaking")
            break
    # print(epsilons)

    # Linearly interpolate to get the epsilon used
    if len(epsilons) == 0:
        f = interp1d(np.array([0, 1]), np.array([-1, -1]), kind="linear", fill_value="extrapolate")
    else:
        epsilons = np.array(epsilons)
        f = interp1d(epsilons[:, 0], epsilons[:, 1], kind="linear", fill_value="extrapolate")
    
    # Pickle the function
    with open(f"{run_fp}/epsilon_used.pkl", "wb") as file:
        pickle.dump(f, file)
    
    return f(args.n_g)



In [8]:
run_fp = "runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0"
calculate_epsilon_used(run_fp)

200 Noise: 0.2 Clip: 0.5 accountant_200.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0 accountant_200.pt
200 58.99997843029757
400 Noise: 0.2 Clip: 0.5 accountant_400.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0 accountant_400.pt
400 76.08746105795657
600 Noise: 0.2 Clip: 0.5 accountant_600.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0 accountant_600.pt
600 89.88315850500715
800 Noise: 0.2 Clip: 0.5 accountant_800.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0 accountant_800.pt
800 102.03902484103334
1000 Noise: 0.2 Clip: 0.5 accountant_1000.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0 accountant_1000.pt
1000 113.1689867808569
1200 Noise: 0.2 Clip: 0.5 accountant_1200.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0 accountant_1200.pt


KeyboardInterrupt: 

In [None]:
run_fp = "runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0"
calculate_epsilon_used(run_fp)

In [14]:
run_fp = "runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.001_0.01_0.5_64_0_48000_LeakyReLU_0.0"
calculate_epsilon_used(run_fp)

200 Noise: 0.1 Clip: 0.001 accountant_200.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.001_0.01_0.5_64_0_48000_LeakyReLU_0.0 accountant_200.pt
200 293.5059794795981
Epsilon too high, breaking


array(-1.)

In [11]:
valid_runs_fps = []
for run_id in os.listdir("runs_vae"):
    run_fp = f"runs_vae/{run_id}"
    args = parse_run_id(run_id)
    if run_id.startswith("ae-grad") and \
            args.noise_multiplier >= 0.1 and \
            args.n_g > 20000 and \
            args.n_g < 50000:
        valid_runs_fps.append(run_fp)
valid_runs_fps = sorted(valid_runs_fps)
valid_runs_fps

['runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.001_0.01_0.5_64_0_48000_LeakyReLU_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.001_0.01_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.001_0.02_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.005_0.01_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.005_0.02_0.5_64_0_48000_LeakyReLU_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.005_0.02_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.01_0.01_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.01_0.02_0.5_64_0_48000_LeakyReLU_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.01_0.02_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.05_0.01_0.5_64_0_48000_Tanh_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.05_0.02_0.5_64_0_48000_LeakyReLU_0.0',
 'runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.1_0.05_0.02_0.5_64_0_48000_Tanh_0.0']

In [9]:
# run_fp = "runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.2_0.5_0.01_0.5_64_0_100000_LeakyReLU_0.0"
# calculate_epsilon_used(run_fp)

valid_runs_fps = []
for run_id in os.listdir("runs_vae"):
    run_fp = f"runs_vae/{run_id}"
    args = parse_run_id(run_id)
    if run_id.startswith("ae-grad") and \
            args.noise_multiplier > 0 and \
            args.n_g > 20000 and \
            args.n_g < 50000:
        valid_runs_fps.append(run_fp)
valid_runs_fps = sorted(valid_runs_fps)

len(valid_runs_fps)

36

In [10]:
parse_run_id(valid_runs_fps[0].split("/")[-1])

Args(hidden=[64], nz=32, ngf=32, nc=1, epsilon=50.0, delta=1e-06, noise_multiplier=0.01, c_p=0.001, lr=0.01, beta1=0.5, batch_size=64, n_d=0, n_g=48000, activation='LeakyReLU', lambda_gp=0.0)

In [11]:
acct_fp = "runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.01_0.5_64_0_48000_LeakyReLU_0.0/accountant_200.pt"
privacy_engine = PrivacyEngine()
privacy_engine.accountant = torch.load(acct_fp)
privacy_engine.get_epsilon(1e-5)

MemoryError: Unable to allocate 74.1 GiB for an array with shape (9945573802,) and data type float64

In [12]:
for run_fp in valid_runs_fps:
    print(run_fp)
    calculate_epsilon_used(run_fp)

runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.01_0.5_64_0_48000_LeakyReLU_0.0
200 0.01 0.001 accountant_200.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.01_0.5_64_0_48000_LeakyReLU_0.0 accountant_200.pt
Error, breaking Unable to allocate 74.1 GiB for an array with shape (9945573802,) and data type float64
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.01_0.5_64_0_48000_Tanh_0.0
200 0.01 0.001 accountant_200.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.01_0.5_64_0_48000_Tanh_0.0 accountant_200.pt
Error, breaking Unable to allocate 74.1 GiB for an array with shape (9945573802,) and data type float64
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.02_0.5_64_0_48000_LeakyReLU_0.0
200 0.01 0.001 accountant_200.pt
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.01_0.001_0.02_0.5_64_0_48000_LeakyReLU_0.0 accountant_200.pt
Error, breaking Unable to allocate 74.1 GiB for an array with shape (9945573802,) and data type float64
runs_vae/ae-grad_64_32_32_1_50.0_1e-06_0.

: 

: 

In [None]:
# Calculate the epsilon used for each run
idx = 0
print(calculate_epsilon_used(valid_runs_fps[idx]))
# if not os.path.exists(f"{run_fp}/epsilon_used.pkl"):

In [None]:
idx = 1
print(calculate_epsilon_used(valid_runs_fps[idx]))
idx = 2
print(calculate_epsilon_used(valid_runs_fps[idx]))

In [None]:
xs = np.linspace(0, 10000, 200)
# Load the function
with open(f"{run_fp}/epsilon_used.pkl", "rb") as file:
    f = pickle.load(file)
plt.plot(xs, f(xs), 'k', lw=3, alpha=0.7)

In [None]:
def generate_samples(G, args, latent_type=None, plot=True, batch_size=32, seed=False):
    if seed:
        torch.manual_seed(42)

    # Generate Sample Images
    noise = torch.randn(batch_size, args.nz).to(device)
    if latent_type == "wgan":
        output = G(noise)
        fake = pub_G(output)
    elif latent_type == "ae":
        output = G(noise)
        fake = pub_Dec(output)
    else:
        fake = G(noise)
    fake = fake.view(fake.size(0), 1, 28, 28)
    
    # Plot Sample Images
    if plot:
        fig, ax = plt.subplots(4, 8, figsize=(10, 5))
        for i in range(4):
            for j in range(8):
                ax[i, j].imshow(fake[i*8+j][0].detach().cpu().numpy(), cmap='gray')
                ax[i, j].axis('off')
        plt.show()
        plt.close()

    return fake

In [None]:
def plot_reconstructed(G, args, latent_type=None, r0=(-10, 10), r1=(-10, 10), n=12):
    # Setup function
    if latent_type == "wgan":
        G_fn = lambda x: pub_G(G(x))
    elif latent_type == "ae":
        G_fn = lambda x: pub_Dec(G(x))
    else:
        G_fn = G

    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to(device)
            x_hat = G_fn(z)
            x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*r0, *r1])

In [None]:
def samples_during_training(run_fp, end=50000, step=5000):
    # Parse args
    run_id = run_fp.split("/")[-1]
    args = parse_run_id(run_id)

    imgs = []
    for j in range(step, end + 1, step):
        gen_fp = f"vae_{j}.pt"
        gen_fp = os.path.join(run_fp, gen_fp)

        # Check if model exists
        if not os.path.exists(gen_fp):
            print("Model does not exist", gen_fp)
            raise ValueError("Model does not exist")

        vae = VAE(
            Encoder_VAE(args.hidden, latent_size=args.nz), 
            Decoder_VAE(args.hidden, latent_size=args.nz)
        ).to(device)
        vae.load_state_dict(torch.load(gen_fp))
        vae.eval()

        G = vae.decoder
        G.eval()

        fake = generate_samples(
            G, args, latent_type="wgan" if run_id.startswith("wgan") else "ae", 
            plot=False, batch_size=10, seed=True
        )
        imgs.append(fake)
    
    # Plot Sample Images
    n = len(imgs)
    fig, ax = plt.subplots(10, n)
    for i in range(10):
        for j in range(n):
            ax[i, j].imshow(imgs[j][i][0].detach().cpu().numpy(), cmap='gray')
            ax[i, j].axis('off')
            if i == 0:
                ax[i, j].set_title(f"{(j+1)*step}", color="white", fontsize=9)
    fig.subplots_adjust(wspace=0, hspace=0)
    fig.patch.set_facecolor('black')
    fig.suptitle(f"Samples during training epochs", color="white", fontsize=12)
    plt.show()

In [None]:
run_folder = "runs_vae"
run_id = "ae-grad_64_32_32_1_50.0_1e-06_0.0_1.0_0.01_0.5_64_0_100000_LeakyReLU_0.0"

run_fp = os.path.join(run_folder, run_id)
print(run_id)

samples_during_training(run_fp, end=100000, step=10000)

In [None]:
run_folder = "runs_vae"
run_ids = os.listdir(run_folder)
for i in range(len(run_ids)):
    run_id = run_ids[i]
    run_fp = os.path.join(run_folder, run_id)
    args = parse_run_id(run_id)

    # if args.lr != 0.1:
    #     continue
    # if args.n_g != 50000:
    #     continue
    # c_ps = [1.0, 2.0, 4.0, 8.0, 16.0, 32.0]
    # if args.c_p > 32.0:
    #     continue
    # if args.c_p != 0.05:
    #     continue
    if args.noise_multiplier == 0.0:
        continue
    if args.batch_size != 64:
        continue
    if args.n_g != 48000:
        continue

    print(f"Noise: {args.noise_multiplier}, Clip: {args.c_p}")

    gen_fp = last_num_models(run_fp, num=1, query="vae")[0]
    # gen_fp = "vae_48000.pt"
    print(run_id + gen_fp)

    gen_fp = os.path.join(run_fp, gen_fp)

    # Check if exists
    if not os.path.exists(gen_fp):
        print("Generator model not found")
        continue
    
    vae = VAE(
        Encoder_VAE(args.hidden, latent_size=args.nz), 
        Decoder_VAE(args.hidden, latent_size=args.nz)
    ).to(device)
    vae.load_state_dict(torch.load(gen_fp))
    vae.eval()

    G = vae.decoder
    G.eval()

    generate_samples(G, args, latent_type="wgan" if run_id.startswith("wgan") else "ae")
    break

In [None]:
# run_id = "/home/jason/p2/runs_latent/wgan_96_64_32_1_50.0_1e-06_0.0_0.01_5e-05_0.5_64_1_50000_LeakyReLU_0.0"

print no
run_ids = os.listdir("runs_latent")
for i in range(len(run_ids)):
    # run_ids[i] = os.path.join("runs_latent", run_ids[i])
    run_id = run_ids[i]
    print(run_id)

    run_id = run_id.split("/")[-1]
    run_fp = os.path.join('runs_latent/', run_id)
    args = parse_run_id(run_id)

    gen_fp = os.path.join(run_fp, 'netG_48000.pt')
    # Check if exists
    if not os.path.exists(gen_fp):
        print("Generator model not found")
        continue

    G = Generator_FC(args.hidden, args.nz, output_size=(100,)).to(device)
    G.load_state_dict(torch.load(gen_fp))
    G.eval()

    generate_samples(G, args, latent_type="wgan" if run_id.startswith("wgan") else "ae")

In [None]:
# Generating samples from gradient ascent
from model_inversion import gradient_ascent, projected_gradient_ascent

labeling_loader, public_loader, private_loader, test_loader = load_MNIST(10)

# Get a single batch of private images
imgs, _ = next(iter(private_loader))
imgs.shape

In [None]:
iterations_list = [100, 1000, 10000, 100000, 200000]

In [None]:
imgs = imgs.to(device)
# Use pub_Enc and pub_Dec to encode/decode images
imgs_enc = pub_Dec(pub_Enc(imgs)).detach().cpu()

imgs_dec = []
imgs_gen = []
for iterations in iterations_list:
    # Perform gradient ascent
    latent_dec = gradient_ascent(
        pub_Dec, imgs, latent_dim=100, start_lr=200,
        iterations=iterations)

    # Perform projected gradient ascent
    latent_gen = projected_gradient_ascent(
        pub_G, imgs, latent_dim=100, start_lr=200,
        iterations=iterations, z_0_mult=1)
    
    # Decode the latent vectors
    imgs_dec.append(pub_Dec(latent_dec).detach().cpu())
    imgs_gen.append(pub_G(latent_gen).detach().cpu())
imgs = imgs.detach().cpu()

In [None]:
# Plot the images
fig, ax = plt.subplots(10, 12)
rotation = 90
# Original images
for i in range(10):
    ax[i, 0].imshow(imgs[i][0].detach().cpu().numpy(), cmap='gray')
    ax[i, 0].axis('off')
    if i == 0:
        # Rotated title
        ax[i, 0].set_title("Original", color="white", fontsize=9, rotation=rotation)
# Decoded images
for i in range(10):
    ax[i, 1].imshow(imgs_enc[i][0].detach().cpu().numpy(), cmap='gray')
    ax[i, 1].axis('off')
    if i == 0:
        ax[i, 1].set_title("ae-enc", color="white", fontsize=9, rotation=rotation)
# Various iterations settings of gradient ascent (5 settings)
for i in range(10):
    for j in range(5):
        ax[i, j+2].imshow(imgs_dec[j][i][0].detach().cpu().numpy(), cmap='gray')
        ax[i, j+2].axis('off')
        if i == 0:
            ax[i, j+2].set_title(f"ae-grad {iterations_list[j]}k", color="white", fontsize=9, rotation=rotation)
# Various iterations settings of projected gradient ascent (5 settings)
for i in range(10):
    for j in range(5):
        ax[i, j+7].imshow(imgs_gen[j][i][0].detach().cpu().numpy(), cmap='gray')
        ax[i, j+7].axis('off')
        if i == 0:
            ax[i, j+7].set_title(f"wgan {iterations_list[j]}k", color="white", fontsize=9, rotation=rotation)
# Black background
fig.patch.set_facecolor('black')
# Remove the white space around the plots
# plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

In [None]:
from data import load_latent
# Seed
torch.manual_seed(0)

# Load Latent Vectors
loader_ae_enc = load_latent(16, data_fp="data/ae_enc_latent_dataset.pt")
loader_ae_grad = load_latent(16, data_fp="data/ae_grad_latent_dataset.pt")
loader_wgan = load_latent(16, data_fp="data/wgan_latent_dataset.pt")

# Grab a single batch
batch_ae_enc = next(iter(loader_ae_enc))[0]
batch_ae_grad = next(iter(loader_ae_grad))[0]
batch_wgan = next(iter(loader_wgan))[0]

# Latent space to image space
images_ae_enc = pub_Dec(batch_ae_enc)
images_ae_grad = pub_Dec(batch_ae_grad)
images_wgan = pub_G(batch_wgan)

# Plot all images of a batch
def plot_batch(images):
    fig, ax = plt.subplots(4, 4)
    for i in range(4):
        for j in range(4):
            ax[i, j].imshow(images[i*4 + j][0].detach().cpu().numpy(), cmap='gray')
            ax[i, j].axis('off')
    fig.subplots_adjust(wspace=0, hspace=0)
    fig.patch.set_facecolor('black')

    plt.show()

plot_batch(images_ae_enc)
plot_batch(images_ae_grad)
plot_batch(images_wgan)



In [None]:


for i, (data, ) in enumerate(train_loader):
    # Print difference between item 0 and 1
    print(data.shape)
    print(torch.mean(data[0] - data[1]))

    # Add noise to data
    data = data + torch.randn_like(data)
    
    # Decode
    fake = pub_Dec(data)
    fake = fake.view(fake.size(0), 1, 28, 28)

    # Plot Sample Images
    fig, ax = plt.subplots(4, 8, figsize=(10, 5))
    for i in range(4):
        for j in range(8):
            ax[i, j].imshow(fake[i*8+j][0].detach().cpu().numpy(), cmap='gray')
            ax[i, j].axis('off')

    plt.show()
    break

In [None]:
# Load the generator
# run_id = "/home/jason/p2/runs/private_16-12_100_32_1_inf_1e-06_0.0_0.01_0.0001_0.5_64_5_200000_LeakyReLU" # decent
run_id = "/home/jason/p2/runs/private_16-12_100_32_1_inf_1e-06_0.4_0.005_0.0001_0.5_64_4_500000_LeakyReLU" # better
run_id = "/home/jason/p2/runs/16-12_100_32_1_38.0_1e-06_0.05_0.01_0.0001_0.5_64_5_200000_LeakyReLU" # somewhat reasonable noised
run_id = "/home/jason/p2/runs_gen_fc_3/public_128_100_32_1_inf_1e-06_0.0_0.01_5e-05_0.0_64_3_500000_LeakyReLU_0.0"
run_id = "/home/jason/p2/runs_gen_fc_3/public_256_100_32_1_inf_1e-06_0.0_0.01_5e-05_0.0_64_3_500000_LeakyReLU_0.0"

run_id = run_id.split("/")[-1]
run_fp = os.path.join('runs_gen_fc_3/', run_id)
args = parse_run_id(run_id)

for i in range(500000, 500000 + 1, 10000):
    gen_fp = os.path.join(run_fp, 'netG_{}.pt'.format(i))
    if os.path.exists(gen_fp):
        print("Loading {}".format(gen_fp))

        G = Generator_FC([256], args.nz).to(device)
        G.load_state_dict(torch.load(gen_fp))
        G.eval()

        generate_samples(G, args)

In [None]:
# Random Seeding
torch.manual_seed(0)
np.random.seed(0)


activation = 'LeakyReLU'
args = Args(
    # Model Parameters
    hidden=[16, 12], nz=100, ngf=32, nc=1, activation=activation,
    # Privacy Parameters
    epsilon=50.0, delta=1e-6, noise_multiplier=0.3, c_p=0.01, 
    # Training Parameters
    lr=1e-3, beta1=0.5, batch_size=16, n_d=3, n_g=int(1e4), lambda_gp=10.0,
)

# Generate Run ID
run_id = generate_run_id(args)


run_id = "public_16-12_100_32_1_inf_1e-06_0.4_0.005_0.0001_0.5_64_4_300000_LeakyReLU"
# run_id = "16-12_100_32_1_inf_1e-06_0.4_0.005_0.0005_0.5_64_5_50000"
# /home/jason/p2/runs/16-12_100_32_1_inf_1e-06_0.4_0.005_0.0005_0.5_64_5_50000
# run_id = "16-12_100_32_1_50.0_1e-06_0.6_0.005_0.0001_0.5_64_4_300000_Tanh"
# run_id = "16-12_100_32_1_50.0_1e-06_0.6_0.005_0.0001_0.5_64_5_300000_Tanh"
run_id = "16-12_100_32_1_50.0_1e-06_0.2_0.005_0.0001_0.5_64_4_300000_Tanh"

# Create Folder Path
run_fp = os.path.join('runs/', run_id)
run_fp

In [None]:
fps = os.listdir("runs_gen_fc/")
len(fps)

In [None]:
args
# 3,4,6*,8,9,10*,13*,16,19,20
# Args(hidden=[128], nz=50, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.005, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.01, lr=0.0005, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# * Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.005, lr=5e-05, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.01, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.005, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# * Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.005, lr=0.0005, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# * Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.01, lr=5e-05, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=50, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.01, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.005, lr=0.0005, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.01, lr=0.0005, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)

In [None]:
# bad
args
# 0,1,2,5,7,11,12,14,15,17,18,21
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.02, lr=5e-05, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.0, lr=0.0005, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=10.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.02, lr=0.0005, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=50, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.02, lr=5e-05, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.02, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.02, lr=0.0005, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.0, lr=0.0005, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=10.0)
# Args(hidden=[128], nz=50, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.0, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=10.0)
# Args(hidden=[128], nz=50, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.02, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=0.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.0, lr=5e-05, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=10.0)
# Args(hidden=[128], nz=50, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.0, lr=5e-05, beta1=0.0, batch_size=64, n_d=3, n_g=100000, activation='LeakyReLU', lambda_gp=10.0)
# Args(hidden=[128], nz=100, ngf=32, nc=1, epsilon='inf', delta=1e-06, noise_multiplier=0.0, c_p=0.0, lr=5e-05, beta1=0.0, batch_size=64, n_d=5, n_g=100000, activation='LeakyReLU', lambda_gp=10.0)

In [None]:
idx = 12
fp = os.path.join("runs_gen_fc/", fps[idx])
args = parse_run_id(fps[idx])

for i in range(0, 100000 + 1, 20000):
    gen_fp = os.path.join(fp, 'netG_{}.pt'.format(i))
    if os.path.exists(gen_fp):
        print("Loading {}".format(gen_fp))
        
        G = Generator_FC([128], args.nz).to(device)
        G.load_state_dict(torch.load(gen_fp))
        G.eval()

        generate_samples(G, args)

In [None]:
def plot_loss(run_fp):
    # Read loss.txt
    loss_fp = os.path.join(run_fp, 'loss.txt')
    epsilons = []
    d_loss, g_loss = [], []

    with open(loss_fp, 'r') as f:
        loss = f.read().splitlines()
        for i in range(len(loss)):
            if "time" in loss[i]:
                continue
            idx, l = loss[i].split(", ")

            if "." in idx:
                # Discriminator_FC Loss
                idx = int(float(idx))
                d_loss.append((idx, float(l)))
            else:
                # Generator Loss
                idx = int(float(idx))
                g_loss.append((idx, float(l)))

    # Graph Loss
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator_FC Loss During Training")
    plt.plot(*zip(*d_loss), label="Discriminator_FC")
    plt.plot(*zip(*g_loss), label="Generator")
    plt.xlabel("iterations")
    plt.ylabel("Loss")

    # Only show first 100 epochs
    # plt.xlim(-10, 20000)
    plt.legend()
    plt.show()

In [None]:
G

In [None]:
# Generate 2048 fake images
noise = torch.randn(2048, 100).to(device)
fake = G(noise)
fake = fake.view(fake.size(0), 1, 28, 28)

In [None]:
# Calculate Inception Score
IS = get_IS(fake)
print("Inception Score:", IS)

# Calculate Frechet Inception Distance
FID = get_FID(fake)
print("Frechet Inception Distance:", FID)