In [4]:
from jax import random, jit

from itertools import product

import jax.numpy as jnp
import equiformer.tensor_product as tensor_product
import equiformer.spherical as spherical
import numpy as np
from sympy.physics.quantum.cg import CG
from einops import einsum, rearrange, repeat

from IPython.display import display

In [2]:
arrs = tensor_product.generate_cg_matrices(1, 1, 1)

In [3]:
display(arrs)

Array([[[ 0.       ,  0.       ,  0.       ],
        [ 0.       ,  0.       , -0.7071067],
        [ 0.       ,  0.7071067,  0.       ]],

       [[ 0.       ,  0.       ,  0.7071067],
        [ 0.       ,  0.       ,  0.       ],
        [-0.7071067,  0.       ,  0.       ]],

       [[ 0.       , -0.7071067,  0.       ],
        [ 0.7071067,  0.       ,  0.       ],
        [ 0.       ,  0.       ,  0.       ]]], dtype=float32)

#### Creating Random Inputs

In [4]:
key = random.PRNGKey(0)
key, subkey = random.split(key, num=2)

n = 100

coords = random.normal(subkey, shape=(n, 3))

In [7]:
f = spherical.solid_harmonics_jit(coords, 2)
g = spherical.solid_harmonics_jit(coords, 1)

In [8]:
stacked_f = jnp.stack([f, f], axis=-2)
stacked_g = jnp.stack([g, g], axis=-2)

print(stacked_f.shape, stacked_g.shape)

(100, 2, 5) (100, 2, 3)


In [9]:
tps = tensor_product.tensor_product(stacked_f, stacked_g)

In [10]:
print([tp.shape for tp in tps])
concatted = jnp.concatenate([rearrange(tp, "... c m -> ... (c m)") for tp in tps], axis=-1)
print(concatted.shape)

[(100, 4, 3), (100, 4, 5), (100, 4, 7)]
(100, 60)


### Testing new CG

In [33]:
import logging

logging.basicConfig(level=logging.DEBUG)

In [55]:
def generate_large_cg_matrix(ncs_1: list[int], ncs_2: list[int], lmax: int) -> jnp.ndarray:
    """Defines larger Clebsch-Gordan matrix for given angular momenta."""
    # Find the actual maximum angular momentum to calculate
    true_lmax = min(lmax, len(ncs_1) + len(ncs_2) - 2)

    # Define nested list to hold all blocks
    # Format needed for np.block
    cg_mats_all = [[[None] * (true_lmax + 1)] * len(ncs_2)] * len(ncs_1)

    cg_dict = {}

    long_list = []
    for (l1, nc1), (l2, nc2) in product(enumerate(ncs_1), enumerate(ncs_2)):
        # Constructing matrix that will match input channels with output
        match_io = rearrange(np.eye(nc1 * nc2), "(nci ncf) ncif -> nci ncf ncif", nci=nc1, ncf=nc2)

        for l3 in range(0, true_lmax + 1):
            # CG matrix will be trivially zero if l3 is not in the range
            # Choosing to calculate them anyway for simplicity
            cg_mat = tensor_product.generate_cg_matrix(l1, l2, l3)

            # Repeating a cg_mat a number of times for each input channel
            cg_mat_repeated = repeat(cg_mat, "mi mf mo -> nc1 mi nc2 mf mo", nc1=nc1, nc2=nc2)


            # Perform einsum to properly expand cg_mat_repeated
            cg_mat_matching = einsum(
                cg_mat_repeated,
                match_io,
                "nc1 mi nc2 mf mo, nc1 nc2 ncif -> nc1 mi nc2 mf ncif mo",
            )
            # Reshaping to 3 dimensional matrix
            cg_mat_slurp = rearrange(cg_mat_matching, "nc1 mi nc2 mf ncif mo -> (nc1 mi) (nc2 mf) (ncif mo)", nc1=nc1, nc2=nc2, ncif=(nc1 * nc2))
            logging.info(f"{l1}, {l2}, {l3}, {cg_mat_repeated.shape} -> {cg_mat_matching.shape} -> {cg_mat_slurp.shape}")

            # NOTE This line is broken? Shapes are weird?
            # I think the block operation is modifying the shape of the matrix
            cg_mats_all[l1][l2][l3] = cg_mat_slurp
            # Trying to debug list of lists
            long_list.append(cg_mat_slurp)
            cg_dict[(l1, l2, l3)] = cg_mat_slurp

    # TODO Fix this function - currently giving wrong sizes
    return np.block(cg_mats_all), cg_mats_all, long_list, cg_dict

