In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
import torchvision
from torch.utils.data import Subset
from tqdm import tqdm
from torch.distributions.laplace import Laplace
import math
from ntk_utils import gen_h_dis, gen_alpha, gen_z_embed, process_query

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# the following code is adapted from
# https://github.com/josharnoldjosh/Resnet-Extract-Image-Feature-Pytorch-Python 

# Load the pretrained resnet18 model
model = models.resnet18(pretrained=True)
model = model.to(device)

# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

# Set model to evaluation mode
model.eval()

# Image transforms
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

def get_vector(img):
    # 2. Create a PyTorch Variable with the transformed image
    t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))
    # 3. Create a vector of zeros that will hold our feature vector
    #    The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(512).to(device)
    # 4. Define a function that will copy the output of a layer
    def copy_data(m, i, o):
        # my_embedding.copy_(o.data)
        my_embedding.copy_(o.data.reshape(o.data.size(1)))
    # 5. Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)
    # 6. Run the model on our transformed image
    t_img = t_img.to(device)
    model(t_img)
    # 7. Detach our copy function from the layer
    h.remove()

    return my_embedding

# download cifar10 dataset
ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)





Files already downloaded and verified


In [2]:

# mapping from class name to index
class_to_idx_dict = ds.class_to_idx
# mapping from index to class name
idx_to_cls_dict = {}

# collect all class name
class_name_list = []
# collect all class index
class_idx_list = []
for key, value in list(class_to_idx_dict.items()):
    class_name_list.append(key)
    class_idx_list.append(value)
    idx_to_cls_dict[value] = key

# # collect all train index by its class index
# ds_cls_idx_dict = {}
# for cls_idx in class_idx_list:
#     # init with empty list
#     ds_cls_idx_dict[cls_idx] = []

ds_train_idx_by_cls_list = [] 
for cls_idx in class_idx_list:
    ds_train_idx_by_cls_list.append([])

# start collecting
for i in tqdm(range(len(ds))):
    cur_cls_idx = ds[i][1]
    ds_train_idx_by_cls_list[cur_cls_idx].append(i)


100%|██████████| 50000/50000 [00:01<00:00, 34745.58it/s]


In [14]:
######### normalization x data start ###########
def calculate_norm(input_data):
    # input_data: n * d
    square_data = input_data * input_data
    norm_data = square_data.sum(dim=1)
    norm_data = torch.sqrt(norm_data)
    return norm_data

def data_normalization(input_data):
    # input data: n * d
    # print("data normalized")
    x_norm = calculate_norm(input_data)
    x_norm = x_norm[..., None]
    ball_data = input_data / x_norm

    return ball_data
######### normalization x data end ###########


train_num = 1000
test_num = 100
label_num = 10

total_train_idx_list = []
total_test_idx_list = []
for train_idx_list in ds_train_idx_by_cls_list:
    total_train_idx_list += train_idx_list[:train_num]
    total_test_idx_list += train_idx_list[-test_num:]

print(len(total_train_idx_list))
print(len(total_test_idx_list))

############ get img and label tensor according to idx start ###############
def get_img_label_tensor(idx_list):
    img_ts_list = []
    label_list = []
    for idx in tqdm(idx_list):
        image, label = ds[idx]
        img_ts = get_vector(image)
        img_ts_list.append(img_ts)
        label_list.append(label)

    # concat image tensor
    img_ts = torch.stack(img_ts_list, dim=0)

    label_ts = torch.zeros((len(idx_list), label_num), dtype=torch.float32)
    # set the negative label to -1
    label_ts -= 1.0
    for i, label in enumerate(label_list):
        # set corresponding label to 1
        label_ts[i][label] = 1.0
    # print(img_ts.shape)
    # print(len(label_list))

    cls_index_label_ts = torch.tensor(label_list, dtype=torch.int64)

    return img_ts, label_ts, cls_index_label_ts
############ get img and label tensor according to idx end ###############

# img_ts: n * 512
# label_ts: n * 10
# cls_index_label_ts: n * 1
train_img_ts, train_label_ts, train_cls_index_label_ts = get_img_label_tensor(total_train_idx_list)
test_img_ts, test_label_ts, test_cls_index_label_ts = get_img_label_tensor(total_test_idx_list)

train_img_ts = data_normalization(train_img_ts)
test_img_ts = data_normalization(test_img_ts)

