In [1]:
import time
import math
import copy
import torch
import random
import logging
import warnings
import datetime

import numpy as np
import seaborn as sns
import torch.nn as nn
import torch.optim as opt
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn import linear_model
from torch.autograd import grad
from torch.autograd.functional import vhp
from data_processing import get_data_adult
from torch.utils.data import Subset, DataLoader
from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score

plt.rcParams['figure.dpi'] = 300
warnings.filterwarnings("ignore")

E = math.e

### Utility Functions

In [2]:
def graph(data, title, x_label, y_label, labels):
    sns.set(font_scale=1)

    ticks = np.arange(0, len(data[2]), step=1)
    plt.xticks(ticks=ticks, labels=data[2], rotation='vertical')
    
    plt.plot(data[0], 'b-', linewidth=2.0, label=labels[0])
    plt.plot(data[1], 'r-', linewidth=2.0, label=labels[1])
    plt.plot([a_i - b_i for a_i, b_i in zip(data[0], data[1])], 'k', linewidth=2.0, label='difference')
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [3]:
 class CreateData(torch.utils.data.Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.targets[idx]

        return out_data, out_label

### Randomized Response

In [4]:
def randomize_resp(label, epsilon):

    probability = float(E ** epsilon) / float(1 + (E ** epsilon))
    
    if label == 0:
        new_label = np.random.choice([0,1], p=[probability, 1-probability])
    else:
        new_label = np.random.choice([0,1], p=[1-probability, probability])

    return new_label

### Models

In [5]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, num_features, weight_decay, device):
        super(LogisticRegression, self).__init__()

        self.wd = torch.FloatTensor([weight_decay]).to(device)
        self.w = torch.nn.Parameter(torch.zeros([num_features], requires_grad=True))

    def forward(self, x):
        logits = torch.matmul(x, torch.reshape(self.w, [-1, 1]))

        return logits
    
    def loss(self, logits, y):
        preds = torch.sigmoid(logits)
        loss = -torch.mean(y * log_clip(preds) + (1 - y) * log_clip(1 - preds))

        return loss

In [6]:
def log_clip(x):
    return torch.log(torch.clamp(x, 1e-10, None))

### Influence Calculation Functions


In [7]:
def calc_influence_single(model, epsilon, train_data, test_data, device, rec_depth, r, damp, scale):
    
    s_test_vec = s_test_sample(model, train_data, [test_data[1], test_data[2]], device, rec_depth, r, damp, scale)

    # Calculate the influence function
    train_dataset_size = len(train_data[0].dataset)
    y_perts = []

    for i, y_ in enumerate(train_data[2]):
        y_pert = randomize_resp(y_, epsilon)
        y_perts.append(y_pert)

    time_a = datetime.datetime.now()

    grad_z_vec = grad_training([train_data[1], train_data[2]], y_perts, model, device)

    time_b = datetime.datetime.now()
    time_delta = time_b - time_a
    logging.info(f"Time for grad_z iter:" f" {time_delta.total_seconds() * 1000}")

    with torch.no_grad():
        influence = sum([torch.sum(k * j).data for k, j in zip(grad_z_vec, s_test_vec)])

    return influence.cpu(), y_perts

In [8]:
def s_test_sample(model, train_data, test_data, device, rec_depth, r, damp, scale):

    inverse_hvp = [torch.zeros_like(params, dtype=torch.float) for params in model.parameters()]

    for i in range(r):

        hessian_loader = DataLoader(train_data[0].dataset, sampler=torch.utils.data.RandomSampler(train_data[0].dataset, True, num_samples=rec_depth),batch_size=1,num_workers=4,)

        cur_estimate = s_test(test_data, model, i, hessian_loader, device, damp, scale)

        with torch.no_grad():
            inverse_hvp = [old + (cur / scale) for old, cur in zip(inverse_hvp, cur_estimate)]

    with torch.no_grad():
        inverse_hvp = [component / r for component in inverse_hvp]

    return inverse_hvp

In [9]:
def grad_z(test_data, model, device):

    model.eval()

    test_data_features = test_data[0]
    test_data_labels = test_data[1]

    logits = model(test_data_features)
    prediction = torch.sigmoid(logits)
    loss = -torch.mean(test_data_labels * log_clip(prediction) + (1 - test_data_labels) * log_clip(1 - prediction))
    
    return grad(loss, model.parameters())

In [10]:
def grad_training(train_data, y_perts, model, device):
    
    model.eval()

    train_data_features = torch.FloatTensor(train_data[0]).to(device)
    train_data_labels = torch.FloatTensor(train_data[1]).to(device)
    train_pert_data_labels = torch.FloatTensor(y_perts).to(device)

    logits = model(train_data_features)
    prediction = torch.sigmoid(logits)
    orig_loss = -torch.mean(train_data_labels * log_clip(prediction) + (1 - train_data_labels) * log_clip(1 - prediction))
    pert_loss = -torch.mean(train_pert_data_labels * log_clip(prediction) + (1 - train_pert_data_labels) * log_clip(1 - prediction))
    
    loss = (pert_loss -  orig_loss)
    
    return grad(loss, model.parameters())

