In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax.numpy as jnp
import jax
from jax import random, vmap, jit, lax
from src.load_data import load_data
from src.kmeans import SphericalKMeans
from functools import partial
from jax import random
import jax.scipy as jsp
import matplotlib.pyplot as plt

In [65]:
def flatten_axes(a: jax.Array, start: int = 0, end: int = -1) -> jax.Array:
    return a.reshape(a.shape[:start] + (-1,) + a.shape[end:][1:])


def get_patched_op(length, patch_radius):
    # @jit
    def img_to_patches(img):
        img_new = jnp.pad(
            img, ((patch_radius, patch_radius), (patch_radius, patch_radius), (0, 0))
        )
        print(img_new.shape)

        def _get_patch(c0, c1):
            tmp = lax.dynamic_slice_in_dim(
                img_new, c0 - patch_radius, 2 * patch_radius + 1, axis=0
            )
            return lax.dynamic_slice_in_dim(
                tmp, c1 - patch_radius, 2 * patch_radius + 1, axis=1
            )

        get_patch = vmap(_get_patch, (None, 0), 0)
        get_patch = vmap(get_patch, (0, None), 0)
        index = jnp.arange(0, length) + patch_radius
        patches = get_patch(index, index)
        return patches

    return vmap(img_to_patches, 0, 0)


def patch_images(images, patch_radius):
    # take images in the form NxCxCxM
    # return images in the form NxCxCx(M*(2patch_radius+1)^2)
    image_length = images.shape[1]
    patcher = get_patched_op(image_length, patch_radius)
    return flatten_axes(patcher(images), start=-3)


def sample_patches(key, images, num_samples):
    # images NxCxCxM
    # return (Num_patches x M)
    flat_images = flatten_axes(images, 0, 2)
    samples = random.choice(
        key, flat_images, shape=(num_samples,), replace=False, axis=0
    )
    return samples


def train_forward(
    key, images, elem_wise_func, num_centroids=20, num_samples=10_000, num_iter=500
):
    sample_images = sample_patches(key, images, num_samples)
    norms = jnp.linalg.norm(sample_images, axis=1, keepdims=True)
    norms = jnp.where(norms < 1e-6, 1e-6, norms)
    sample_images = sample_images / norms
    centroids = (
        SphericalKMeans(nb_clusters=num_centroids, max_iter=num_iter)
        .fit(sample_images)[0]
        .T
    )
    transform_mat = jnp.real(jsp.linalg.sqrtm(elem_wise_func(centroids.T @ centroids)))
    return centroids, transform_mat


def forward_pass(images, element_wise, centroids, transform_mat, sections=32):
    @jit
    def _one_pixel(p):
        norm = jnp.max(jnp.array([jnp.linalg.norm(p), 1e-6]))
        inter = element_wise(centroids.T @ p / norm)
        return norm * transform_mat @ inter

    batches = jnp.array_split(images, sections)
    output = []
    for batch in batches:
        output.append(jnp.apply_along_axis(_one_pixel, -1, batch))
    return jnp.concatenate(output, axis=0)


def one_layer_gaussian_pooling(images, dilation, sigma):
    images = images[:, :, :, jnp.newaxis]
    x = jnp.linspace(-4, 4, 9)
    gauss_kernel = jsp.stats.norm.pdf(x, scale=sigma) * jsp.stats.norm.pdf(
        x[:, None], scale=sigma
    )
    kernel = gauss_kernel[:, :, jnp.newaxis, jnp.newaxis]
    out = lax.conv(
        jnp.transpose(images, [0, 3, 1, 2]),  # lhs = NCHW image tensor
        jnp.transpose(kernel, [3, 2, 0, 1]),  # rhs = OIHW conv kernel tensor
        (dilation, dilation),  # window strides
        "SAME",
    )  # padding mode
    return jnp.squeeze(jnp.transpose(out, [0, 2, 3, 1]))


gaussian_pooling = vmap(one_layer_gaussian_pooling, (3, None, None), 3)


class KNN:
    def __init__(
        self, element_wise_fun, patch_sizes=[], dimensions=[], pooling=[]
    ) -> None:
        self.patch_size = patch_sizes

    def train(self, images):
        pass

    def forward(self, images, batch_size):
        pass

In [66]:
Xtr, Ytr, Xte = load_data("./data")
print(Xtr.shape, Ytr.shape, Xte.shape)

(5000, 32, 32, 3) (5000,) (2000, 32, 32, 3)


In [68]:
patch_radius = 2
num_centroids = 20
key = random.key(42)
patched_images = patch_images(Xtr, patch_radius)
alpha_kernel_func = lambda x, alpha: jnp.exp(alpha * (x + 1))
kernel_func = partial(alpha_kernel_func, 1)
centroids, transform_mat = train_forward(
    key, patched_images, kernel_func, num_centroids
)

