In [3]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
import torchvision
from torch.utils.data import Subset
from tqdm import tqdm
from torch.distributions.laplace import Laplace


from ntk_utils import gen_h_dis, gen_alpha, gen_z_embed, process_query

# device =
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained model
model = models.resnet18(pretrained=True)

model = model.to(device)

# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

# Set model to evaluation mode
model.eval()

# Image transforms
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

def get_vector(img):
    # 1. Load the image with Pillow library
    # img = Image.open(image_name)
    # 2. Create a PyTorch Variable with the transformed image
    t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))
    # 3. Create a vector of zeros that will hold our feature vector
    #    The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(512).to(device)
    # 4. Define a function that will copy the output of a layer
    def copy_data(m, i, o):
        # my_embedding.copy_(o.data)
        my_embedding.copy_(o.data.reshape(o.data.size(1)))
    # 5. Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)
    # 6. Run the model on our transformed image
    # origin_device = t_img.device
    t_img = t_img.to(device)
    model(t_img)
    # 7. Detach our copy function from the layer
    h.remove()

    return my_embedding

    # my_embedding = my_embedding.to(origin_device)
    # # 8. Return the feature vector
    # return my_embedding.numpy()


ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)

def gen_2classes_indices(cls1_name, cls2_name):
    cls1_indices, cls2_indices, other_indices = [], [], []
    cls1_idx, cls2_idx = ds.class_to_idx[cls1_name], ds.class_to_idx[cls2_name]

    for i in range(len(ds)):
        current_class = ds[i][1]
        if current_class == cls1_idx:
            cls1_indices.append(i)
        elif current_class == cls2_idx:
            cls2_indices.append(i)
        else:
            other_indices.append(i)

    return cls1_indices, cls2_indices

def gen_feature_tensor(idx_list):
    img_ts_list = []
    for idx in tqdm(idx_list):
        image, label = ds[idx]
        img_ts = get_vector(image)
        # img_ts = torch.from_numpy(img_np)
        img_ts_list.append(img_ts)

    # concat
    ret_ts = torch.stack(img_ts_list, dim=0)
    return ret_ts


def test_accuracy(test_dataset, gt_label, w_r, x_data, alpha):
    pred = process_query(test_dataset, w_r, x_data, alpha)
    succ_cnt = torch.sum(pred == gt_label)
    nz = pred.shape[0]
    accuracy = succ_cnt / nz
    # print("accuracy", accuracy)
    return accuracy




In [17]:
# there are 5k images for 1 class
train_num = 1000
test_num = 100
label_scale = 1e12

cls1_name = "airplane"
cls2_name = "cat"

cls1_indices, cls2_indices = gen_2classes_indices(cls1_name, cls2_name)

cls1_train_ts = gen_feature_tensor(cls1_indices[:train_num]).to(device)
cls2_train_ts = gen_feature_tensor(cls2_indices[:train_num]).to(device)

cls1_label = torch.full((train_num, ), label_scale, dtype=torch.float32)
cls2_label = torch.full((train_num, ), -1 * label_scale, dtype=torch.float32)

cls1_test_ts = gen_feature_tensor(cls1_indices[-test_num:]).to(device)
cls2_test_ts = gen_feature_tensor(cls2_indices[-test_num:]).to(device)


############# test on NTK Regression start #################

m = 256
reg_lambda = 10.0

x_data = torch.cat((cls1_train_ts, cls2_train_ts), dim=0).to(device)
y_data = torch.cat((cls1_label, cls2_label), dim=0).to(device)

n, d = x_data.shape

# generate w_r
w_r = torch.randn((m, d), dtype=torch.float32).to(device)

h_dis = gen_h_dis(w_r, x_data)

alpha = gen_alpha(h_dis, reg_lambda, y_data)

# may scale down alpha here
alpha = alpha / (n * n)





100%|██████████| 1000/1000 [00:05<00:00, 197.22it/s]
100%|██████████| 1000/1000 [00:04<00:00, 201.71it/s]
100%|██████████| 100/100 [00:00<00:00, 197.86it/s]
100%|██████████| 100/100 [00:00<00:00, 185.70it/s]