In [11]:
def s_test(test_data, model, i, hessian_loader, device, damp, scale):

    v = grad_z(test_data, model, device)
    print(v)
    h_estimate = v

    params, names = make_functional(model)
    params = tuple(p.detach().requires_grad_() for p in params)

    progress_bar = tqdm(hessian_loader, desc=f"IHVP sample {i}")
    for i, (x_train, y_train) in enumerate(progress_bar):

        x_train, y_train = x_train.type(torch.FloatTensor).to(device), y_train.type(torch.FloatTensor).to(device)

        def f(*new_params):
            load_weights(model, names, new_params)
            out = model(x_train)
            loss = calc_loss(out, y_train)
            return loss

        hv = vhp(f, params, tuple(h_estimate), strict=True)[1]

        # Recursively calculate h_estimate
        
        # Influence = grad(x_test)^T*(H_1 + lambda*I)*\sum grad(x_train)
        with torch.no_grad():
            h_estimate = [_v + (1 - damp) * _h_e - _hv / scale for _v, _h_e, _hv in zip(v, h_estimate, hv)]

            if i % 100 == 0:
                norm = sum([h_.norm() for h_ in h_estimate])
                progress_bar.set_postfix({"est_norm": norm.item()})

    with torch.no_grad():
        load_weights(model, names, params, as_params=True)

    return h_estimate


In [12]:
def calc_loss(logits, labels):
    preds = torch.sigmoid(logits)
    loss = -torch.mean(labels * log_clip(preds) + (1 - labels) * log_clip(1 - preds))

    return loss

In [13]:
def make_functional(model):
    orig_params = tuple(model.parameters())

    names = []

    for name, p in list(model.named_parameters()):
        del_attr(model, name.split("."))
        names.append(name)

    return orig_params, names

In [14]:
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

In [15]:
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

In [16]:
def load_weights(model, names, params, as_params=False):
    for name, p in zip(names, params):
        if not as_params:
            set_attr(model, name.split("."), p)
        else:
            set_attr(model, name.split("."), torch.nn.Parameter(p))

### Perform Influence Calculation and LOO

In [17]:
class TestLeaveOneOut():
    def test_leave_one_out(self, epsilon, weight_decay, rec_depth, r, scale, damp):

        num_features, train_data, test_data = get_data_adult()

        device = 'cuda:5' if torch.cuda.is_available() else 'cpu'

        train_sample_num = len(train_data[1])

        train_dataset = CreateData(train_data[0], train_data[1])
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)

        # prepare sklearn model to train w
        C = 1.0 / (train_sample_num * weight_decay)
        sklearn_model = linear_model.LogisticRegression(C=C, solver='lbfgs', tol=1e-8, fit_intercept=False)

        # prepare pytorch model to compute influence function
        torch_model = LogisticRegression(num_features, weight_decay, device)

        # train
        sklearn_model.fit(train_data[0], train_data[1])
        pred_logr = sklearn_model.predict(test_data[0])
        score = accuracy_score(test_data[1], pred_logr)
        print(f'lbfgs training took {sklearn_model.n_iter_} iter. Accuracy: {score:0.3f}'  )

        # assign W into pytorch model
        w_opt = sklearn_model.coef_.ravel()
        with torch.no_grad():
            torch_model.w = torch.nn.Parameter(
                torch.tensor(w_opt, dtype=torch.float)
            )
        torch_model = torch_model.to(device)
        
        # calculate original loss
        x_test_input = torch.FloatTensor(test_data[0]).to(device)
        y_test_input = torch.LongTensor(test_data[1]).to(device)

        test_dataset = CreateData(test_data[0], test_data[1])
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
      
        test_loss_ori = torch_model.loss(torch_model(x_test_input), y_test_input).detach().cpu().numpy()

        loss_diff_approx, train_data_perts = calc_influence_single(torch_model, epsilon, [train_loader, train_data[0], train_data[1]], [test_loader, x_test_input, y_test_input], device, rec_depth, r, damp, scale)
        
        loss_diff_approx = - torch.FloatTensor(loss_diff_approx).cpu().numpy()

        # retrain
        sklearn_model_pert = linear_model.LogisticRegression(C=C, fit_intercept=False, tol=1e-8, solver='lbfgs')
        sklearn_model_pert.fit(train_data[0], train_data_perts)
        
        pred_logr = sklearn_model_pert.predict(test_data[0])
        score = accuracy_score(test_data[1], pred_logr)
        print(f'Perturbation lbfgs training took {sklearn_model_pert.n_iter_} iter. Accuracy: {score:0.3f}'  )
       
        w_retrain = sklearn_model_pert.coef_.T.ravel()
        with torch.no_grad():
            torch_model.w = torch.nn.Parameter(
                torch.tensor(w_retrain, dtype=torch.float)
            )

        torch_model = torch_model.to(device)

        # get retrain loss
        test_loss_retrain = torch_model.loss(torch_model(x_test_input), y_test_input).detach().cpu().numpy()

        # get true loss diff
        loss_diff_true = test_loss_retrain - test_loss_ori
        
        est_loss_diff = loss_diff_approx
        avg_loss_diff = loss_diff_true
     
        print('Real avg. loss diff: ', avg_loss_diff, 'Est. avg. loss diff: ', est_loss_diff)
        
        return avg_loss_diff, est_loss_diff

