In [89]:
import numpy as np
import scipy.ndimage

import math
from itertools import combinations, permutations, product
from typing import List, Union

from skimage.measure import block_reduce

import numpy as np
from medimage import image as MEDimage

from scipy.signal import fftconvolve

# !pip install medimage
import SimpleITK as sitk
from scipy.ndimage import convolve as scipy_convolve

from abc import ABC
from medimage import image as MEDimage
import pywt

In [15]:
class image_volume_obj:
    def __init__(self, data: np.ndarray):
        self.data = data
        

In [16]:
def pad_imgs(
            images: np.ndarray,
            padding_length: List,
            axis: List,
            mode: str
            )-> np.ndarray:
    """Apply padding on a 3d images using a 2D padding pattern.

    Args:
        images (ndarray): a numpy array that represent the image.
        padding_length (List): The padding length that will apply on each side of each axe.
        axis (List): A list of axes on which the padding will be done.
        mode (str): The padding mode. Check options here: `numpy.pad 
            <https://numpy.org/doc/stable/reference/generated/numpy.pad.html>`__.

    Returns:
        ndarray: A numpy array that represent the padded image.
    """
    pad_tuple = ()
    j = 1

    for i in range(np.ndim(images)):
        if i in axis:
            pad_tuple += ((padding_length[-j], padding_length[-j]),)
            j += 1
        else:
            pad_tuple += ((0, 0),)

    return np.pad(images, pad_tuple, mode=mode)

def convolve(
        dim: int,
        kernel: np.ndarray,
        images: np.ndarray,
        orthogonal_rot: bool=False,
        mode: str = "symmetric"
    ) -> np.ndarray:
    """Convolve a given n-dimensional array with the kernel to generate a filtered image.

    Args:
        dim (int): The dimension of the images.
        kernel (ndarray): The kernel to use for the convolution.
        images (ndarray): A n-dimensional numpy array that represent a batch of images to filter.
        orthogonal_rot (bool, optional): If true, the 3D images will be rotated over coronal, axial and sagittal axis.
        mode (str, optional): The padding mode. Check options here: `numpy.pad 
            <https://numpy.org/doc/stable/reference/generated/numpy.pad.html>`__.

    Returns:
        ndarray: The filtered image.
    """

    in_size = np.shape(images)

    # We only handle 2D or 3D images.
    assert len(in_size) == 3 or len(in_size) == 4, \
        "The tensor should have the followed shape (B, H, W) or (B, D, H, W)"

    if not orthogonal_rot:
        # If we have a 2D kernel but a 3D images, we squeeze the tensor
        if dim < len(in_size) - 1:
            images = images.reshape((in_size[0] * in_size[1], in_size[2], in_size[3]))

        # We compute the padding size along each dimension
        padding = [int((kernel.shape[-1] - 1) / 2) for _ in range(dim)]
        pad_axis_list = [i for i in range(1, dim+1)]

        # We pad the images and we add the channel axis.
        padded_imgs = pad_imgs(images, padding, pad_axis_list, mode)
        new_imgs = np.expand_dims(padded_imgs, axis=1)

        # Operate the convolution
        if dim < len(in_size) - 1:
            # If we have a 2D kernel but a 3D images, we convolve slice by slice
            result_list = [fftconvolve(np.expand_dims(new_imgs[i], axis=0), kernel, mode='valid') for i in range(len(images))]
            result = np.squeeze(np.stack(result_list), axis=2)

        else :
            result = fftconvolve(new_imgs, kernel, mode='valid')

        # Reshape the data to retrieve the following format: (B, C, D, H, W)
        if dim < len(in_size) - 1:
            result = result.reshape((
                in_size[0], in_size[1], result.shape[1], in_size[2], in_size[3])
            ).transpose(0, 2, 1, 3, 4)

    # If we want orthogonal rotation
    else:
        coronal_imgs = images
        axial_imgs, sagittal_imgs = np.rot90(images, 1, (1, 2)), np.rot90(images, 1, (1, 3))
        
        result_coronal = convolve(dim, kernel, coronal_imgs, False, mode)
        result_axial = convolve(dim, kernel, axial_imgs, False, mode)
        result_sagittal = convolve(dim, kernel, sagittal_imgs, False, mode)

        # split and unflip and stack the result on a new axis
        result_axial = np.rot90(result_axial, 1, (3, 2))
        result_sagittal = np.rot90(result_sagittal, 1, (4, 2))

        result = np.stack([result_coronal, result_axial, result_sagittal])

    return result