In [52]:
blocked, individual, long_list, cg_dict = generate_large_cg_matrix([0, 1], [0, 1], 2)

INFO:root:0, 0, 0, (0, 1, 0, 1, 1) -> (0, 1, 0, 1, 0, 1) -> (0, 0, 0)
INFO:root:0, 0, 1, (0, 1, 0, 1, 3) -> (0, 1, 0, 1, 0, 3) -> (0, 0, 0)
INFO:root:0, 0, 2, (0, 1, 0, 1, 5) -> (0, 1, 0, 1, 0, 5) -> (0, 0, 0)
INFO:root:0, 1, 0, (0, 1, 1, 3, 1) -> (0, 1, 1, 3, 0, 1) -> (0, 3, 0)
INFO:root:0, 1, 1, (0, 1, 1, 3, 3) -> (0, 1, 1, 3, 0, 3) -> (0, 3, 0)
INFO:root:0, 1, 2, (0, 1, 1, 3, 5) -> (0, 1, 1, 3, 0, 5) -> (0, 3, 0)
INFO:root:1, 0, 0, (1, 3, 0, 1, 1) -> (1, 3, 0, 1, 0, 1) -> (3, 0, 0)
INFO:root:1, 0, 1, (1, 3, 0, 1, 3) -> (1, 3, 0, 1, 0, 3) -> (3, 0, 0)
INFO:root:1, 0, 2, (1, 3, 0, 1, 5) -> (1, 3, 0, 1, 0, 5) -> (3, 0, 0)
INFO:root:1, 1, 0, (1, 3, 1, 3, 1) -> (1, 3, 1, 3, 1, 1) -> (3, 3, 1)
INFO:root:1, 1, 1, (1, 3, 1, 3, 3) -> (1, 3, 1, 3, 1, 3) -> (3, 3, 3)
INFO:root:1, 1, 2, (1, 3, 1, 3, 5) -> (1, 3, 1, 3, 1, 5) -> (3, 3, 5)


In [53]:
blocked.shape

(6, 6, 9)

In [54]:
[[[cg.shape for cg in inner] for inner in outer] for outer in individual]

[[[(3, 3, 1), (3, 3, 3), (3, 3, 5)], [(3, 3, 1), (3, 3, 3), (3, 3, 5)]],
 [[(3, 3, 1), (3, 3, 3), (3, 3, 5)], [(3, 3, 1), (3, 3, 3), (3, 3, 5)]]]

In [45]:
[cg.shape for cg in long_list]

[(0, 0, 0),
 (0, 0, 0),
 (0, 0, 0),
 (0, 3, 0),
 (0, 3, 0),
 (0, 3, 0),
 (3, 0, 0),
 (3, 0, 0),
 (3, 0, 0),
 (3, 3, 1),
 (3, 3, 3),
 (3, 3, 5)]

In [65]:
outer = []
for l1 in range(2):
    middle = []
    for l2 in range(2):
        inner = []
        for l3 in range(3):
            inner.append(cg_dict[(l1, l2, l3)])
        middle.append(inner)
    outer.append(middle)

In [74]:
zerozero = np.concatenate(long_list[0:3], axis=-1)
zeroone = np.concatenate(long_list[3:6], axis=-1)
onezero = np.concatenate(long_list[6:9], axis=-1)
oneone = np.concatenate(long_list[9:12], axis=-1)

print(zerozero.shape, zeroone.shape, onezero.shape, oneone.shape)

(0, 0, 0) (0, 3, 0) (3, 0, 0) (3, 3, 9)


(0, 0, 0)