import statments

In [4]:
"""Empirical Sensitivity."""
import argparse
import os

import numpy as np
import torch
from torch import nn

from utils import get_data_loaders
from logistic_regression import nonprivate_logistic_regression
from torchvision import transforms

your code

In [11]:
def plot_hist(array_of_empirical_sensitivities, n, lmbda, name):
    if not isinstance(array_of_empirical_sensitivities, np.ndarray):
        raise ValueError('array_of_empirical_sensitivities should be a np.ndarray.')
    if not isinstance(name, str):
        raise ValueError('name should be a str')

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    ################################################################
    # TODO(student): replace below with correct theoretical max sensitivity
    max_theoretical_sensitivity = 2./(n*lmbda)
    ################################################################

    num_bins = 20
    dirname = './figs'
    filename = os.path.join(dirname, name) + '.histogram.png'
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    fig, ax = plt.subplots()
    ax.set_xscale('log')
    bin_values, _, _ = ax.hist(array_of_empirical_sensitivities, 
            num_bins, label='empirical sensitivities')
    ax.set_title('histogram of sensitivities: ' + name)
    ax.axvline(x=max_theoretical_sensitivity, color='r', linestyle='dashed', linewidth=2,
            label='theoretical max sensitivity')
    ax.legend()
    fig.savefig(filename)
    return filename


def plot_extreme_neighbors(sensitivities, list_of_neighboring_examples, name):
    """Plots to disk the neighboring-example pairs with the most and least empirical sensitivity
    
    Note on the data structures used: 
        sensitivities: a np.ndarray containing empirical sensitivities for each run
        list_of_neighboring_examples: a list of neighboring example pairs, one for each run. in other words:
        
        list_of_neighboring_examples = [
            neighboring_example_1, 
            neighboring_example_2,  
            ...
            neighboring_example_n,
            ]
            
        where each tuple in the list represents the data diff between the neighboring 
        datasets and is formatted like this:
        
        neighboring_example_i = (
            (neighbor_img_i, neighbor_label_i),
            (neighbor_img_i_prime, neighbor_label_i_prime),
        )
        
        See utils.py if you are still confused.
    """
    if not isinstance(sensitivities, np.ndarray):
        raise ValueError('sensitivies should be a np.ndarray.')
    first_neighbor_pair = list_of_neighboring_examples[0]
    if not isinstance(list_of_neighboring_examples, list) or not isinstance(first_neighbor_pair, tuple) \
            or not isinstance(first_neighbor_pair[0][0], torch.Tensor):
        raise ValueError('list_of_neighboring_examples should be a list of tuple pairs, where tuple contains img tensors')
    if not isinstance(name, str):
        raise ValueError('name should be a str')

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    ############################################################################
    # TODO(student)
    #
    # using list_of_empirical_sensitivies and neighboring_examples, create two image plots
    # 1) side-by-side images for neighbor-pair that maximizes sensitivity
    # 2) side-by-side images for neighbor-pair that minimizes sensitivity
    #
    # matplotlib.subplots and matplotlib.imshow may come in handy
    #
    pil_transform = transforms.ToPILImage()
    idx_max = np.argmax(sensitivities)
    idx_min = np.argmin(sensitivities)
    neighbor_example_max = list_of_neighboring_examples[idx_max]
    neighbor_example_min = list_of_neighboring_examples[idx_min]
    max_neighbor, max_neighbor_prime = neighbor_example_max
    min_neighbor, min_neighbor_prime = neighbor_example_min 
    
    max_neighbor_img, max_neighbor_target = max_neighbor
    max_neighbor_prime_img, max_neighbor_prime_target = max_neighbor_prime
    min_neighbor_img, min_neighbor_target = min_neighbor
    min_neighbor_prime_img, min_neighbor_prime_target = min_neighbor_prime
    
    dirname = './figs'
    filename1 = os.path.join(dirname, name)+'.maximum_sensitivity.png'
    filename2 = os.path.join(dirname, name)+'.minimum_sensitivity.png'
    filenames = filename1, filename2
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    fig1, axis1 = plt.subplots(1, 2, figsize=(10, 4))
    axis1[0].imshow(pil_transform(max_neighbor_img))
    axis1[1].imshow(pil_transform(max_neighbor_prime_img))
    fig1.suptitle('maximum sensitivity: '+name)
    fig1.savefig(filename1)
    
    fig2, axis2 = plt.subplots(1, 2, figsize=(10, 4))
    axis2[0].imshow(pil_transform(min_neighbor_img))
    axis2[1].imshow(pil_transform(min_neighbor_prime_img))
    fig2.suptitle('minimum sensitivity: '+name)
    fig2.savefig(filename2)
    #raise NotImplementedError
    ############################################################################

    return filenames


