In [None]:
from se3_cnn import basis_kernels
from se3_cnn import SO3

In [None]:
# Render a big kernel
size = 13
n_radial = 1
upsampling = 1
R_in = SO3.repr5
R_out = SO3.repr5

radial_window_dict = {
    'radial_window_fct': basis_kernels.gaussian_window_fct_convenience_wrapper,
    'radial_window_fct_kwargs': {
        'mode': 'compromise',
        'border_dist': 0.,
        'sigma': .6
    }
}

basis = basis_kernels.cube_basis_kernels_analytical(size, R_in, R_out, radial_window_dict)
#basis = basis_kernels.orthonormalize(basis)

print(basis.shape)

# Check equivariance
basis_kernels.check_basis_equivariance(basis, R_out, R_in, 3.14/4, 0.12, 0.05)

In [None]:
# 3D plot
import ipyvolume as ipv

dl = 0.138
amp = 0.002
ipv.quickvolshow(basis[2, 0, 0], 
                 level=[0.5 - dl, 0.5 + dl], 
                 level_width=0.05, 
                 opacity=0.05, 
                 data_min=-amp, data_max=amp)

In [None]:
# Plot of 2D cuts
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def plot_kernel(basis, base_element=0, zheight=0):
    size = basis.shape[-1]
    dim_out = basis.shape[1]
    dim_in = basis.shape[2]
    
    vmin = basis.mean() - 2 * basis.std()
    vmax = basis.mean() + 2 * basis.std()

    plt.figure(figsize=(2*dim_in, 2*dim_out))
    for i in range(dim_out):
        for j in range(dim_in):
            plt.subplot(dim_out, dim_in, dim_in * i + j + 1)
            plt.imshow(basis[base_element, i, j, size//2 + round(size / 2 * zheight), :, :], vmin=vmin, vmax=vmax)
            plt.axis("off")
    plt.tight_layout()
    
for i in range(basis.shape[0]):
    print(i)
    plot_kernel(basis, base_element=i)
    plt.show()