In [18]:
def visualize_result(actual_loss_diff, estimated_loss_diff):
    r2_s = r2_score(actual_loss_diff, estimated_loss_diff)

    max_abs = np.max([np.abs(actual_loss_diff), np.abs(estimated_loss_diff)])
    min_, max_ = -max_abs * 1.1, max_abs * 1.1
    plt.rcParams['figure.figsize'] = 6, 5
    plt.scatter(actual_loss_diff, estimated_loss_diff, zorder=2, s=10)
    plt.title('Loss diff Pert.')
    plt.xlabel('Actual loss diff')
    plt.ylabel('Estimated loss diff')
    range_ = [min_, max_]
    plt.plot(range_, range_, 'k-', alpha=0.2, zorder=1)
    text = 'MAE = {:.03}\nR2 score = {:.03}'.format(mean_absolute_error(actual_loss_diff, estimated_loss_diff),
                                                    r2_s)
    plt.text(max_abs, -max_abs, text, verticalalignment='bottom', horizontalalignment='right')
    plt.xlim(min_, max_)
    plt.ylim(min_, max_)

    plt.savefig("result.png")

    return r2_s

In [19]:
epsilons = [.001, .005, .01, .05, .1, .15, .2, .25, .3, .35, .4, .45, .5, .55, .6, .65, .7, .75, .8, .85, .9, .95, 1, 5, 10]

all_avg = []
all_est = []

weight_decay = 0.01
rec_depth = 5350
r = 5
scale = 10
damp = .01 
rounds = 1

LOO = TestLeaveOneOut()

for ro in range(rounds):
    avg_losses = []
    est_losses = []
    for e in epsilons:
        print('Working on epsilon: ', e)
        avg_loss_diff, est_loss_diff = LOO.test_leave_one_out(e, weight_decay, rec_depth, r, scale, damp)
        avg_losses.append(avg_loss_diff)
        est_losses.append(est_loss_diff)
    all_avg.append(avg_losses)
    all_est.append(est_losses)
    
final_real = [0 for x in range(len(epsilons))]
final_est = [0 for x in range(len(epsilons))]

for i in range(len(final_real)):
    for j in range(len(all_avg)):
        final_real[i] = final_real[i] + all_avg[j][i]
        final_est[i] = final_est[i] + all_est[j][i]



Working on epsilon:  0.001
lbfgs training took [21] iter. Accuracy: 0.868
(tensor([ 2.2450e-03,  3.2799e-03,  2.4591e-03,  4.4891e-03,  4.0567e-03,
         1.8401e-03,  2.5347e-03,  1.1526e-02, -4.1895e-03, -3.0053e-03,
         7.0517e-03,  4.4526e-03, -2.0527e-03,  6.1148e-05,  3.2798e-02,
         2.2055e-04, -1.2353e-02, -4.0237e-04,  9.8094e-05, -1.4418e-03,
         2.0473e-03,  1.0371e-02, -3.8725e-05, -1.3131e-03, -8.3364e-04,
        -5.1374e-03, -7.1490e-05,  1.0301e-02,  6.7954e-04,  2.7084e-03,
         8.2644e-04,  2.7297e-04,  2.9098e-02, -3.7567e-03, -6.2755e-04,
        -6.9781e-03, -3.0594e-03,  3.6934e-03,  2.9833e-04,  2.3242e-03,
        -1.0353e-03,  5.6633e-04,  1.6216e-02,  4.9207e-05,  4.1101e-04,
         3.7966e-04,  5.8032e-05,  1.8531e-04,  5.8761e-05,  5.7860e-05,
         2.4978e-06,  3.1212e-04,  7.9009e-05,  3.1574e-04,  1.9264e-04,
         1.2895e-05,  7.8333e-05,  2.0767e-05,  5.2372e-05,  6.5280e-05,
         5.0709e-04,  1.0439e-04,  7.5282e-05,  1

IHVP sample 0:   1%|▎                                                  | 32/5350 [00:00<01:01, 86.36it/s, est_norm=0.15]


KeyboardInterrupt: 

In [None]:
graph([final_real, final_est, epsilons], 'Actual vs. Estimated Loss per Epsilon', 'epsilon', 'Average Loss', ['Actual', 'Estimated'])

In [None]:
r2_s = visualize_result(final_real, final_est)