In [277]:
class Filtering:
    """Class to handle various filtering operations on imaging data."""

    def __init__(self):
        pass
    
        
    def average_pooling(self, image, pool_size=(2, 2, 2), strides=None):
        """
        Apply 3D average pooling to the input image.

        Parameters:
        image (np.array): 3D array representing the image to be pooled.
        pool_size (tuple): Size of the pooling window (default: (2, 2, 2)).
        strides (tuple): Strides for the pooling operation (default: same as pool_size).

        Returns:
        np.array: Average-pooled 3D image.
        """
        if strides is None:
            strides = pool_size
        depth, height, width = image.shape

        out_depth = (depth - pool_size[0]) // strides[0] + 1
        out_height = (height - pool_size[1]) // strides[1] + 1
        out_width = (width - pool_size[2]) // strides[2] + 1

        pooled_image = np.zeros((out_depth, out_height, out_width))

        for d in range(out_depth):
            for h in range(out_height):
                for w in range(out_width):
                    start_d = d * strides[0]
                    start_h = h * strides[1]
                    start_w = w * strides[2]

                    end_d = start_d + pool_size[0]
                    end_h = start_h + pool_size[1]
                    end_w = start_w + pool_size[2]

                    pooled_image[d, h, w] = np.mean(image[start_d:end_d, start_h:end_h, start_w:end_w])

        return pooled_image

    def max_pooling(self, image, pool_size=(2, 2, 2), strides=None):
        """
        Apply 3D max pooling to the input image.

        Parameters:
        image (np.array): 3D array representing the image to be pooled.
        pool_size (tuple): Size of the pooling window (default: (2, 2, 2)).
        strides (tuple): Strides for the pooling operation (default: same as pool_size).

        Returns:
        np.array: Max-pooled 3D image.
        """
        if strides is None:
            strides = pool_size
        depth, height, width = image.shape

        out_depth = (depth - pool_size[0]) // strides[0] + 1
        out_height = (height - pool_size[1]) // strides[1] + 1
        out_width = (width - pool_size[2]) // strides[2] + 1

        pooled_image = np.zeros((out_depth, out_height, out_width))

        for d in range(out_depth):
            for h in range(out_height):
                for w in range(out_width):
                    start_d = d * strides[0]
                    start_h = h * strides[1]
                    start_w = w * strides[2]

                    end_d = start_d + pool_size[0]
                    end_h = start_h + pool_size[1]
                    end_w = start_w + pool_size[2]

                    pooled_image[d, h, w] = np.max(image[start_d:end_d, start_h:end_h, start_w:end_w])

        return pooled_image
    
    def upsample3d(self, image, output_shape):
        zoom_factors = np.array(output_shape) / np.array(image.shape)
        return scipy.ndimage.zoom(image, zoom_factors, order=3)


    
    def mean_filter(self, ndims: int, size: int, images: np.ndarray, orthogonal_rot: bool = False, padding="symmetric") -> np.ndarray:
        """
        Apply mean filtering to the input images.

        Args:
            ndims (int): Number of dimensions of the kernel filter.
            size (int): An integer that represents the length along one dimension of the kernel.
            padding (str): The padding type that will be used to produce the convolution.
            images (np.ndarray): A n-dimensional numpy array that represents the images to filter.
            orthogonal_rot (bool, optional): If true, the 3D images will be rotated over coronal, axial and sagittal axes.

        Returns:
            np.ndarray: The filtered image.
        """
        assert isinstance(ndims, int) and ndims > 0, "ndims should be a positive integer"
        assert ((size + 1) / 2).is_integer() and size > 0, "size should be a positive odd number."

        # Initialize the kernel as a tensor of zeros
        weight = 1 / np.prod(size ** ndims)
        kernel = np.ones([size for _ in range(ndims)]) * weight
        kernel = np.expand_dims(kernel, axis=(0, 1))

        # Ensure images is at least 4-dimensional (B, W, H, D)
        if images.ndim < 4:
            raise ValueError("Input images must have at least 4 dimensions (B, W, H, D)")

        # Swap the second axis with the last, to convert image B, W, H, D --> B, D, H, W
        image = np.swapaxes(images, 1, 3)
        result = np.squeeze(convolve(ndims, kernel, image, orthogonal_rot, padding), axis=1)
        
        return np.swapaxes(result, 1, 3)
    
    def mean_filtering(self, input_images: Union[np.ndarray, sitk.Image], ndims: int = 3, size: int = 15, orthogonal_rot: bool = False, padding: str = "symmetric") -> np.ndarray:
        """Apply mean filtering to the input images."""
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.mean_filter(ndims, size, input_images, orthogonal_rot, padding)
        
        return np.squeeze(result)
    
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    def log_filter(self, ndims:int, size: int, images: np.ndarray, sigma: float, orthogonal_rot: bool = False, padding="constant") -> np.ndarray:
        """The constructor of the laplacian of gaussian (LoG) filter

        Args:
            ndims (int): Number of dimension of the kernel filter
            size (int): An integer that represent the length along one dimension of the kernel.
            sigma (float): The gaussian standard deviation parameter of the laplacian of gaussian filter
            padding (str): The padding type that will be used to produce the convolution

        Returns:
            None
        """
        assert isinstance(ndims, int) and ndims > 0, "ndims should be a positive integer"
#         assert ((size+1)/2).is_integer() and size > 0, "size should be a positive odd number."
        assert sigma > 0, "alpha should be a positive float."
        self.dim = ndims
        self.size = size
        self.sigma = sigma
        
        def compute_weight(position):
            distance_2 = np.sum(position**2)
            first_part = -1/((2*math.pi)**(self.dim/2) * self.sigma**(self.dim+2))
            second_part = (self.dim - distance_2/self.sigma**2)*math.e**(-distance_2/(2 * self.sigma**2))

            return first_part * second_part

        kernel = np.zeros([self.size for _ in range(self.dim)])

        for k in product(range(self.size), repeat=self.dim):
            kernel[k] = compute_weight(np.array(k)-int((self.size-1)/2))

        kernel -= np.sum(kernel)/np.prod(kernel.shape)
        kernel = np.expand_dims(kernel, axis=(0, 1))
            
        print(kernel.shape)

        image = np.swapaxes(images, 1, 3)
        print(image.shape)
        result = np.squeeze(convolve(ndims, kernel, image, orthogonal_rot, padding), axis=1)
        
        return np.swapaxes(result, 1, 3)
        

    def log_filtering(self, input_images:Union[np.ndarray, sitk.Image], ndims: int = 3, size: int = 15, sigma: int = 3, orthogonal_rot: bool = False, padding: str = "symmetric") -> np.ndarray:
        
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.log_filter(ndims, size, input_images, sigma, orthogonal_rot, padding)
        
        return np.squeeze(result)
        
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   
    def gabor_filter(self, ndims:int, size: int, images: np.ndarray, sigma: float, lamb: float, gamma: float, theta= float, orthogonal_rot: bool = False, padding= "constant") -> np.ndarray:
        """
        The constructor of the Gabor filter. Highly inspired by Ref 1.

        Args:
            size (int): An integer that represent the length along one dimension of the kernel.
            sigma (float): A positive float that represent the scale of the Gabor filter
            lamb (float): A positive float that represent the wavelength in the Gabor filter. (mm or pixel?)
            gamma (float): A positive float that represent the spacial aspect ratio
            theta (float): Angle parameter used in the rotation matrix
            rot_invariance (bool): If true, rotation invariance will be done on the kernel and the kernel
                                   will be rotate 2*pi / theta times.
            padding: The padding type that will be used to produce the convolution

        Returns:
            None
        """
        
        assert ((size + 1) / 2).is_integer() and size > 0, "size should be a positive odd number."
        assert sigma > 0, "sigma should be a positive float"
        assert lamb > 0, "lamb represent the wavelength, so it should be a positive float"
        assert gamma > 0, "gamma is the ellipticity of the support of the filter, so it should be a positive float"

        self.dim = ndims
        self.padding = padding
        self.size = size
        self.sigma = sigma
        self.lamb = lamb
        self.gamma = gamma
        self.theta = theta
        self.rot = orthogonal_rot
        
        def compute_weight(position, theta):
            k_2 = position[0]*math.cos(theta) + position[1] * math.sin(theta)
            k_1 = position[1]*math.cos(theta) - position[0] * math.sin(theta)

            common = math.e**(-(k_1**2 + (self.gamma*k_2)**2)/(2*self.sigma**2))
            real = math.cos(2*math.pi*k_1/self.lamb)
            im = math.sin(2*math.pi*k_1/self.lamb)
            return common*real, common*im
        
        # Rotation invariance
        nb_rot = round(2*math.pi/abs(self.theta)) if self.rot else 1
        real_list = []
        im_list = []

        for i in range(1, nb_rot+1):
            # Initialize the kernel as tensor of zeros
            real_kernel = np.zeros([self.size for _ in range(2)])
            im_kernel = np.zeros([self.size for _ in range(2)])

            for k in product(range(self.size), repeat=2):
                real_kernel[k], im_kernel[k] = compute_weight(np.array(k)-int((self.size-1)/2), self.theta*i)

            real_list.extend([real_kernel])
            im_list.extend([im_kernel])

        kernel = np.expand_dims(np.concatenate((real_list, im_list), axis=0), axis=1)
        # Ensure images is at least 4-dimensional (B, W, H, D)
        if images.ndim < 4:
            raise ValueError("Input images must have at least 4 dimensions (B, W, H, D)")

        image = np.swapaxes(images, 1, 3)

