In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax.numpy as jnp
import jax
from jax import random, vmap, jit, lax
from src.load_data import load_data
from src.kmeans import SphericalKMeans
from functools import partial
from jax import random
import jax.scipy as jsp
import matplotlib.pyplot as plt

In [3]:
def flatten_axes(a: jax.Array, start: int = 0, end: int = -1) -> jax.Array:
    return a.reshape(a.shape[:start] + (-1,) + a.shape[end:][1:])


def get_patched_op(length, patch_radius):
    # @jit
    def img_to_patches(img):
        img_new = jnp.pad(
            img, ((patch_radius, patch_radius), (patch_radius, patch_radius), (0, 0))
        )
        # print(img_new.shape)

        def _get_patch(c0, c1):
            tmp = lax.dynamic_slice_in_dim(
                img_new, c0 - patch_radius, 2 * patch_radius + 1, axis=0
            )
            return lax.dynamic_slice_in_dim(
                tmp, c1 - patch_radius, 2 * patch_radius + 1, axis=1
            )

        get_patch = vmap(_get_patch, (None, 0), 0)
        get_patch = vmap(get_patch, (0, None), 0)
        index = jnp.arange(0, length) + patch_radius
        patches = get_patch(index, index)
        return patches

    return vmap(img_to_patches, 0, 0)


def patch_images(images, patch_radius):
    # take images in the form NxCxCxM
    # return images in the form NxCxCx(M*(2patch_radius+1)^2)
    image_length = images.shape[1]
    patcher = get_patched_op(image_length, patch_radius)
    return flatten_axes(patcher(images), start=-3)


def sample_patches(key, images, num_samples):
    # images NxCxCxM
    # return (Num_patches x M)
    flat_images = flatten_axes(images, 0, 2)
    samples = random.choice(
        key, flat_images, shape=(num_samples,), replace=False, axis=0
    )
    return samples


def train_forward(
    key, images, elem_wise_func, num_centroids=20, num_samples=10_000, num_iter=500
):
    sample_images = sample_patches(key, images, num_samples)
    norms = jnp.linalg.norm(sample_images, axis=1, keepdims=True)
    norms = jnp.where(norms < 1e-6, 1e-6, norms)
    sample_images = sample_images / norms
    centroids = (
        SphericalKMeans(nb_clusters=num_centroids, max_iter=num_iter)
        .fit(sample_images)[0]
        .T
    )
    transform_mat = jnp.real(jsp.linalg.sqrtm(elem_wise_func(centroids.T @ centroids)))
    return centroids, transform_mat


def forward_pass(images, element_wise, centroids, transform_mat, sections=32):
    @jit
    def _one_pixel(p):
        norm = jnp.max(jnp.array([jnp.linalg.norm(p), 1e-6]))
        inter = element_wise(centroids.T @ p / norm)
        return norm * transform_mat @ inter

    batches = jnp.array_split(images, sections)
    output = []
    for batch in batches:
        output.append(jnp.apply_along_axis(_one_pixel, -1, batch))
    return jnp.concatenate(output, axis=0)


def one_layer_gaussian_pooling(images, dilation, sigma):
    images = images[:, :, :, jnp.newaxis]
    x = jnp.linspace(-4, 4, 9)
    gauss_kernel = jsp.stats.norm.pdf(x, scale=sigma) * jsp.stats.norm.pdf(
        x[:, None], scale=sigma
    )
    kernel = gauss_kernel[:, :, jnp.newaxis, jnp.newaxis]
    out = lax.conv(
        jnp.transpose(images, [0, 3, 1, 2]),  # lhs = NCHW image tensor
        jnp.transpose(kernel, [3, 2, 0, 1]),  # rhs = OIHW conv kernel tensor
        (dilation, dilation),  # window strides
        "SAME",
    )  # padding mode
    return jnp.squeeze(jnp.transpose(out, [0, 2, 3, 1]))


gaussian_pooling = vmap(one_layer_gaussian_pooling, (3, None, None), 3)


# patch_radius = 2
# num_centroids = 20
# key = random.key(42)
# patched_images =pooling_sigma
# alpha_kernel_func = lambda x, alpha: jnp.exp(alpha * (x + 1))
# kernel_func = partial(alpha_kernel_func, 1)
# centroids, transform_mat = train_forward(
#     key, patched_images, kernel_func, num_centroids
# )

