In [1]:
from datasets import *

In [21]:
"""
This code implements functionalities required for computing the NT Kernel for multi-layer
fully-connected neural networks. The computed kernels are saved to the disk. 

The code is written for Python 3.6. 

Inputs: 
	noise_id: The index of the noise intensity: valid range 0 to 14.
	num_layers: The number of layers: valid range {2, 3, 4}.
"""

from __future__ import print_function
import math
import os 
import sys
import time
import numpy as np

from jax import random
from neural_tangents import stax

num_layers = 2
dataset_name = "CIFAR2"
labels = [0, 3]
ratio = 1
X, Y, Xtest, Ytest = load_dataset(dataset_name, labels=labels, ratio=1.0, grayscale=True)
X = X.reshape(X.shape[0], -1)
Xtest = Xtest.reshape(Xtest.shape[0], -1)


if num_layers == 2:
	init_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(1024), stax.Relu(), stax.Dense(1))
elif num_layers == 3:
	init_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(1024), stax.Relu(), stax.Dense(1024), stax.Relu(), stax.Dense(1))
elif num_layers == 4:
	init_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(1024), stax.Relu(), stax.Dense(1024), stax.Relu(), stax.Dense(1024), stax.Relu(), stax.Dense(1))
else:
	raise Exception('Non-valid Kernel')

n = X.shape[0]
print(f"Data points: {n}")
kernel = np.zeros((n, n), dtype=np.float32)
batch_size = 10
m = n / batch_size
m = int(m)
# To avoid memory overflow, for training data, we fill the kernel matrix block by block
for i in range(batch_size):
    for j in range(batch_size):
        print('%d and %d'%(i, j))
        x1 = X[i * m:(i + 1) * m, :].reshape(m, -1)
        x2 = X[j * m:(j + 1) * m, :].reshape(m, -1)
        kernel[i * m:(i + 1) * m, j * m:(j + 1) * m] = kernel_fn(x1, x2, 'ntk')
print(kernel.shape)
directory = './NTK_Kernels/'
directory += f"{dataset_name}/"+f"labels_{labels[0]}_{labels[1]}_ratio_{ratio}/"
if not os.path.exists(directory):
    os.makedirs(directory)
file_name = 'Train_NTK_layers_%d.npy'%(num_layers)
np.save(directory + file_name, kernel)

file_name = 'Test_NTK_layers_%d.npy'%(num_layers)
kernel = kernel_fn(Xtest, X, 'ntk')
np.save(directory + file_name, kernel)

Data points: 10000
0 and 0
0 and 1
0 and 2
0 and 3
0 and 4
0 and 5
0 and 6
0 and 7
0 and 8
0 and 9
1 and 0
1 and 1
1 and 2
1 and 3
1 and 4
1 and 5
1 and 6
1 and 7
1 and 8
1 and 9
2 and 0
2 and 1
2 and 2
2 and 3
2 and 4
2 and 5
2 and 6
2 and 7
2 and 8
2 and 9
3 and 0
3 and 1
3 and 2
3 and 3
3 and 4
3 and 5
3 and 6
3 and 7
3 and 8
3 and 9
4 and 0
4 and 1
4 and 2
4 and 3
4 and 4
4 and 5
4 and 6
4 and 7
4 and 8
4 and 9
5 and 0
5 and 1
5 and 2
5 and 3
5 and 4
5 and 5
5 and 6
5 and 7
5 and 8
5 and 9
6 and 0
6 and 1
6 and 2
6 and 3
6 and 4
6 and 5
6 and 6
6 and 7
6 and 8
6 and 9
7 and 0
7 and 1
7 and 2
7 and 3
7 and 4
7 and 5
7 and 6
7 and 7
7 and 8
7 and 9
8 and 0
8 and 1
8 and 2
8 and 3
8 and 4
8 and 5
8 and 6
8 and 7
8 and 8
8 and 9
9 and 0
9 and 1
9 and 2
9 and 3
9 and 4
9 and 5
9 and 6
9 and 7
9 and 8
9 and 9
(10000, 10000)


In [25]:
from pvr_datasets import *

In [28]:
X, Y = create_boolean_pvr(pointer_label, 3)
torch.save((X, Y), "data/pvr/pointer_3")
X, Y = torch.load("data/pvr/pointer_3")