#         result = convolve(self.dim, kernel, image, orthogonal_rot, self.padding)
        result = convolve(ndims, kernel, image, orthogonal_rot, padding)
        # Reshape to get real and imaginary response on the first axis.
        _dim = 2 if orthogonal_rot else 1
        nb_rot = int(result.shape[_dim]/2)
        result = np.stack(np.array_split(result, np.array([nb_rot]), _dim), axis=0)

        # 2D modulus response map
        result = np.linalg.norm(result, axis=0)

        # Rotation invariance.
        result = np.mean(result, axis=2) if orthogonal_rot else np.mean(result, axis=1)

        # Aggregate orthogonal rotation
        result = np.mean(result, axis=0) if orthogonal_rot else result
            
        return np.swapaxes(result, 1, 3)


    def gabor_filtering(self, input_images:Union[np.ndarray, sitk.Image], ndims: int = 3, size: int = 5, sigma: float = 10, lamb: float = 0, gamma: float = 0, theta: float = 0, orthogonal_rot: bool = False, padding: str = "symmetric", average_pooling: bool = False) -> np.ndarray:
        
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.gabor_filter(ndims, size, input_images, sigma, lamb, gamma, theta, orthogonal_rot, padding)
        
        if average_pooling:
            result = np.squeeze(result)
            print(result.shape)
            pooled_result = self.average_pooling(result)
            print(pooled_result.shape)
#             result = self.upsample3d(pooled_result, result.shape)
#             print(result.shape)
        
        return np.squeeze(result)
           
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    def laws_filter(self, ndims: int, config: List[str], images: np.ndarray, energy_distance: int = 7,
                    rot_invariance: bool = False, orthogonal_rot: bool = False, padding: str = "symmetric", energy_image: bool = False) -> np.ndarray:
        """
        Apply Laws filter to the input images.

        Args:
            ndims (int): Number of dimensions for the filter.
            config (List[str]): A list of strings specifying the 1D filters to use.
            images (np.ndarray): The input images to be filtered.
            energy_distance (int): Distance for creating the energy kernel.
            rot_invariance (bool): If true, apply rotation invariance.
            orthogonal_rot (bool): If true, apply orthogonal rotation.
            padding (str): The type of padding to use.

        Returns:
            np.ndarray: The filtered images.
        """
        
        if images.ndim == 3:  # Assuming the input is a 3D image
            images = np.expand_dims(images, axis=0)  # Add batch dimension
        elif images.ndim != 4:  # Must be 4D (batch, channels, depth, height, width)
            raise ValueError("Input images must be a 3D or 4D array.")

        ndims = len(config)
        self.config = config
        self.energy_dist = energy_distance
        self.dim = ndims
        self.padding = padding
        self.rot = rot_invariance
        self.energy_kernel = None

        def __get_filter(name,
                         pad=False) -> np.ndarray:
            """This method create a 1D filter according to the given filter name.

            Args:
                name (float): The filter name. (Such as L3, L5, E3, E5, S3, S5, W5 or R5)
                pad (bool): If true, add zero padding of length 1 each side of kernel L3, E3 and S3

            Returns:
                ndarray: A 1D filter that is needed to construct the Laws kernel.
            """

            if name == "L3":
                ker = np.array([0, 1, 2, 1, 0]) if pad else np.array([1, 2, 1])
                return 1/math.sqrt(6) * ker
            elif name == "L5":
                return 1/math.sqrt(70) * np.array([1, 4, 6, 4, 1])
            elif name == "E3":
                ker = np.array([0, -1, 0, 1, 0]) if pad else np.array([-1, 0, 1])
                return 1 / math.sqrt(2) * ker
            elif name == "E5":
                return 1 / math.sqrt(10) * np.array([-1, -2, 0, 2, 1])
            elif name == "S3":
                ker = np.array([0, -1, 2, -1, 0]) if pad else np.array([-1, 2, -1])
                return 1 / math.sqrt(6) * ker
            elif name == "S5":
                return 1 / math.sqrt(6) * np.array([-1, 0, 2, 0, -1])
            elif name == "W5":
                return 1 / math.sqrt(10) * np.array([-1, 2, 0, -2, 1])
            elif name == "R5":
                return 1 / math.sqrt(70) * np.array([1, -4, 6, -4, 1])
            else:
                raise Exception(f"{name} is not a valid filter name. "
                                "Choose between : L3, L5, E3, E5, S3, S5, W5 or R5")
                
        def __compute_energy_image(self,
                                   images: np.ndarray) -> np.ndarray:
            """Compute the Laws texture energy images as described in (Ref 1).

            Args:
                images (ndarray): A n-dimensional numpy array that represent the filtered images

            Returns:
                ndarray: A numpy multi-dimensional array of the Laws texture energy map.
            """
            # If we have a 2D kernel but a 3D images, we swap dimension channel with dimension batch.
            images = np.swapaxes(images, 0, 1)

            # absolute image intensities are used in convolution
            result = fftconvolve(np.abs(images), self.energy_kernel, mode='valid') 

            if self.dim == 2:
                return np.swapaxes(result, axis1=0, axis2=1)
            else:
                return np.squeeze(result, axis=1)
            
        def create_energy_kernel(energy_dist) -> np.ndarray:
            """Create the kernel that will be used to generate Laws texture energy images

            Returns:
                ndarray: A numpy multi-dimensional arrays that represent the Laws energy kernel.
            """

            # Initialize the kernel as tensor of zeros
            kernel = np.zeros([self.energy_dist*2+1 for _ in range(self.dim)])

            for k in product(range(self.energy_dist*2 + 1), repeat=self.dim):
                position = np.array(k)-self.energy_dist
                kernel[k] = 1 if np.max(abs(position)) <= self.energy_dist else 0

            return np.expand_dims(kernel/np.prod(kernel.shape), axis=(0, 1))
            

        ker_length = np.array([int(name[-1]) for name in self.config])

        pad = not(ker_length.min == ker_length.max)
        
        filter_list = np.array([[__get_filter(name, pad) for name in self.config]])

        if self.rot:
            filter_list = np.concatenate((filter_list, np.flip(filter_list, axis=2)), axis=0)
            prod_list = [prod for prod in product(*np.swapaxes(filter_list, 0, 1))]

            perm_list = []
            for i in range(len(prod_list)):
                perm_list.extend([perm for perm in permutations(prod_list[i])])

            filter_list = np.unique(perm_list, axis=0)

        kernel_list = []
        for perm in filter_list:
            kernel = perm[0]
            shape = kernel.shape

            for i in range(1, len(perm)):
                sub_kernel = perm[i]
                shape += np.shape(sub_kernel)
                kernel = np.outer(sub_kernel, kernel).reshape(shape)
            if self.dim == 3:
                kernel_list.extend([np.expand_dims(np.flip(kernel, axis=(1, 2)), axis=0)])
            else:
                kernel_list.extend([np.expand_dims(np.flip(kernel, axis=(0, 1)), axis=0)])

        kernel = np.unique(kernel_list, axis=0)
        
        energy_kernel = create_energy_kernel(self.energy_dist)
        
        print(images.shape)
        print(kernel.shape)
        
                    
        images = np.swapaxes(images, 1, 3)

        if orthogonal_rot:
            raise NotImplementedError

        result = convolve(self.dim, kernel, images, orthogonal_rot, self.padding)
        result = np.amax(result, axis=1) if self.dim == 2 else np.amax(result, axis=0)
        
        if energy_image:
            # We pad the response map
            result = np.expand_dims(result, axis=1) if self.dim == 3 else result
            ndims = len(result.shape)

            padding = [self.energy_dist for _ in range(2 * self.dim)]
            pad_axis_list = [i for i in range(ndims - self.dim, ndims)]

            response = pad_imgs(result, padding, pad_axis_list, self.padding)

            # Free memory
            del result

            # We compute the energy map and we squeeze the second dimension of the energy maps.
            energy_imgs = self.__compute_energy_image(response)

            return np.swapaxes(energy_imgs, 1, 3)
        else:
            return np.swapaxes(result, 1, 3)


    def laws_filtering(self, input_images: Union[np.ndarray, sitk.Image], ndims: int = 3, config = ['E5', 'L5', 'S5'], energy_distance: int = 7, rot_invariance: bool = False, orthogonal_rot: bool = False, padding: str = "constant", energy_image = False, max_pooling: bool = False) -> np.ndarray:
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.laws_filter(ndims, config, input_images, energy_distance, rot_invariance, orthogonal_rot=False, padding="constant", energy_image=False)

        if max_pooling:
            result = np.squeeze(result)
            print(result.shape)
            pooled_result = self.max_pooling(result)
            print(pooled_result.shape)
            result = self.upsample3d(pooled_result, result.shape)
            print(result.shape)
        
        return np.squeeze(result)
    
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    def wavelet_filter(self, ndims: int, size: int, images: np.ndarray, wavelet_name: str, rot_invariance: bool = False, padding: str = "symmetric", level:int = 1, wt_filter: str = "LHL") -> np.ndarray:
        """
        Apply Laws filter to the input images.

        Args:
            ndims (int): Number of dimensions for the filter.
            config (List[str]): A list of strings specifying the 1D filters to use.
            images (np.ndarray): The input images to be filtered.
            energy_distance (int): Distance for creating the energy kernel.
            rot_invariance (bool): If true, apply rotation invariance.
            orthogonal_rot (bool): If true, apply orthogonal rotation.
            padding (str): The type of padding to use.

        Returns:
            np.ndarray: The filtered images.
        """
        self.dim = ndims
        self.padding = padding
        self.rot = rot_invariance
        
        _filter = wt_filter
        wavelet = pywt.Wavelet(wavelet_name)
        
        self.wavelet = wavelet
        kernel_length = max(wavelet.rec_len, wavelet.dec_len)
        
        image_shape = np.shape(images[0])
        level = 1
        padding = []
        ker_length = kernel_length * level
        
        for l in image_shape:
            padded_length = math.ceil((l + 2*(ker_length-1)) / 2**level) * 2**level - l
            padding.extend([math.floor(padded_length/2), math.ceil(padded_length/2)])
            
        axis_list = [i for i in range(0, self.dim)]

        pad_tuple = ()
        j=0
        for i in range(np.ndim(images[0])):
            if i in axis_list:
                pad_tuple += ((padding[j], padding[j+1]),) 
                j +=2
            else:
                pad_tuple += ((0,0),)
        
