In [None]:
import sys
sys.path.insert(0,'../src/geometricconvolutions/')
import itertools as it
import numpy as np
import jax.numpy as jnp
import jax.random as random
import geometric as geom
import utils
import pylab as plt
%load_ext autoreload
%autoreload 2

In [None]:
D = 2
group_operators = geom.make_all_operators(D)
print(len(group_operators))

In [None]:
allfilters = {}
names = {}
maxn = {}
for M in [3, ]: #filter size, 3x3
    maxn[(D, M)] = 0
    for k in [0,1,2]: #tensor order
        for parity in [0,1]: #parity
            key = (D, M, k, parity)
            allfilters[key] = geom.get_unique_invariant_filters(M, k, parity, D, group_operators)
            n = len(allfilters[key])
            if n > maxn[(D, M)]:
                maxn[(D, M)] = n
            names[key] = ["{} {}".format(geom.tensor_name(k, parity), i) for i in range(n)]

In [None]:
for key in allfilters.keys():
    D, M, k, parity = key
    utils.plot_filters(allfilters[key], names[key], maxn[(D, M)])

In [None]:
filter_list = list(it.chain(*list(allfilters.values())))
print(len(filter_list))

In [None]:
# Make an N side length, parity=0 geometric vector image on a D-torus
N = 3
key = random.PRNGKey(0)
vector_image = geom.GeometricImage(random.normal(key, shape=((N,)*D + (D,))), 0, D).normalize()
if D == 2:
    utils.plot_image(vector_image)

In [None]:
vector_images = []
for c1_idx, c2_idx in it.combinations_with_replacement(range(len(filter_list)), 2):
    c1 = filter_list[c1_idx]
    c2 = filter_list[c2_idx]
    
    #conditions suitable for a sequence of kronecker contractions
    if ((c1.k + c2.k + vector_image.k)%2 == 0):
        B1 = vector_image.convolve_with(c1)
        B2 = vector_image.convolve_with(c2)
        img = (B1*B2)
        
        if (img.parity != vector_image.parity):
            continue
            for levi_idxs in it.permutations(range(img.k), img.D - 1):
                img_levi_contracted = img.levi_civita_contract(levi_idxs)
                
                tuple_pairs = it.combinations(it.combinations(range(img.k),2),img.k // 2)
                pairs = np.array([np.array(x).reshape((img.k-vector_image.k,)) for x in tuple_pairs])
                unique_rows = np.array([True if len(np.unique(row)) == len(row) else False for row in pairs])
                unique_pairs = pairs[unique_rows]

                for idxs in unique_pairs:
                    img_contracted = img_levi_contracted
                    while(len(idxs) > 0):
                        idx1, idx2, *idxs = idxs
                        img_contracted = img_contracted.contract(idx1, idx2) #could use multi-contract

                        # adjust indices now that we have removed some
                        larger_idx = np.max([idx1, idx2])
                        smaller_idx = np.min([idx1, idx2])
                        idxs = [x if x < larger_idx else x-1 for x in idxs]
                        idxs = [x if x < smaller_idx else x-1 for x in idxs]

                    assert img_contracted.shape() == vector_image.shape()
                    vector_images.append(img_contracted.normalize())
                
        else:
            tuple_pairs = it.combinations(it.combinations(range(img.k),2),img.k // 2)
            pairs = np.array([np.array(x).reshape((img.k-vector_image.k,)) for x in tuple_pairs])
            unique_rows = np.array([True if len(np.unique(row)) == len(row) else False for row in pairs])
            unique_pairs = pairs[unique_rows]

            for idxs in unique_pairs:
                img_contracted = img
                while(len(idxs) > 0):
                    idx1, idx2, *idxs = idxs
                    img_contracted = img_contracted.contract(idx1, idx2) #could use multi-contract

                    # adjust indices now that we have removed some
                    larger_idx = np.max([idx1, idx2])
                    smaller_idx = np.min([idx1, idx2])
                    idxs = [x if x < larger_idx else x-1 for x in idxs]
                    idxs = [x if x < smaller_idx else x-1 for x in idxs]

                assert img_contracted.shape() == vector_image.shape()
                vector_images.append(img_contracted.normalize())

In [None]:
datablock = np.array([im.data.flatten() for im in vector_images])
print(datablock.shape)
u, s, v = np.linalg.svd(datablock)
print("there are", np.sum(s > geom.TINY), "different images")