In [None]:
import itertools as it
import numpy as np
import jax.numpy as jnp
import jax.random as random
import geometricconvolutions.geometric as geom
import geometricconvolutions.utils as utils
%load_ext autoreload
%autoreload 2

In [None]:
D = 2
N = 3
img_k = 1
max_k = 2
group_operators = geom.make_all_operators(D)
print(len(group_operators))

In [None]:
allfilters = {}
names = {}
maxn = {}
maxn[(D, N)] = 0
for k in range(max_k+1): #tensor order
    for parity in [0,1]: #parity
        key = (D, N, k, parity)
        allfilters[key] = geom.get_unique_invariant_filters(N, k, parity, D, group_operators)
        n = len(allfilters[key])
        if n > maxn[(D, N)]:
            maxn[(D, N)] = n
        names[key] = ["{} {}".format(geom.tensor_name(k, parity), i) for i in range(n)]

In [None]:
for key in allfilters.keys():
    D, M, k, parity = key
    utils.plot_filters(allfilters[key], names[key], maxn[(D, M)])

In [None]:
filter_list = list(it.chain(*list(allfilters.values())))
print(len(filter_list))

In [None]:
# Make an N side length, parity=0 geometric vector image on a D-torus
key = random.PRNGKey(0)
key, subkey = random.split(key)
vector_image1 = geom.GeometricImage(random.normal(subkey, shape=((N,)*D + (D,)*img_k)), 0, D).normalize()
key, subkey = random.split(key)
vector_image2 = geom.GeometricImage(random.normal(subkey, shape=((N,)*D + (D,)*img_k)), 0, D).normalize()
key, subkey = random.split(key)
vector_image3 = geom.GeometricImage(random.normal(subkey, shape=((N,)*D + (D,)*img_k)), 0, D).normalize()

In [None]:
def getVectorImgs(vector_image, extra_images = []):
    vector_images = []
    for c1_idx, c2_idx in it.combinations_with_replacement(range(len(filter_list)), 2):
        print(c1_idx, c2_idx)
        c1 = filter_list[c1_idx]
        c2 = filter_list[c2_idx]
        
        for c3 in filter_list:

            #conditions suitable for a sequence of kronecker contractions
            if (
                ((c1.k + c2.k + c3.k + vector_image.k)%2 == 0) and 
                ((c1.parity + c2.parity + c3.parity + vector_image.parity)%2 == 0)
            ):
                B1 = vector_image.convolve_with(c1)
                B2 = vector_image.convolve_with(c2)
                img = (B1*B2).convolve_with(c3)

                tuple_pairs = it.combinations(it.combinations(range(img.k),2),(img.k-vector_image.k) // 2)
                pairs = np.array([np.array(x).reshape((img.k-vector_image.k,)) for x in tuple_pairs])
                unique_rows = np.array([True if len(np.unique(row)) == len(row) else False for row in pairs])
                unique_pairs = pairs[unique_rows]

                for idxs in unique_pairs:
                    idxs = jnp.array(idxs)
                    img_contracted = img.multicontract(idxs)
                    assert img_contracted.shape() == vector_image.shape()

                    long_image = img_contracted.data.flatten()
                    for extra_image in extra_images:
                        extra_img = (extra_image.convolve_with(c1)*extra_image.convolve_with(c2))
                        extra_img_data = extra_img.convolve_with(c3).multicontract(idxs).data.flatten()
                        long_image = jnp.concatenate((long_image, extra_img_data))

                    vector_images.append(long_image)
    return jnp.array(vector_images)

In [None]:
datablock = getVectorImgs(vector_image1, [vector_image2, vector_image3])

In [None]:
datablock.shape

In [None]:
# datablock = np.array([im.data.flatten() for im in vector_images1])
print(datablock.shape)
u, s, v = np.linalg.svd(datablock)
print("there are", np.sum(s > geom.TINY), "different images")