#         print(pad_tuple)
#         print(images.shape)
        arr = np.pad(images[0], pad_tuple, mode=self.padding)
        
        images_padded = np.expand_dims(arr, axis=0)
        
        _index = str().join(['a' if _filter[i] == 'L' else 'd' for i in range(len(_filter))])

        if self.rot:
            result = []
            _index_list = np.unique([str().join(perm) for perm in permutations(_index, self.dim)])

            # For each images, we flip each axis.
            for image in images:
                axis_rot = [comb for j in range(self.dim+1) for comb in combinations(np.arange(self.dim), j)]
                images_rot = [np.flip(image, axis) for axis in axis_rot]

                res_rot = []
                for i in range(len(images_rot)):
                    filtered_image = pywt.swtn(images_rot[i], self.wavelet, level=level)[0]
                    res_rot.extend([np.flip(filtered_image[j], axis=axis_rot[i]) for j in _index_list])

                result.extend([np.mean(res_rot, axis=0)])
        else:
            result = []
            for i in range(len(images)):
                coeffs = pywt.swtn(images[i], wavelet, level=level)
                filtered_image = coeffs[level - 1]['dad']
                result.append(filtered_image)
                
#             return np.array(result)

        return np.swapaxes(result, 1, 2)

            
    def wavelet_filtering(self, input_images: Union[np.ndarray, sitk.Image], ndims: int = 3, size: int = 5, wavelet_name: str ="haar", rot_invariance: bool = False, padding: str = "constant", level: int = 1, average_pooling: bool = False, wt_filter: str = "LHL") -> np.ndarray:
        
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.wavelet_filter(ndims, size, input_images, wavelet_name, rot_invariance, padding, level, wt_filter)

        if average_pooling:
            result = np.squeeze(result)
            print(result.shape)
            pooled_result = self.average_pooling(result)
            print(pooled_result.shape)