# print(centroids.shape, transform_mat.shape)
# forward_output = forward_pass(patched_images, kernel_func, centroids, transform_mat)
# print(forward_output.shape)
# pooled_forward_output = gaussian_pooling(forward_output,2,1)
# print(pooled_forward_output.shape)

In [16]:
class KNN:
    def __init__(
        self,
        element_wise_fun,
        patch_sizes=[3, 2],
        centroids_num=[75, 200],
        pooling=[2, 4],
        pooling_sigma=None,
    ) -> None:

        self.element_wise = element_wise_fun
        self.patch_size = patch_sizes
        self.centroids_num = centroids_num
        self.pooling = pooling
        self.pooling_sigma = pooling_sigma
        if pooling_sigma == None:
            pooling_sigma = [i / jnp.sqrt(2) for i in pooling]

        self.pooling_sigma = pooling_sigma
        layer_num = len(patch_sizes)

        if (
            len(centroids_num) != layer_num
            or len(pooling) != layer_num
            or len(pooling_sigma) != layer_num
        ):
            raise ValueError

        self.centroids = []
        self.transform_mat = []

    # patch_radius = 2
    # num_centroids = 20
    # key = random.key(42)
    # patched_images =pooling_sigma
    # alpha_kernel_func = lambda x, alpha: jnp.exp(alpha * (x + 1))
    # kernel_func = partial(alpha_kernel_func, 1)
    # centroids, transform_mat = train_forward(
    #     key, patched_images, kernel_func, num_centroids
    # )

    # print(centroids.shape, transform_mat.shape)
    # forward_output = forward_pass(patched_images, kernel_func, centroids, transform_mat)
    # print(forward_output.shape)
    # pooled_forward_output = gaussian_pooling(forward_output,2,1)
    # print(pooled_forward_output.shape)

    def forward(self, key, images):
        x = images
        keys = jax.random.split(key, len(self.patch_size) + 1)
        for i in range(len(self.patch_size)):
            x = patch_images(x, self.patch_size[i])
            cent, mat = train_forward(
                keys[i], x, self.element_wise, self.centroids_num[i]
            )
            self.centroids.append(cent)
            self.transform_mat.append(mat)
            x = forward_pass(x, self.element_wise, cent, mat, sections=100)
            # So that is fits in memeory
            batches = jnp.array_split(x, 100)
            output = []
            for batch in batches:
                output.append(
                    gaussian_pooling(batch, self.pooling[i], self.pooling_sigma[i])
                )
            x = jnp.concatenate(output, axis=0)
        return x

In [17]:
Xtr, Ytr, Xte = load_data("./data")
print(Xtr.shape, Ytr.shape, Xte.shape)

(5000, 32, 32, 3) (5000,) (2000, 32, 32, 3)


In [18]:
gaus_fun =  lambda x, alpha: jnp.exp(alpha * (x + 1))
element_wise = partial(gaus_fun, 0.35)
myKNN = KNN(element_wise)
key = jax.random.key(12)
output = myKNN.forward(key, Xtr)
output = flatten_axes(output, 1)

Spherical Kmean main loop starts
Spherical K means at iter 10 :  mean cosine similarity 0.4780
Spherical K means at iter 20 :  mean cosine similarity 0.4813
Spherical K means at iter 30 :  mean cosine similarity 0.4820
Spherical K means at iter 40 :  mean cosine similarity 0.4825
Spherical K means at iter 50 :  mean cosine similarity 0.4828
Spherical K means at iter 60 :  mean cosine similarity 0.4829
End of main loop


: 

In [15]:
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_tr = scaler.fit_transform(output)
X_train, X_val, y_train, y_val = train_test_split(
    X_tr, Ytr, test_size=0.2, random_state=42
)
clf = svm.SVC(kernel="rbf", verbose=False)
# Train the classifier on the training data
clf.fit(X_train, y_train)
# Make predictions on the validation data
y_pred = clf.predict(X_val)
# Calculate the accuracy
accuracy = accuracy_score(y_val, y_pred)
print(accuracy)

0.301


## Hyperparameters of the original source code:
*. Kernel elment_wise map sigma=0.25
*. Pooling: sigma=dilation/np.sqrt(2)
*. patch_size [3,2]
*. Subsampling [2,4]
*. Mapdim 12,200