m = 256
reg_lambda = 10.0


cpu_device = torch.device("cpu")

x_data = train_img_ts.to(cpu_device)
y_data = train_label_ts.to(cpu_device)
n, d = x_data.shape

# generate w_r
w_r = torch.randn((m, d), dtype=torch.float32).to(cpu_device)
# h_dis: n * n
h_dis = gen_h_dis(w_r, x_data)

print("hdis shape", h_dis.shape)

# calculate NTK Regression alpha
alpha = gen_alpha(h_dis, reg_lambda, y_data)

print("alpha shape", alpha.shape)
    

10000
1000


100%|██████████| 10000/10000 [00:49<00:00, 200.18it/s]
100%|██████████| 1000/1000 [00:05<00:00, 199.84it/s]


hdis shape torch.Size([10000, 10000])
alpha shape torch.Size([10000, 10])


In [19]:
# from ntk_utils import process_10_cls_query
from ntk_utils import gen_z_embed

# cpu_device = torch.device("cpu")

def process_10_cls_query(z, w_r, x_data, alpha):
    # z denote the query, nz denote the query num
    # z: nz * d
    # w_r: m * d
    # x_data: n * d
    # alpha: n * 10
    # return: pred: nz * 1

    # nz * n
    query_embed = gen_z_embed(z, x_data, w_r)

    # nz * 10
    query_pred = query_embed @ alpha

    query_result = torch.argmax(query_pred, dim=1)

    # nz * 1
    return query_result


def test_accuracy_for_10_cls(test_dataset, gt_label, w_r, x_data, alpha):
    # test_dataset: n * 512
    # gt_label: n * 1
    
    # pred: nz * 1
    pred = process_10_cls_query(test_dataset, w_r, x_data, alpha)
    nz = pred.shape[0]
    succ_cnt = torch.sum(pred == gt_label)
    test_acc = succ_cnt / nz
    return test_acc

# test_img_ts = test_img_ts.to(cpu_device)
# test_cls_index_label_ts = test_cls_index_label_ts.to(cpu_device)
# w_r = w_r.to(cpu_device)
# train_img_ts = train_img_ts.to(cpu_device)
# alpha = alpha.to(cpu_device)


unprivate_train_acc = test_accuracy_for_10_cls(train_img_ts, train_cls_index_label_ts, w_r, train_img_ts, alpha)
unprivate_test_acc = test_accuracy_for_10_cls(test_img_ts, test_cls_index_label_ts, w_r, train_img_ts, alpha)

print(unprivate_train_acc, unprivate_test_acc)


tensor(0.8669) tensor(0.8340)


In [22]:



def cal_k(eps, delta, beta):
    eta = 7e-3
    n = 1e3
    k_bound = (eps * eps * eta * eta) / (8 * math.log(1 / delta) * n * n * beta * beta)
    k = int(math.floor(k_bound))
    return k


def gaussain_sampling_on_k(h_dis, y_data, reg_lambda, k=None):
    
    test_acc_list = []
    train_acc_list = []

    repeat_time = 10
    for _ in tqdm(range(repeat_time)):
        if k is None:
            wt_h_dis = h_dis
        else:
            # setup gaussian sampler
            gaussian_sampler = torch.distributions.MultivariateNormal(
                loc=torch.zeros(n).to(device), covariance_matrix=h_dis
            )

            # gausian sampling
            wt_h_dis = torch.empty(k, n, n).to(device)

            for i in range(k):
                sample_vec = gaussian_sampler.sample()
                wt_h_dis[i] = sample_vec[..., None] @ sample_vec[None, ...]

            # take mean over dim k
            # n * n
            wt_h_dis = wt_h_dis.mean(dim=0)

        alpha = gen_alpha(wt_h_dis, reg_lambda, y_data)
        alpha = alpha / n

        cur_train_acc = test_accuracy_for_10_cls(train_img_ts, train_cls_index_label_ts, w_r, train_img_ts, alpha)
        cur_test_acc = test_accuracy_for_10_cls(test_img_ts, test_cls_index_label_ts, w_r, train_img_ts, alpha)

        train_acc_list.append(cur_train_acc)
        test_acc_list.append(cur_test_acc)

        cur_train_mean = sum(train_acc_list) / len(train_acc_list)
        cur_test_mean = sum(test_acc_list) / len(test_acc_list)
        print(cur_train_mean, cur_test_mean)

    final_train_acc = sum(train_acc_list) / len(train_acc_list)
    final_test_acc = sum(test_acc_list) / len(test_acc_list)

    return final_train_acc, final_test_acc


