In [1]:
from curvelets import curvelet_transform
import numpy as np
from matplotlib import pyplot as plt
import functools

In [2]:
def make_circle(size, radius):
    x = np.arange(-size//2, size//2)/(size//2)
    x,y = np.meshgrid(x,x)
    return (x**2 + y**2 < radius**2).astype(np.float)

def test():
    """
    Test of shapes
    """
    n1 = 3
    n2 = 4
    n_angles = [2,4,6]
    border="toric"
    size_image = 256
    curv_trans = curvelet_transform(size_image,n_angles)  
    im = np.random.randn(n1, n2, size_image, size_image)
    trans= curv_trans(im)
    print([arr.shape for arr in trans])
    
def test2():
    "Test of the geometrical explainability of the curvelet transform. And test of the reconstruction"
    n1 = 3
    n2 = 4
    n_angles = [2,4,6]
    border="toric"
    size_image = 64
    curv_trans = curvelet_transform(size_image,n_angles)  
    im = make_circle(size_image, 0.7)
    trans= curv_trans(im)
    for tr in trans:
        for ang in range(tr.shape[0]):
            for phase in range(tr.shape[1]):
                plt.imshow(tr[ang,phase,:,:])
                plt.show()
            print("############################")
    
    im_rec = curv_trans.inverse(trans)
    plt.imshow(np.hstack((im,im_rec)))
    plt.colorbar()
    plt.show()

def test3():
    "Test of the orthonormality of the transform"
    n1 = 3
    n2 = 4
    size_image = 64;
    n_angles = [2,4,6]
    im = np.random.randn(n1,n2,size_image,size_image)
    curv_trans = curvelet_transform(size_image,n_angles)
    trans= curv_trans(im)
    norm_t = functools.reduce(lambda a, b: a+b, [np.linalg.norm(arr)**2 for arr in trans])
    norm_i = np.linalg.norm(im)**2
    print((norm_t - norm_i)/norm_i)
    
def test4():
    """
    Test for the separability of the transform and its reconstruction
    """
    n1 = 3
    n2 = 4
    n_angles = [2,4,6]
    size_image = 512
    curv_trans = curvelet_transform(size_image,n_angles)
    
    im = np.random.randn(n1, n2, size_image, size_image)
    curv_transform1 = [[ curv_trans(im[i,j,:,:]) for j in range(n2) ] for i in range(n1)]
    curv_transform1 = [np.array( [[ curv_transform1[i][j][k] for j in range(n2) ] 
                                                             for i in range(n1) ]
                               )
                                  for k in range(len(curv_transform1[0][0]))
                      ]
    curv_transform2 = curv_trans(im)
    dist_transforms = functools.reduce(lambda a, b: a+b, [np.linalg.norm(c1-c2)**2 for c1, c2 in zip(curv_transform1, curv_transform2)])
    norm_transform1 = functools.reduce(lambda a, b: a+b, [np.linalg.norm(c1)**2 for c1 in curv_transform1])
    
    print("Separability :")
    print("    - transform :", dist_transforms/norm_transform1)
    
    im_rec1 = np.array([[ curv_trans.inverse([c[i,j,...] for c in curv_transform2]) for j in range(n2) ] 
                                                                                    for i in range(n1) 
                       ]
                      )
    im_rec2 = curv_trans.inverse(curv_transform2)
    print("    - reconstruction : ", np.linalg.norm(im_rec1 - im_rec2)/np.linalg.norm(im_rec1))

    
test4()   
    

Separability :
    - transform : 0.0
    - reconstruction :  0.0