def compute_empricial_sensivity(train_loader, neighbor_loader,
        num_epochs, learning_rate, lmbda, model_seed=None):
    ############################################################################
    # TODO(student)
    #
    # your code here...
    #
    #
    #raise NotImplementedError
    non_private_train_params = nonprivate_logistic_regression(train_loader, num_epochs, learning_rate, lmbda, model_seed)
    non_private_neighbor_params = nonprivate_logistic_regression(neighbor_loader, num_epochs, learning_rate, lmbda, model_seed)
    sensitivity = torch.norm(non_private_train_params['weight'] - non_private_neighbor_params['weight'], p=2)
    ############################################################################

    return sensitivity

main function

In [12]:
def main(n, runs, epochs, lr, batch_size, model_seed, lmbda):
    list_of_empirical_sensitivies = []
    list_of_neighboring_examples = []
    for data_seed in range(runs):
        loaders, neighboring_examples = get_data_loaders(data_seed, batch_size, 
                num_train=n)
        sensitivity = compute_empricial_sensivity(
                loaders['train'], loaders['neighbor'],
                epochs, lr, lmbda, model_seed)
        list_of_empirical_sensitivies.append(sensitivity)
        list_of_neighboring_examples.append(neighboring_examples)

    list_of_empirical_sensitivies = np.array(list_of_empirical_sensitivies)
    sensitivity_upper_bound = 3.
    name = 'lambda={},n={}'.format(lmbda, n)
    filename = plot_hist(list_of_empirical_sensitivies, n, lmbda, name)
    print('see plot at', filename)

    filenames = plot_extreme_neighbors(list_of_empirical_sensitivies, list_of_neighboring_examples, name)
    print('see plots at {} and {}'.format(*filenames))

arguments and main function call

In [21]:
N = 1000
RUNS = 20  # TODO(student): run more times once your code works; something like 100
EPOCHS = 100
LR = 0.1
BATCH_SIZE = 256
MODEL_SEED = 0
LMBDA = 5e-3

main(N, RUNS, EPOCHS, LR, BATCH_SIZE, MODEL_SEED, LMBDA)

100%|██████████| 100/100 [00:10<00:00,  9.68it/s]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]
100%|██████████| 100/100 [00:10<00:00,  9.60it/s]
100%|██████████| 100/100 [00:10<00:00,  9.54it/s]
100%|██████████| 100/100 [00:10<00:00,  9.04it/s]
100%|██████████| 100/100 [00:10<00:00,  8.97it/s]
100%|██████████| 100/100 [00:10<00:00,  9.21it/s]
100%|██████████| 100/100 [00:10<00:00,  8.99it/s]
100%|██████████| 100/100 [00:10<00:00,  8.83it/s]
100%|██████████| 100/100 [00:10<00:00,  9.26it/s]
100%|██████████| 100/100 [00:11<00:00,  9.17it/s]
100%|██████████| 100/100 [00:10<00:00,  9.40it/s]
100%|██████████| 100/100 [00:11<00:00,  9.20it/s]
100%|██████████| 100/100 [00:10<00:00,  9.04it/s]
100%|██████████| 100/100 [00:10<00:00,  9.29it/s]
100%|██████████| 100/100 [00:10<00:00,  9.17it/s]
100%|██████████| 100/100 [00:10<00:00,  9.17it/s]
100%|██████████| 100/100 [00:11<00:00,  9.30it/s]
100%|██████████| 100/100 [00:11<00:00,  8.81it/s]
100%|██████████| 100/100 [00:10<00:00,  9.17it/s]


see plot at ./figs/lambda=0.005,n=1000.histogram.png
see plots at ./figs/lambda=0.005,n=1000.maximum_sensitivity.png and ./figs/lambda=0.005,n=1000.minimum_sensitivity.png