print(centroids.shape, transform_mat.shape)
forward_output = forward_pass(patched_images, kernel_func, centroids, transform_mat)
print(forward_output.shape)
pooled_forward_output = gaussian_pooling(forward_output, 2, 1)
print(pooled_forward_output.shape)

(36, 36, 3)
Spherical Kmean main loop starts
Spherical K means at iter 10 :  mean cosine similarity 0.4706
Spherical K means at iter 20 :  mean cosine similarity 0.4739
Spherical K means at iter 30 :  mean cosine similarity 0.4741
Spherical K means at iter 40 :  mean cosine similarity 0.4744
Spherical K means at iter 50 :  mean cosine similarity 0.4745
End of main loop
(75, 20) (20, 20)


In [63]:
def one_layer_gaussian_pooling(images, dilation, sigma):
    images = images[:, :, :, jnp.newaxis]
    x = jnp.linspace(-4, 4, 9)
    gauss_kernel = jsp.stats.norm.pdf(x, scale=sigma) * jsp.stats.norm.pdf(
        x[:, None], scale=sigma
    )
    kernel = gauss_kernel[:, :, jnp.newaxis, jnp.newaxis]
    out = lax.conv(
        jnp.transpose(images, [0, 3, 1, 2]),  # lhs = NCHW image tensor
        jnp.transpose(kernel, [3, 2, 0, 1]),  # rhs = OIHW conv kernel tensor
        (dilation, dilation),  # window strides
        "SAME",
    )  # padding mode
    return jnp.squeeze(jnp.transpose(out, [0, 2, 3, 1]))


gaussian_pooling = vmap(one_layer_gaussian_pooling, (3, None, None), 3)

In [64]:
gaussian_out = one_layer_gaussian_pooling(forward_output[:, :, :, 1], 2, 1)
print(gaussian_out.shape)

(5000, 32, 32)
(9, 9, 1, 1)
(5000, 16, 16)


In [62]:
print(gaussian_out.shape)

(5000, 16, 16, 20)


In [None]:
in_channels = forward_output.shape[-1]
sigma = 1
x = jnp.linspace(-3, 3, 7)
dilation_rate = 2
kernel = jnp.zeros((7, 7, in_channels, in_channels), dtype=jnp.float32)  # HWIO
gauss_kernel = jsp.stats.norm.pdf(x, scale=sigma) * jsp.stats.norm.pdf(
    x[:, None], scale=sigma
)
kernel += gauss_kernel[:, :, jnp.newaxis, jnp.newaxis]
print("Input shape: ", forward_output.shape)
print("Kernel shape: ", gauss_kernel.shape)
print("Input shape transposed: ", jnp.transpose(forward_output, [0, 3, 1, 2]).shape)
print("Kernel shape transposed: ", jnp.transpose(kernel, [3, 2, 0, 1]).shape)

out = lax.conv(
    jnp.transpose(forward_output, [0, 3, 1, 2]),  # lhs = NCHW image tensor
    jnp.transpose(kernel, [3, 2, 0, 1]),  # rhs = OIHW conv kernel tensor
    (2, 2),  # window strides
    "SAME",
)  # padding mode
print(out.shape)
print(jnp.transpose(out, [0, 2, 3, 1]).shape)

In [None]:
def on_layer_gaussian_pauling(images, dilation, sigma):
    x = jnp.linspace(-4, 4, 9)
    kernel = jnp.zeros((9, 9, 1, 1), dtype=jnp.float32)  # HWIO
    gauss_kernel = jsp.stats.norm.pdf(x, scale=sigma) * jsp.stats.norm.pdf(
        x[:, None], scale=sigma
    )
    kernel += gauss_kernel[:, :, jnp.newaxis, jnp.newaxis]
    kernel = jsp.stats.norm.pdf(x, scale=sigma) * jsp.stats.norm.pdf(
        x[:, None], scale=sigma
    )
    out = lax.conv(
        jnp.transpose(images, [0, 3, 1, 2]),  # lhs = NCHW image tensor
        jnp.transpose(kernel, [3, 2, 0, 1]),  # rhs = OIHW conv kernel tensor
        (dilation, dilation),  # window strides
        "SAME",
    )  # padding mode
    return jnp.transpose(out, [0, 2, 3, 1])

In [None]:
kernel = jnp.zeros((3, 3, 3, 1), dtype=jnp.float32)
print(kernel.shape)
kernel += jnp.array([[1, 1, 0], [1, 0, -1], [0, -1, -1]])[
    :, :, jnp.newaxis, jnp.newaxis
]
print(kernel.shape)

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0])
plt.show()
# NHWC layout
img = jnp.zeros((1, 200, 198, 3), dtype=jnp.float32)
for k in range(3):
    x = 30 + 60 * k
    y = 20 + 60 * k
    img = img.at[0, x : x + 10, y : y + 10, k].set(1.0)