#             result = self.upsample3d(pooled_result, result.shape)
#             print(result.shape)

        return np.squeeze(result)
        
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    def riesz_transform(self, image: np.ndarray, l: tuple, aligned_str_tensor: bool = False, sigma_tensor: float = None) -> np.ndarray:
        """Compute Riesz transform of an input image."""
        fft_image = np.fft.fftn(image)

        if image.ndim < 3:
            raise ValueError("Input image must have at least three dimensions.")

        nx, ny, nz = image.shape[-3:]

        kx = np.fft.fftfreq(nx).reshape(-1, 1, 1)
        ky = np.fft.fftfreq(ny).reshape(1, -1, 1)
        kz = np.fft.fftfreq(nz).reshape(1, 1, -1)

        if aligned_str_tensor:
            gradient_filter = sitk.GradientRecursiveGaussianImageFilter()
            gradient_filter.SetSigma(sigma_tensor)
            gradient = gradient_filter.Execute(sitk.GetImageFromArray(image))

            gradient_np = sitk.GetArrayFromImage(gradient)
            # Verify gradient_np shape here: (3, nx, ny, nz)

            J = np.zeros((3, 3, *gradient_np.shape[:3]))

            for i in range(3):
                for j in range(3):
                    J[i, j] = gradient_np[..., i] * gradient_np[..., j]

            # Verify J tensor shape here: (3, 3, nx, ny, nz)

            tensor_eigenvalues, tensor_eigenvectors = np.linalg.eigh(J)
            dominant_eigenvector = tensor_eigenvectors[..., -1]  # Select the dominant eigenvector
            # Adjust transpose or reshape as needed:
            dominant_eigenvector = dominant_eigenvector.transpose(2, 0, 1)  # Or try other combinations

            kx = kx * dominant_eigenvector[0] + ky * dominant_eigenvector[1] + kz * dominant_eigenvector[2]
            ky = kx * dominant_eigenvector[1] + ky * dominant_eigenvector[1] + kz * dominant_eigenvector[2]
            kz = kx * dominant_eigenvector[2] + ky * dominant_eigenvector[2] + kz * dominant_eigenvector[2]

        riesz_component = l[0] * (1j * kx) + l[1] * (1j * ky) + l[2] * (1j * kz)
        riesz_transformed_fft = riesz_component * fft_image
        riesz_transformed_image = np.fft.ifftn(riesz_transformed_fft).real

        return riesz_transformed_image

    def riesz_filtering(self, input_images: np.ndarray, l: tuple, aligned_str_tensor: bool = False, sigma_tensor: float = None) -> np.ndarray:
        """Apply Riesz transform to input images."""
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)
        
        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)

        result = self.riesz_transform(input_images, l, aligned_str_tensor, sigma_tensor)
        
        return np.squeeze(result)
    
    def riesz_then_log_filtering(self, input_images, l: tuple, aligned_str_tensor: bool = False, sigma_tensor: float = None, ndims: int = 3, size: int = 15, sigma: float = 5.0, padding: str = "symmetric") -> np.ndarray:
        riesz_image = self.riesz_filtering(input_images, l, aligned_str_tensor, sigma_tensor)
        
        riesz_image = sitk.GetImageFromArray(riesz_image)
        riesz_image.CopyInformation(input_images)
        
        
        log_filtered_image = self.log_filtering(riesz_image, ndims, size, sigma, padding)
        
        return log_filtered_image
    
 

In [284]:
if __name__ == "__main__":
    
    filtering = Filtering()

    checkerboard_path = "/Users/kamleshranabhat/Desktop/test_dataset/checkerboard/image/checkerboard.nii"
    impulse_path = "/Users/kamleshranabhat/Desktop/test_dataset/impulse/image/impulse.nii"
    sphere_path = "/Users/kamleshranabhat/Desktop/test_dataset/sphere/image/sphere.nii"
    pattern1_path = "/Users/kamleshranabhat/Desktop/test_dataset/pattern_1/image/pattern_1.nii"
    
    output_path = "/Users/kamleshranabhat/Desktop/test_dataset/pattern_1/image/34.nii"
    
    
    checkerboard_image = sitk.ReadImage(checkerboard_path)
    impulse_image = sitk.ReadImage(impulse_path)
    sphere_image = sitk.ReadImage(sphere_path)
    pattern1_image = sitk.ReadImage(pattern1_path)

#     filtered_image_np = filtering.mean_filtering(impulse_image, ndims=2, size = 15, orthogonal_rot =False, padding = "constant" )
#     filtered_image_np= filtering.log_filtering(checkerboard_image, ndims=2, size=21, sigma=5, orthogonal_rot=False,padding= 'symmetric')
#     filtered_image_np = filtering.gabor_filtering(impulse_image, ndims = 2, size = 15, sigma = 10.0, lamb = 4.0,gamma = 0.5, theta = math.pi/4, orthogonal_rot = False, padding = "constant", average_pooling = True)                                               
#     filtered_image_np = filtering.laws_filtering(checkerboard_image, ndims=2, config = ['L5','S5'], energy_distance=7, rot_invariance= True, orthogonal_rot= False, padding= 'symmetric', energy_image= False, max_pooling=True)                                          
#     filtered_image_np = filtering.wavelet_filtering(checkerboard_image, ndims=3,size=15, wavelet_name = "sym2",rot_invariance=False, padding="wrap", level = 1, wt_filter = "B", average_pooling = False)
    filtered_image_np = filtering.riesz_filtering(pattern1_image, l=(0, 2, 0), sigma_tensor = None, aligned_str_tensor=False)
#     checkerboard_image = sitk.ReadImage(output_path)
#     filtered_image_np= filtering.log_filtering(checkerboard_image, ndims=3, size=15, sigma=3, orthogonal_rot=False, padding= 'constant')
#     filtered_image_np = filtering.wavelet_filtering(checkerboard_image, ndims=3,size=15, wavelet_name = "sym2",rot_invariance=False, padding="edge", level = 1, wt_filter = "B", average_pooling = False)

                                   

                        
#     filtered_image_np = filtering.log_filtering(checkerboard_image, ndims=3, sigma = 5, padding = "reflect" )
#     filtered_image_np = filtering.laws_filtering(checkerboard_image, ndims=3, sigma = 5, padding = "reflect" )
#     filtered_image_np = filtering.gabor_filtering(checkerboard_image, ndims = 2, size = 5, sigma = 10, lamb = 4, gamma = 0.5, theta = math.pi/3, orthogonal_rot = False, padding = "constant")
#     filtered_image_np = filtering.wavelet_filtering(checkerboard_image, ndims=3, wavelet_name = "sym2", rot_invariance=False, padding="wrap", level = 3, wt_filter = "B", avg_pooling = False)
#     filtered_image_np = filtering.riesz_filtering(checkerboard_image, ndims=3, sigma = 5, padding = "reflect" )

    
    filtered_image_sitk = sitk.GetImageFromArray(filtered_image_np)
    filtered_image_sitk.CopyInformation(checkerboard_image)
    

    print(f"Filtered image saved at: {output_path}")
    
    sitk.WriteImage(filtered_image_sitk, output_path)

Filtered image saved at: /Users/kamleshranabhat/Desktop/test_dataset/pattern_1/image/34.nii


In [222]:
wavelet = pywt.Wavelet("sym2")

for i in range(len(images)):
    coeffs = pywt.swtn(images[i], wavelet, level=1)
    # Print the keys of the wavelet coefficients for debugging
    print(f"Available keys in coeffs at level {level-1}: {list(coeffs[level-1].keys())}")



NameError: name 'images' is not defined

In [130]:
import numpy as np

image_shape = (1, 64, 64, 64)

pad_tuple = ((3, 3), (3, 3), (3, 3), (0,0))

image = np.random.rand(*image_shape)

padded_image = np.pad(image, pad_tuple, mode='constant')

print(image.shape)
print(padded_image.shape)


(1, 64, 64, 64)
(7, 70, 70, 64)


In [123]:
import pywt

wavelet_list = pywt.wavelist()

for wavelet_name in wavelet_list:
    try:
        wavelet = pywt.Wavelet(wavelet_name)
        family_name = wavelet.family_name
        print(f"{wavelet_name}: {family_name}")
    except ValueError:
        # Handle continuous wavelets separately
        wavelet = pywt.ContinuousWavelet(wavelet_name)
        family_name = "Continuous"
        print(f"{wavelet_name}: {family_name}")

        

