In [None]:
import itertools as it
import numpy as np
import jax.numpy as jnp
import jax.random as random
import geometricconvolutions.geometric as geom
import geometricconvolutions.utils as utils
%load_ext autoreload
%autoreload 2

In [None]:
D = 2
N = 3
img_k = 1
max_k = 2
group_operators = geom.make_all_operators(D)
print(len(group_operators))

In [None]:
allfilters, maxn = geom.get_invariant_filters_dict([N], range(max_k+1), [0,1], D, group_operators)
for key in allfilters.keys():
    D, M, k, parity = key
    names = ["{} {}".format(geom.tensor_name(k, parity), i) for i in range(len(allfilters[key]))]
    utils.plot_filters(allfilters[key], names, maxn[(D, M)])

In [None]:
filter_list = list(it.chain(*list(allfilters.values())))
print(len(filter_list))

In [None]:
# Make an N side length, parity=0 geometric vector image on a D-torus
key = random.PRNGKey(0)
vector_images = []
if (N == 3):
    num_images = 3
elif (N == 5):
    num_images = 7
    
for _ in range(num_images):
    key, subkey = random.split(key)
    vector_images.append(
        geom.GeometricImage(random.normal(subkey, shape=((N,)*D + (D,)*img_k)), 0, D).normalize()
    )

In [None]:
def quadratic_filter(img, c1, c2, c3):
    return (img.convolve_with(c1) * img.convolve_with(c2)).convolve_with(c3)

In [None]:
def getVectorImgs(vector_image, extra_images = []):
    vector_images = []
    for c1_idx, c2_idx, c3_idx in it.combinations(range(len(filter_list)), 3):
        print(c1_idx, c2_idx, c3_idx)
        c1 = filter_list[c1_idx]
        c2 = filter_list[c2_idx]
        c3 = filter_list[c3_idx]

        #conditions suitable for a sequence of kronecker contractions
        if (
            ((c1.k + c2.k + c3.k + vector_image.k)%2 == 0) and 
            ((c1.parity + c2.parity + c3.parity + vector_image.parity)%2 == 0)
        ):
            img = quadratic_filter(vector_image, c1, c2, c3)

            tuple_pairs = it.combinations(it.combinations(range(img.k),2),(img.k-vector_image.k) // 2)
            pairs = np.array([np.array(x).reshape((img.k-vector_image.k,)) for x in tuple_pairs])
            unique_rows = np.array([True if len(np.unique(row)) == len(row) else False for row in pairs])
            unique_pairs = pairs[unique_rows]

            for idxs in unique_pairs:
                #take every two elements and form tuple pairs of them
                tupled_idxs = tuple((x,y) for x,y in zip(idxs[0::2], idxs[1::2]))
                
                img_contracted = img.multicontract(tupled_idxs)
                assert img_contracted.shape() == vector_image.shape()

                long_image = img_contracted.data.flatten()
                for extra_image in extra_images:
                    extra_img = quadratic_filter(extra_image, c1, c2, c3)
                    extra_img_data = extra_img.multicontract(tupled_idxs).data.flatten()
                    long_image = jnp.concatenate((long_image, extra_img_data))

                vector_images.append(long_image)
                
    return jnp.array(vector_images)

In [None]:
vector_image, *extra_images = vector_images
datablock = getVectorImgs(vector_image, extra_images)

In [None]:
print(datablock.shape)
print(jnp.unique(jnp.around(datablock, decimals=4), axis=0).shape)

In [None]:
u, s, v = jnp.linalg.svd(jnp.unique(datablock, axis=0))
print("there are", np.sum(s > 10*geom.TINY), "different images")