In [34]:
def gaussain_sampling_on_k(h_dis, y_data, reg_lambda, cls1_test_ts, cls2_test_ts, k=None):
    
    n = h_dis.shape[0]

    test_acc_list = []
    train_acc_list = []

    repeat_time = 20

    for _ in tqdm(range(repeat_time)):
        if k is None:
            wt_h_dis = h_dis
        else:
            # setup gaussian sampler
            gaussian_sampler = torch.distributions.MultivariateNormal(
                loc=torch.zeros(n).to(device), covariance_matrix=h_dis
            )

            # gausian sampling
            wt_h_dis = torch.empty(k, n, n).to(device)

            for i in range(k):
                sample_vec = gaussian_sampler.sample()
                wt_h_dis[i] = sample_vec[..., None] @ sample_vec[None, ...]

            # take mean over dim k
            # n * n
            wt_h_dis = wt_h_dis.mean(dim=0)

        alpha = gen_alpha(wt_h_dis, reg_lambda, y_data)

        alpha = alpha / (n * n)

        cls1_accuracy = test_accuracy(cls1_test_ts, 1, w_r, x_data, alpha)
        cls2_accuracy = test_accuracy(cls2_test_ts, -1, w_r, x_data, alpha)

        cls1_train_acc = test_accuracy(cls1_train_ts, 1, w_r, x_data, alpha)
        cls2_train_acc = test_accuracy(cls2_train_ts, -1, w_r, x_data, alpha)

        cur_test_acc = (cls1_accuracy + cls2_accuracy) / 2
        cur_train_acc = (cls1_train_acc + cls2_train_acc) / 2

        # print(cur_test_acc, cur_train_acc)

        test_acc_list.append((cls1_accuracy + cls2_accuracy) / 2)
        train_acc_list.append((cls1_train_acc + cls2_train_acc) / 2)

    final_test_acc = sum(test_acc_list) / len(test_acc_list)
    final_train_acc = sum(train_acc_list) / len(train_acc_list)

    return final_test_acc, final_train_acc



# k = 1

# k_list = list(range(5, 51, 5))
k_list = [200]
draw_test_acc_list = []
draw_train_acc_list = []

for k in k_list:

    final_test_acc, final_train_acc = gaussain_sampling_on_k(h_dis, y_data, reg_lambda, cls1_test_ts, cls2_test_ts, k)

    # print(exponent)
    print(k)
    print("test acc", final_test_acc)
    print("train acc", final_train_acc)

    draw_test_acc_list.append(final_test_acc)
    draw_train_acc_list.append(final_train_acc)



100%|██████████| 20/20 [00:43<00:00,  2.19s/it]

200
test acc tensor(0.8952, device='cuda:0')
train acc tensor(0.9140, device='cuda:0')





In [27]:
final_test_acc, final_train_acc = gaussain_sampling_on_k(h_dis, y_data, reg_lambda, cls1_test_ts, cls2_test_ts, None)

# print(exponent)
# print(k
print("test acc", final_test_acc)
print("train acc", final_train_acc)

100%|██████████| 10/10 [00:12<00:00,  1.27s/it]

test acc tensor(0.9750, device='cuda:0')
train acc tensor(1., device='cuda:0')





In [29]:
draw_test_acc_list = [x.item() for x in draw_test_acc_list]
draw_train_acc_list = [x.item() for x in draw_train_acc_list]

In [30]:
# exponent_eps_list = [-3.0 + i * 0.2 for i in range(10)]
# draw_test_acc_list = []
# draw_train_acc_list = []

print(k_list)
print(draw_test_acc_list)
print(draw_train_acc_list)

[5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
[0.7695000171661377, 0.8557500243186951, 0.8659998774528503, 0.8422499895095825, 0.8502499461174011, 0.875999927520752, 0.8510000109672546, 0.8704999089241028, 0.8747499585151672, 0.8942500352859497]
[0.7809500098228455, 0.8646249771118164, 0.873699963092804, 0.8582750558853149, 0.8629249930381775, 0.890500009059906, 0.8567500114440918, 0.8806250691413879, 0.886650025844574, 0.9064000248908997]