bior1.1: Biorthogonal
bior1.3: Biorthogonal
bior1.5: Biorthogonal
bior2.2: Biorthogonal
bior2.4: Biorthogonal
bior2.6: Biorthogonal
bior2.8: Biorthogonal
bior3.1: Biorthogonal
bior3.3: Biorthogonal
bior3.5: Biorthogonal
bior3.7: Biorthogonal
bior3.9: Biorthogonal
bior4.4: Biorthogonal
bior5.5: Biorthogonal
bior6.8: Biorthogonal
cgau1: Continuous
cgau2: Continuous
cgau3: Continuous
cgau4: Continuous
cgau5: Continuous
cgau6: Continuous
cgau7: Continuous
cgau8: Continuous
cmor: Continuous
coif1: Coiflets
coif2: Coiflets
coif3: Coiflets
coif4: Coiflets
coif5: Coiflets
coif6: Coiflets
coif7: Coiflets
coif8: Coiflets
coif9: Coiflets
coif10: Coiflets
coif11: Coiflets
coif12: Coiflets
coif13: Coiflets
coif14: Coiflets
coif15: Coiflets
coif16: Coiflets
coif17: Coiflets
db1: Daubechies
db2: Daubechies
db3: Daubechies
db4: Daubechies
db5: Daubechies
db6: Daubechies
db7: Daubechies
db8: Daubechies
db9: Daubechies
db10: Daubechies
db11: Daubechies
db12: Daubechies
db13: Daubechies
db14: Daubechies


  wavelet = pywt.ContinuousWavelet(wavelet_name)
  wavelet = pywt.ContinuousWavelet(wavelet_name)
  wavelet = pywt.ContinuousWavelet(wavelet_name)


In [None]:
impulse_path = "/Users/kamleshranabhat/Desktop/test_dataset/impulse/image/impulse.nii"
sphere_path = "/Users/kamleshranabhat/Desktop/test_dataset/sphere/image/sphere.nii"
checkerboard = "/Users/kamleshranabhat/Desktop/test_dataset/checkerboard/image/checkerboard.nii"

print(math.pi/3)

In [None]:
import numpy as np
from scipy.signal import fftconvolve
import math
from itertools import permutations, product


class filtering:
    
    def __init__(self):
        self.config = None
        self.energy_dist = None
        self.dim = None
        self.padding = None
        self.rot = None
        self.energy_kernel = None

    def laws_filter(self, ndims: int, config: list, images: np.ndarray, energy_distance: int = 7,
                    rot_invariance: bool = False, orthogonal_rot: bool = False, padding: str = "symmetric", energy_image: bool = False) -> np.ndarray:
        """
        Apply Laws filter to the input images.

        Args:
            ndims (int): Number of dimensions for the filter.
            config (List[str]): A list of strings specifying the 1D filters to use.
            images (np.ndarray): The input images to be filtered.
            energy_distance (int): Distance for creating the energy kernel.
            rot_invariance (bool): If true, apply rotation invariance.
            orthogonal_rot (bool): If true, apply orthogonal rotation.
            padding (str): The type of padding to use.
            energy_image (bool): If true, compute energy image.

        Returns:
            np.ndarray: The filtered images.
        """
        if images.ndim == 3:
            images = np.expand_dims(images, axis=0)  # Add batch dimension
        elif images.ndim != 4:
            raise ValueError("Input images must be a 3D or 4D array.")

        self.config = config
        self.energy_dist = energy_distance
        self.dim = ndims
        self.padding = padding
        self.rot = rot_invariance

        filter_list = np.array([self.__get_filter(name, pad=(ndims == 3)) for name in config])

        if rot_invariance:
            filter_list = np.concatenate((filter_list, np.flip(filter_list, axis=-1)), axis=0)
            filter_list = np.array([list(prod) for prod in product(*filter_list.T)])
        else:
            filter_list = np.array([filter_list])

        kernel_list = []
        for perm in filter_list:
            kernel = perm[0]
            shape = kernel.shape

            for i in range(1, len(perm)):
                sub_kernel = perm[i]
                shape += np.shape(sub_kernel)
                kernel = np.outer(sub_kernel, kernel).reshape(shape)

            kernel_list.append(np.expand_dims(kernel, axis=0))

        kernel = np.unique(kernel_list, axis=0)

        ekernel = np.zeros([energy_distance*2+1 for _ in range(ndims)])
        for k in product(range(energy_distance*2 + 1), repeat=ndims):
            position = np.array(k) - energy_distance
            ekernel[k] = 1 if np.max(abs(position)) <= energy_distance else 0

        self.energy_kernel = np.expand_dims(ekernel / np.prod(ekernel.shape), axis=tuple(range(2 * ndims)))

        images = np.swapaxes(images, 1, 3)  # Swap axes to match fftconvolve requirement
        
        print(images.shape)
        print(kernel.shape)

        result = []
        for img in images:
            conv_result = np.array([fftconvolve(img, k, mode='same') for k in kernel])
            result.append(np.amax(conv_result, axis=0))

        result = np.array(result)

        if energy_image:
            # Compute energy image if required
            energy_imgs = self.__compute_energy_image(result)
            return np.swapaxes(energy_imgs, 1, 3)
        else:
            return np.swapaxes(result, 1, 3)

    @staticmethod
    def __get_filter(name, pad=False) -> np.ndarray:
        if name == "L3":
            ker = np.array([0, 1, 2, 1, 0]) if pad else np.array([1, 2, 1])
            return 1 / math.sqrt(6) * ker
        elif name == "L5":
            return 1 / math.sqrt(70) * np.array([1, 4, 6, 4, 1])
        elif name == "E3":
            ker = np.array([0, -1, 0, 1, 0]) if pad else np.array([-1, 0, 1])
            return 1 / math.sqrt(2) * ker
        elif name == "E5":
            return 1 / math.sqrt(10) * np.array([-1, -2, 0, 2, 1])
        elif name == "S3":
            ker = np.array([0, -1, 2, -1, 0]) if pad else np.array([-1, 2, -1])
            return 1 / math.sqrt(6) * ker
        elif name == "S5":
            return 1 / math.sqrt(6) * np.array([-1, 0, 2, 0, -1])
        elif name == "W5":
            return 1 / math.sqrt(10) * np.array([-1, 2, 0, -2, 1])
        elif name == "R5":
            return 1 / math.sqrt(70) * np.array([1, -4, 6, -4, 1])
        else:
            raise ValueError(f"{name} is not a valid filter name. Choose between: L3, L5, E3, E5, S3, S5, W5, or R5.")

    def __compute_energy_image(self, images: np.ndarray) -> np.ndarray:
        images = np.swapaxes(images, 0, 1)
        result = fftconvolve(np.abs(images), self.energy_kernel, mode='valid')
        return np.swapaxes(result, 0, 1)
        
    def laws_filtering(self, input_images: Union[np.ndarray, sitk.Image], ndims: int = 3, config = ['E5', 'L5', 'S5'], energy_distance: int = 7, rot_invariance: bool = False, orthogonal_rot: bool = False, padding: str = "constant", energy_image = False, max_pooling: bool = False) -> np.ndarray:
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)



        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)

        result = self.laws_filter(ndims, config, input_images, energy_distance, rot_invariance, orthogonal_rot=False, padding="constant", energy_image=False)

        if max_pooling:
            result = self.max_pooling(result, pool_size=2)  # Adjust the pool_size as needed

        return np.squeeze(result)