# fix beta and delta
beta = 1e-6
delta = 1e-3

# we run different eps exponent, from (0.5, 1.5)
eps_exponent_list = [0.5 + i * 0.1 for i in range(11)]

for eps_exponent in eps_exponent_list:
    eps = 10 ** eps_exponent
    # calculate number of Gaussian Samples according to eps
    k = cal_k(eps, delta, beta)

    print("-" * 50)

    print(f"eps exponent {eps_exponent}, k {k}")

    # h_dis = h_dis.to(cpu_device)
    # y_data = y_data.to(cpu_device)
    # get private test acc and train acc
    final_test_acc, final_train_acc = gaussain_sampling_on_k(h_dis, y_data, reg_lambda, k)
    
    print("test acc", final_test_acc)
    print("train acc", final_train_acc)

    print("-" * 50)

        

--------------------------------------------------
eps exponent 0.5, k 8


  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:33<05:00, 33.36s/it]

tensor(0.2443) tensor(0.2540)


 20%|██        | 2/10 [01:07<04:31, 33.99s/it]

tensor(0.2418) tensor(0.2455)


 30%|███       | 3/10 [01:41<03:56, 33.78s/it]

tensor(0.3376) tensor(0.3350)


 40%|████      | 4/10 [02:14<03:21, 33.65s/it]

tensor(0.3839) tensor(0.3830)


 50%|█████     | 5/10 [02:48<02:48, 33.69s/it]

tensor(0.4131) tensor(0.4116)


 60%|██████    | 6/10 [03:23<02:16, 34.23s/it]

tensor(0.4448) tensor(0.4420)


 70%|███████   | 7/10 [03:57<01:42, 34.17s/it]

tensor(0.4502) tensor(0.4501)


 80%|████████  | 8/10 [04:31<01:07, 33.97s/it]

tensor(0.4235) tensor(0.4232)


 90%|█████████ | 9/10 [05:08<00:35, 35.04s/it]

tensor(0.4204) tensor(0.4183)


100%|██████████| 10/10 [05:43<00:00, 34.33s/it]


tensor(0.4302) tensor(0.4266)
test acc tensor(0.4302)
train acc tensor(0.4266)
--------------------------------------------------
--------------------------------------------------
eps exponent 0.6, k 14


 10%|█         | 1/10 [00:36<05:27, 36.35s/it]

tensor(0.3983) tensor(0.3990)


 20%|██        | 2/10 [01:11<04:44, 35.62s/it]

tensor(0.4697) tensor(0.4645)


 30%|███       | 3/10 [01:49<04:15, 36.50s/it]

tensor(0.4989) tensor(0.4977)


 40%|████      | 4/10 [02:23<03:35, 35.86s/it]

tensor(0.4720) tensor(0.4710)


 50%|█████     | 5/10 [02:58<02:56, 35.39s/it]

tensor(0.4366) tensor(0.4358)


 60%|██████    | 6/10 [03:33<02:20, 35.16s/it]

tensor(0.4343) tensor(0.4338)


 70%|███████   | 7/10 [04:09<01:46, 35.62s/it]

tensor(0.4440) tensor(0.4417)


 80%|████████  | 8/10 [04:44<01:10, 35.44s/it]

tensor(0.4604) tensor(0.4578)


 90%|█████████ | 9/10 [05:19<00:35, 35.25s/it]

tensor(0.4446) tensor(0.4423)


100%|██████████| 10/10 [05:54<00:00, 35.49s/it]


tensor(0.4538) tensor(0.4527)
test acc tensor(0.4538)
train acc tensor(0.4527)
--------------------------------------------------
--------------------------------------------------
eps exponent 0.7, k 22


 10%|█         | 1/10 [00:35<05:22, 35.85s/it]

tensor(0.5540) tensor(0.5500)


 20%|██        | 2/10 [01:11<04:44, 35.56s/it]

tensor(0.4869) tensor(0.4875)


 30%|███       | 3/10 [01:47<04:11, 35.91s/it]

tensor(0.5540) tensor(0.5467)


 40%|████      | 4/10 [02:23<03:35, 35.95s/it]