print("Original Image:")
plt.imshow(img[0])
plt.show()

dilation_rate = 2
from jax import lax

out = lax.conv(
    jnp.transpose(img, [0, 3, 1, 2]),  # lhs = NCHW image tensor
    jnp.transpose(kernel, [3, 2, 0, 1]),  # rhs = OIHW conv kernel tensor
    (2, 2),  # window strides
    "SAME",
)  # padding mode
plt.imshow(jnp.array(out)[0, 0, :, :])
plt.show()

In [None]:
forward_output.shape

In [None]:
def _one_pixel(p):
    norm = jnp.linalg.norm(p)
    inter = kernel_func(centroids_Z.T @ p / norm)
    return norm * k_mat @ inter


forward_pass = jnp.apply_along_axis(_one_pixel, -1, image_patched)

In [None]:
# def flatten_axes(a: jax.Array, start: int = 0, end: int = -1) -> jax.Array:
#     return a.reshape(a.shape[:start] + (-1,) + a.shape[end:][1:])

# def sample_patches(key, images, num_patches):
#     #images NxCxCxM
#     #retrun (Num_patches x M)
#     flat_images = flatten_axes(images, 0,2)
#     samples = random.choice(key,flat_images,shape=(num_patches,),replace=False, axis=0),
#     return samples[0]

# def get_centroids():
#     pass

In [None]:
key = random.PRNGKey(42)
patch_radius = 1
batch_size = 5
patch_sampling_num = 1000
num_of_centroids = 20
batch_size = 50
channel_len = Xtr.shape[1]
patcher = get_patched_op(channel_len, patch_radius)
image_patched = patcher(Xtr[:batch_size, :, :, :])
# print(image_patched.shape)
image_patched = flatten_axes(image_patched, start=-3)
sampled_patches = sample_patches(key, image_patched, patch_sampling_num)
normalized_sampled_patched = sampled_patches / (
    jnp.linalg.norm(sampled_patches, axis=1, keepdims=True) + 1e-6
)
# print(normalized_sampled_patched.shape)
centroids_Z = (
    SphericalKMeans(nb_clusters=num_of_centroids).fit(normalized_sampled_patched)[0].T
)
##forward pass
print(image_patched.shape)
print(centroids_Z.shape)

##transform_matrix
# print((centroids_Z.T@centroids_Z).shape)
k_mat = jax.scipy.linalg.sqrtm(kernel_func(centroids_Z.T @ centroids_Z))
##forward pass

In [None]:
def _one_pixel(p):
    norm = jnp.linalg.norm(p)
    inter = kernel_func(centroids_Z.T @ p / norm)
    return norm * k_mat @ inter


forward_pass = jnp.apply_along_axis(_one_pixel, -1, image_patched)

In [None]:
print(forward_pass.shape)

In [None]:
key = jax.random.PRNGKey(42)
k, key = random.split(key, 2)
skeans = SphericalKMeans(20)
data = random.normal(k, shape=(100, 10, 10_000))
centroids, _ = skeans.fit(data)
print(centroids.shape)

In [None]:
def get_gradient_transform(image_shape, dim_out):
    theta = jnp.linspace(0, 2.0 * jnp.pi, dim_out + 1)[:-1]

    def gradient_transform_channel(image, angle):
        dx, dy = jnp.gradient(image)

    pass


def get_shape_transform(image_shape, patch_shape):
    pass

In [None]:
old_shape = (32, 32)
# @jit

In [None]:
key = jax.random.PRNGKey(42)
img = random.normal(key, shape=(32, 32, 3))
img_to_patches = vmap(get_patched_op(32, 2), 0, 0)
img = img_to_patches(Xtr)
img = img.reshape(5000, 32, 32, -1)
img.shape

In [None]:
# extract_patch = img_new[
#     center[0] : 1 + center[0] + 2 * pad_width,
#     center[1] : 1 + center[1] + 2 * pad_width,
#     :,
# ]

# patch_radius = 2
# img_new = jnp.pad(img, ((pad, pad), (pad, pad), (0, 0)))
# center = (32, 32)
# extract_patch = img_new[
#     center[0] : 1 + center[0] + 2 * pad_width,
#     center[1] : 1 + center[1] + 2 * pad_width,
#     :,
# ]
# extract_patch.shape

In [None]:
patches = img_to_patches(img, patch_radius=2)
patches = patches.reshape(32, 32, -1)
patches.shape