In [None]:
if __name__ == "__main__":
    filtering = filtering()

    checkerboard_path = "/Users/kamleshranabhat/Desktop/test_dataset/checkerboard/image/checkerboard.nii"
    checkerboard_image = sitk.ReadImage(checkerboard_path)
#     checkerboard_image = sitk.GetImageFromArray(filtered_image_np)

#     filtered_image_np = filtering.gabor_filtering(checkerboard_image, ndims = 2, size = 5, sigma = 10, lamb = 4, gamma = 0.5, theta = math.pi/3, orthogonal_rot = False, padding = "constant")
    
    filtered_image_np = filtering.laws_filtering(checkerboard_image, ndims=3, config=['E5', 'L5', 'S5'], energy_distance=7, rot_invariance=True, orthogonal_rot=False, padding="symmetric", energy_image = True, max_pooling=True)

    # Convert back to SimpleITK image if needed
    filtered_image_sitk = sitk.GetImageFromArray(filtered_image_np)
    filtered_image_sitk.CopyInformation(checkerboard_image)
    
    output_path = "/Users/kamleshranabhat/Desktop/test_dataset/checkerboard/image/3D_mirrorpadding_rotinvariance_maxpooling_energymap_lawsfiltered_checkerboard.nii"
    
    sitk.WriteImage(filtered_image_sitk, output_path)

In [28]:
if __name__ == "__main__":
    filtering = Filtering()

    checkerboard_path = "/Users/kamleshranabhat/Desktop/test_dataset/impulse/image/impulse.nii"
    checkerboard_image = sitk.ReadImage(checkerboard_path)
    
    l = (1, 0, 0)
    aligned_str_tensor = False
    sigma_tensor = 1.0

    filtered_image_np = filtering.riesz_then_log_filtering(
        checkerboard_image, 
        l=l, 
        aligned_str_tensor=aligned_str_tensor, 
        sigma_tensor=sigma_tensor, 
        ndims=3, 
        size=5, 
        sigma=3, 
        padding="constant"
    )

    filtered_image_sitk = sitk.GetImageFromArray(filtered_image_np)
    filtered_image_sitk.CopyInformation(checkerboard_image)
    
    output_path = "/Users/kamleshranabhat/Desktop/test_dataset/impulse/image/zeropadding_rieszlogfiltered_100_impulse.nii"
    
    sitk.WriteImage(filtered_image_sitk, output_path)


TypeError: convolve() missing 1 required positional argument: 'images'

In [104]:
import numpy as np
import SimpleITK as sitk
from itertools import product
import math
from scipy.ndimage import convolve

class Filtering:

    def __init__(self):
        pass

    def log_filter(self, ndims:int, size: int, images: np.ndarray, sigma: float, orthogonal_rot: bool = False, padding="constant") -> np.ndarray:
        """The constructor of the laplacian of gaussian (LoG) filter

        Args:
            ndims (int): Number of dimension of the kernel filter
            size (int): An integer that represent the length along one dimension of the kernel.
            sigma (float): The gaussian standard deviation parameter of the laplacian of gaussian filter
            padding (str): The padding type that will be used to produce the convolution

        Returns:
            None
        """
        assert isinstance(ndims, int) and ndims > 0, "ndims should be a positive integer"
        assert ((size+1)/2).is_integer() and size > 0, "size should be a positive odd number."
        assert sigma > 0, "alpha should be a positive float."
        self.dim = ndims
        self.size = size
        self.sigma = sigma
        
        def compute_weight(position):
            distance_2 = np.sum(position**2)
            first_part = -1/((2*math.pi)**(self.dim/2) * self.sigma**(self.dim+2))
            second_part = (self.dim - distance_2/self.sigma**2)*math.e**(-distance_2/(2 * self.sigma**2))

            return first_part * second_part

        kernel = np.zeros([self.size for _ in range(self.dim)])

        for k in product(range(self.size), repeat=self.dim):
            kernel[k] = compute_weight(np.array(k)-int((self.size-1)/2))

        kernel -= np.sum(kernel)/np.prod(kernel.shape)
        kernel = np.expand_dims(kernel, axis=(0, 1))
        
#         # Ensure images is at least 4-dimensional (B, W, H, D)
#         if images.ndim < 4:
#             raise ValueError("Input images must have at least 4 dimensions (B, W, H, D)")
            
        print(kernel.shape)

        image = np.swapaxes(images, 1, 3)
        print(image.shape)
        result = np.squeeze(convolve(ndims, kernel, image, orthogonal_rot, padding), axis=1)
        
        return np.swapaxes(result, 1, 3)
        

    def log_filtering(self, input_images:Union[np.ndarray, sitk.Image], ndims: int = 3, size: int = 15, sigma: int = 3, orthogonal_rot: bool = False, padding: str = "symmetric") -> np.ndarray:
        
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.log_filter(ndims, size, input_images, sigma, orthogonal_rot, padding)
        
        return np.squeeze(result)

    def riesz_transform(self, image: np.ndarray, l: tuple, aligned_str_tensor: bool = False, sigma_tensor: float = None) -> np.ndarray:
        """Compute Riesz transform of an input image."""
        fft_image = np.fft.fftn(image)

        # Check image dimensions
        if image.ndim < 3:
            raise ValueError("Input image must have at least three dimensions.")

        
        print(image.shape)
        nx, ny, nz = image.shape[-3:]
        
        kx = np.fft.fftfreq(nx).reshape(-1, 1, 1)
        ky = np.fft.fftfreq(ny).reshape(1, -1, 1)
        kz = np.fft.fftfreq(nz).reshape(1, 1, -1)

        if aligned_str_tensor:
            gradient_filter = sitk.GradientRecursiveGaussianImageFilter()
            gradient_filter.SetSigma(sigma_tensor)
            gradient = gradient_filter.Execute(sitk.GetImageFromArray(image))

            gradient_np = sitk.GetArrayFromImage(gradient)
            J = np.zeros((3, 3, *gradient_np.shape[:3]))

            for i in range(3):
                for j in range(3):
                    J[i, j] = gradient_np[..., i] * gradient_np[..., j]

            tensor_eigenvalues, tensor_eigenvectors = np.linalg.eigh(J)
            dominant_eigenvector = tensor_eigenvectors[:, :, :, 2]
            kx = kx * dominant_eigenvector[0] + ky * dominant_eigenvector[1] + kz * dominant_eigenvector[2]
            ky = kx * dominant_eigenvector[1] + ky * dominant_eigenvector[1] + kz * dominant_eigenvector[2]
            kz = kx * dominant_eigenvector[2] + ky * dominant_eigenvector[2] + kz * dominant_eigenvector[2]

        riesz_component = l[0] * (1j * kx) + l[1] * (1j * ky) + l[2] * (1j * kz)
        riesz_transformed_fft = riesz_component * fft_image
        riesz_transformed_image = np.fft.ifftn(riesz_transformed_fft).real
        
        return riesz_transformed_image
    
    def riesz_filtering(self, input_images: np.ndarray, l: tuple, aligned_str_tensor: bool = False, sigma_tensor: float = None) -> np.ndarray:
        """Apply Riesz transform to input images."""
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)
        
        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)

        result = self.riesz_transform(input_images, l, aligned_str_tensor, sigma_tensor)
        
        return np.squeeze(result)
    
    def riesz_then_log_filtering(self, input_images, l: tuple, aligned_str_tensor: bool = False, sigma_tensor: float = None, ndims: int = 3, size: int = 4, sigma: float = 5.0, padding: str = "symmetric") -> np.ndarray:
        riesz_image = self.riesz_filtering(input_images, l, aligned_str_tensor, sigma_tensor)
        
        riesz_image = sitk.GetImageFromArray(riesz_image)
        riesz_image.CopyInformation(input_images)
        
        log_filtered_image = self.log_filtering(riesz_image, ndims, size, sigma, padding)
        
        return log_filtered_image