tensor(0.5244) tensor(0.5180)


 50%|█████     | 5/10 [02:59<02:59, 35.91s/it]

tensor(0.5111) tensor(0.5088)


 60%|██████    | 6/10 [03:35<02:24, 36.01s/it]

tensor(0.4490) tensor(0.4477)


 70%|███████   | 7/10 [04:11<01:47, 35.83s/it]

tensor(0.4649) tensor(0.4647)


 80%|████████  | 8/10 [04:46<01:11, 35.80s/it]

tensor(0.4759) tensor(0.4741)


 90%|█████████ | 9/10 [05:22<00:35, 35.92s/it]

tensor(0.4741) tensor(0.4712)


100%|██████████| 10/10 [05:58<00:00, 35.84s/it]


tensor(0.4712) tensor(0.4692)
test acc tensor(0.4712)
train acc tensor(0.4692)
--------------------------------------------------
--------------------------------------------------
eps exponent 0.8, k 35


 10%|█         | 1/10 [00:38<05:42, 38.06s/it]

tensor(0.3289) tensor(0.3090)


 20%|██        | 2/10 [01:16<05:05, 38.24s/it]

tensor(0.3591) tensor(0.3525)


 30%|███       | 3/10 [01:54<04:26, 38.00s/it]

tensor(0.4458) tensor(0.4380)


 40%|████      | 4/10 [02:32<03:48, 38.15s/it]

tensor(0.4469) tensor(0.4417)


 50%|█████     | 5/10 [03:10<03:10, 38.12s/it]

tensor(0.4531) tensor(0.4472)


 60%|██████    | 6/10 [03:49<02:33, 38.28s/it]

tensor(0.4783) tensor(0.4707)


 70%|███████   | 7/10 [04:26<01:54, 38.09s/it]

tensor(0.4892) tensor(0.4834)


 80%|████████  | 8/10 [05:03<01:15, 37.56s/it]

tensor(0.4729) tensor(0.4679)


 90%|█████████ | 9/10 [05:41<00:37, 37.61s/it]

tensor(0.4766) tensor(0.4712)


100%|██████████| 10/10 [06:18<00:00, 37.83s/it]


tensor(0.4938) tensor(0.4886)
test acc tensor(0.4938)
train acc tensor(0.4886)
--------------------------------------------------
--------------------------------------------------
eps exponent 0.9, k 55


 10%|█         | 1/10 [00:40<06:00, 40.04s/it]

tensor(0.7190) tensor(0.6830)


 20%|██        | 2/10 [01:20<05:22, 40.30s/it]

tensor(0.6549) tensor(0.6370)


 30%|███       | 3/10 [02:01<04:44, 40.60s/it]

tensor(0.6066) tensor(0.5933)


 40%|████      | 4/10 [02:41<04:02, 40.37s/it]

tensor(0.6445) tensor(0.6325)


 50%|█████     | 5/10 [03:21<03:21, 40.38s/it]

tensor(0.6154) tensor(0.6058)


 60%|██████    | 6/10 [04:02<02:41, 40.42s/it]

tensor(0.6154) tensor(0.6082)


 70%|███████   | 7/10 [04:44<02:02, 40.82s/it]

tensor(0.6085) tensor(0.6014)


 80%|████████  | 8/10 [05:25<01:21, 40.93s/it]

tensor(0.6168) tensor(0.6083)


 90%|█████████ | 9/10 [06:05<00:40, 40.58s/it]

tensor(0.6226) tensor(0.6140)


100%|██████████| 10/10 [06:44<00:00, 40.43s/it]


tensor(0.6226) tensor(0.6154)
test acc tensor(0.6226)
train acc tensor(0.6154)
--------------------------------------------------
--------------------------------------------------
eps exponent 1.0, k 88


 10%|█         | 1/10 [00:43<06:29, 43.25s/it]

tensor(0.6245) tensor(0.6070)


 20%|██        | 2/10 [01:29<05:58, 44.84s/it]

tensor(0.6505) tensor(0.6335)


 30%|███       | 3/10 [02:13<05:12, 44.59s/it]

tensor(0.6600) tensor(0.6497)


 40%|████      | 4/10 [02:58<04:28, 44.73s/it]

tensor(0.6643) tensor(0.6540)


 50%|█████     | 5/10 [03:43<03:43, 44.67s/it]

tensor(0.6573) tensor(0.6456)


 60%|██████    | 6/10 [04:27<02:58, 44.67s/it]

tensor(0.6708) tensor(0.6593)


 70%|███████   | 7/10 [05:12<02:14, 44.84s/it]

tensor(0.6770) tensor(0.6653)


 80%|████████  | 8/10 [05:56<01:29, 44.60s/it]

tensor(0.6770) tensor(0.6654)


 90%|█████████ | 9/10 [06:41<00:44, 44.50s/it]

tensor(0.6691) tensor(0.6587)


100%|██████████| 10/10 [07:25<00:00, 44.57s/it]


tensor(0.6634) tensor(0.6528)
test acc tensor(0.6634)
train acc tensor(0.6528)
--------------------------------------------------
--------------------------------------------------
eps exponent 1.1, k 140


 10%|█         | 1/10 [00:50<07:33, 50.37s/it]

tensor(0.7462) tensor(0.7120)


 20%|██        | 2/10 [01:41<06:46, 50.83s/it]

tensor(0.7218) tensor(0.6910)


 30%|███       | 3/10 [02:32<05:55, 50.80s/it]

tensor(0.7111) tensor(0.6837)


 40%|████      | 4/10 [03:22<05:03, 50.62s/it]

tensor(0.6976) tensor(0.6730)


 50%|█████     | 5/10 [04:12<04:11, 50.36s/it]

tensor(0.7024) tensor(0.6806)


 60%|██████    | 6/10 [05:04<03:23, 50.81s/it]

tensor(0.6999) tensor(0.6795)


 70%|███████   | 7/10 [05:54<02:32, 50.67s/it]

tensor(0.6918) tensor(0.6709)


 80%|████████  | 8/10 [06:45<01:41, 50.89s/it]

tensor(0.6939) tensor(0.6754)


 90%|█████████ | 9/10 [07:37<00:51, 51.15s/it]

tensor(0.6988) tensor(0.6801)


100%|██████████| 10/10 [08:28<00:00, 50.81s/it]


tensor(0.7072) tensor(0.6876)
test acc tensor(0.7072)
train acc tensor(0.6876)
--------------------------------------------------
--------------------------------------------------
eps exponent 1.2000000000000002, k 222


 10%|█         | 1/10 [01:01<09:12, 61.40s/it]

tensor(0.7537) tensor(0.7350)


 20%|██        | 2/10 [02:01<08:03, 60.39s/it]

tensor(0.7582) tensor(0.7340)


 30%|███       | 3/10 [03:01<07:03, 60.54s/it]

tensor(0.7628) tensor(0.7353)


 40%|████      | 4/10 [04:02<06:03, 60.62s/it]

tensor(0.7591) tensor(0.7330)


 50%|█████     | 5/10 [05:03<05:04, 60.87s/it]

tensor(0.7641) tensor(0.7362)


 60%|██████    | 6/10 [06:04<04:03, 60.83s/it]

tensor(0.7476) tensor(0.7238)


 70%|███████   | 7/10 [07:06<03:03, 61.24s/it]

tensor(0.7463) tensor(0.7230)


 80%|████████  | 8/10 [08:06<02:01, 60.88s/it]

tensor(0.7489) tensor(0.7254)


 90%|█████████ | 9/10 [09:09<01:01, 61.41s/it]

tensor(0.7470) tensor(0.7237)


100%|██████████| 10/10 [10:11<00:00, 61.10s/it]


tensor(0.7508) tensor(0.7267)
test acc tensor(0.7508)
train acc tensor(0.7267)
--------------------------------------------------
--------------------------------------------------
eps exponent 1.3, k 352


 10%|█         | 1/10 [01:55<17:17, 115.25s/it]

tensor(0.8141) tensor(0.7860)


 20%|██        | 2/10 [03:51<15:27, 115.91s/it]

tensor(0.7944) tensor(0.7685)


 30%|███       | 3/10 [05:40<13:09, 112.85s/it]

tensor(0.7727) tensor(0.7543)


 40%|████      | 4/10 [07:39<11:30, 115.15s/it]

tensor(0.7647) tensor(0.7448)


 50%|█████     | 5/10 [09:34<09:35, 115.03s/it]

tensor(0.7687) tensor(0.7488)