# , ndims=3, size=5, sigma=3, padding="constant"

In [109]:
if __name__ == "__main__":
    filtering = Filtering()

    checkerboard_path = "/Users/kamleshranabhat/Desktop/test_dataset/impulse/image/zeropadding_rieszfiltered_100_impulse.nii"
    checkerboard_image = sitk.ReadImage(checkerboard_path)

    l = (0, 2, 0)
    aligned_str_tensor = True
    sigma_tensor = 1.0

    filtered_image_np = filtering.wavelet_filtering(checkerboard_image, ndims=3, wavelet_name = "sym2", rot_invariance=False, padding="constant", level = 1, wt_filter = "B", avg_pooling = False)

    filtered_image_sitk = sitk.GetImageFromArray(filtered_image_np)
    filtered_image_sitk.CopyInformation(checkerboard_image)

    output_path = "/Users/kamleshranabhat/Desktop/test_dataset/impulse/image/zeropadding_rieszthenSimoncellifiltered_100_impulse.nii"

    sitk.WriteImage(filtered_image_sitk, output_path)


(1, 70, 70, 70)


RuntimeError: Exception thrown in SimpleITK Image_CopyInformation: /tmp/SimpleITK/Code/Common/src/sitkImage.cxx:308:
sitk::ERROR: Source image size of [ 64, 64, 64 ] does not match this image's size of [ 70, 70, 560 ]!

In [None]:
if __name__ == "__main__":
    filtering = Filtering()

    checkerboard_path = "/Users/kamleshranabhat/Desktop/test_dataset/sphere/image/zeropadding_rieszfiltered_020_sphere.nii"
    checkerboard_image = sitk.ReadImage(checkerboard_path)

    l = (0, 2, 0)
    aligned_str_tensor = False
    sigma_tensor = 1.0

    filtered_image_np = filtering.log_filtering(checkerboard_image, ndims=3, size=5, sigma=3, padding="constant")

    filtered_image_sitk = sitk.GetImageFromArray(filtered_image_np)
    filtered_image_sitk.CopyInformation(checkerboard_image)

    output_path = "/Users/kamleshranabhat/Desktop/test_dataset/sphere/image/zeropadding_rieszthenlogfiltered_020_sphere.nii"

    sitk.WriteImage(filtered_image_sitk, output_path)
    

In [70]:
class filtering:
    def log_filter(self, ndims:int, size: int, images: np.ndarray, sigma: float, orthogonal_rot: bool = False, padding="constant") -> np.ndarray:
        """The constructor of the laplacian of gaussian (LoG) filter

        Args:
            ndims (int): Number of dimension of the kernel filter
            size (int): An integer that represent the length along one dimension of the kernel.
            sigma (float): The gaussian standard deviation parameter of the laplacian of gaussian filter
            padding (str): The padding type that will be used to produce the convolution

        Returns:
            None
        """
        assert isinstance(ndims, int) and ndims > 0, "ndims should be a positive integer"
        assert ((size+1)/2).is_integer() and size > 0, "size should be a positive odd number."
        assert sigma > 0, "alpha should be a positive float."
        self.dim = ndims
        self.size = size
        self.sigma = sigma
        
        def compute_weight(position):
            distance_2 = np.sum(position**2)
            first_part = -1/((2*math.pi)**(self.dim/2) * self.sigma**(self.dim+2))
            second_part = (self.dim - distance_2/self.sigma**2)*math.e**(-distance_2/(2 * self.sigma**2))

            return first_part * second_part

        kernel = np.zeros([self.size for _ in range(self.dim)])

        for k in product(range(self.size), repeat=self.dim):
            kernel[k] = compute_weight(np.array(k)-int((self.size-1)/2))

        kernel -= np.sum(kernel)/np.prod(kernel.shape)
        kernel = np.expand_dims(kernel, axis=(0, 1))
        
#         # Ensure images is at least 4-dimensional (B, W, H, D)
#         if images.ndim < 4:
#             raise ValueError("Input images must have at least 4 dimensions (B, W, H, D)")
            
        print(kernel.shape)

        image = np.swapaxes(images, 1, 3)
        print(image.shape)
        result = np.squeeze(convolve(ndims, kernel, image, orthogonal_rot, padding), axis=1)
        
        return np.swapaxes(result, 1, 3)
        

    def log_filtering(self, input_images:Union[np.ndarray, sitk.Image], ndims: int = 3, size: int = 15, sigma: int = 3, orthogonal_rot: bool = False, padding: str = "symmetric") -> np.ndarray:
        
        if isinstance(input_images, sitk.Image):
            input_images = sitk.GetArrayFromImage(input_images)

        input_images = np.expand_dims(input_images.astype(np.float64), axis=0)
        
        result = self.log_filter(ndims, size, input_images, sigma, orthogonal_rot, padding)
        
        return np.squeeze(result)

In [71]:
filtering=filtering()

In [72]:
checkerboard_path = "/Users/kamleshranabhat/Desktop/test_dataset/checkerboard/image/checkerboard.nii"
checkerboard_image = sitk.ReadImage(checkerboard_path)

filtered_image_np = filtering.log_filtering(checkerboard_image, ndims=3, sigma = 5, padding = "reflect" )


filtered_image_sitk = sitk.GetImageFromArray(filtered_image_np)
filtered_image_sitk.CopyInformation(checkerboard_image)

output_path = "/Users/kamleshranabhat/Desktop/test_dataset/checkerboard/image/test.nii"

sitk.WriteImage(filtered_image_sitk, output_path)

(1, 1, 15, 15, 15)
(1, 64, 64, 64)
