In [1]:
import numpy as np

# Load a .npy file
arr = np.load(r'C:\Users\h\Desktop\PET_Recons\100307_rfMRI_REST1_LR_norm_550.npy')
print(arr.shape)


(91, 109, 91, 550)


In [2]:
# first_volume = arr
# first_volume = arr[:, :, :, 0:32]
first_volume = arr[:, :, :, 0]
# first_volume = np.random.rand(3, 4, 5, 6, 7)

print(first_volume.shape)

(91, 109, 91)


In [3]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

first_volume = sigmoid(first_volume)

In [4]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def visualize_volume_cross_sections_comparison(original, reconstruction):
    """
    Create an interactive plotly visualization comparing original and reconstructed 3D volumes.
    
    Parameters:
    original (numpy.ndarray): 3D numpy array with shape (91, 109, 91)
    reconstruction (numpy.ndarray): 3D numpy array with same shape as original
    
    Returns:
    plotly.graph_objects.Figure: Interactive figure with comparison and error visualization
    """
    
    # Validate inputs
    if original.shape != reconstruction.shape:
        raise ValueError("Original and reconstruction must have the same shape")
    
    # Get volume dimensions
    d0, d1, d2 = original.shape
    
    # Calculate error volume
    error = original - reconstruction
    
    # Calculate color scale ranges
    vol_min = min(original.min(), reconstruction.min())
    vol_max = max(original.max(), reconstruction.max())
    error_abs_max = max(abs(error.min()), abs(error.max()))
    
    # Create 3x3 subplots
    subplot_titles = [
        'Original YZ', 'Original XZ', 'Original XY',
        'Reconstruction YZ', 'Reconstruction XZ', 'Reconstruction XY', 
        'Error YZ', 'Error XZ', 'Error XY'
    ]
    
    fig = make_subplots(
        rows=3, cols=3,
        subplot_titles=subplot_titles,
        vertical_spacing=0.08,
        horizontal_spacing=0.05
    )
    
    # Initialize with middle slices
    mid0, mid1, mid2 = d0//2, d1//2, d2//2
    
    # Add original volume heatmaps (top row)
    fig.add_trace(
        go.Heatmap(
            z=original[mid0, :, :],
            colorscale='Viridis',
            zmin=vol_min, zmax=vol_max,
            showscale=False,
            name='Original YZ'
        ), row=1, col=1
    )
    
    fig.add_trace(
        go.Heatmap(
            z=original[:, mid1, :],
            colorscale='Viridis',
            zmin=vol_min, zmax=vol_max,
            showscale=False,
            name='Original XZ'
        ), row=1, col=2
    )
    
    fig.add_trace(
        go.Heatmap(
            z=original[:, :, mid2],
            colorscale='Viridis',
            zmin=vol_min, zmax=vol_max,
            showscale=True,
            colorbar=dict(x=0.32, len=0.3, y=0.85),
            name='Original XY'
        ), row=1, col=3
    )
    
    # Add reconstruction heatmaps (middle row)
    fig.add_trace(
        go.Heatmap(
            z=reconstruction[mid0, :, :],
            colorscale='Viridis',
            zmin=vol_min, zmax=vol_max,
            showscale=False,
            name='Reconstruction YZ'
        ), row=2, col=1
    )
    
    fig.add_trace(
        go.Heatmap(
            z=reconstruction[:, mid1, :],
            colorscale='Viridis',
            zmin=vol_min, zmax=vol_max,
            showscale=False,
            name='Reconstruction XZ'
        ), row=2, col=2
    )
    
    fig.add_trace(
        go.Heatmap(
            z=reconstruction[:, :, mid2],
            colorscale='Viridis',
            zmin=vol_min, zmax=vol_max,
            showscale=False,
            name='Reconstruction XY'
        ), row=2, col=3
    )
    
    # Add error heatmaps (bottom row)
    fig.add_trace(
        go.Heatmap(
            z=error[mid0, :, :],
            colorscale='RdBu',
            zmin=-error_abs_max, zmax=error_abs_max,
            showscale=False,
            name='Error YZ'
        ), row=3, col=1
    )
    
    fig.add_trace(
        go.Heatmap(
            z=error[:, mid1, :],
            colorscale='RdBu',
            zmin=-error_abs_max, zmax=error_abs_max,
            showscale=False,
            name='Error XZ'
        ), row=3, col=2
    )
    
    fig.add_trace(
        go.Heatmap(
            z=error[:, :, mid2],
            colorscale='RdBu',
            zmin=-error_abs_max, zmax=error_abs_max,
            showscale=True,
            colorbar=dict(x=1.02, len=0.3, y=0.15),
            name='Error XY'
        ), row=3, col=3
    )
    
    # Create synchronized slider steps
    steps = []
    for axis in range(3):
        axis_steps = []
        axis_size = [d0, d1, d2][axis]
        
        for i in range(axis_size):
            if axis == 0:  # YZ plane
                z_data = [
                    original[i, :, :], original[:, mid1, :], original[:, :, mid2],
                    reconstruction[i, :, :], reconstruction[:, mid1, :], reconstruction[:, :, mid2],
                    error[i, :, :], error[:, mid1, :], error[:, :, mid2]
                ]
            elif axis == 1:  # XZ plane
                z_data = [
                    original[mid0, :, :], original[:, i, :], original[:, :, mid2],
                    reconstruction[mid0, :, :], reconstruction[:, i, :], reconstruction[:, :, mid2],
                    error[mid0, :, :], error[:, i, :], error[:, :, mid2]
                ]
            else:  # XY plane
                z_data = [
                    original[mid0, :, :], original[:, mid1, :], original[:, :, i],
                    reconstruction[mid0, :, :], reconstruction[:, mid1, :], reconstruction[:, :, i],
                    error[mid0, :, :], error[:, mid1, :], error[:, :, i]
                ]
            
            step = dict(
                method="restyle",
                args=[{"z": z_data}],
                label=str(i)
            )
            axis_steps.append(step)
        
        steps.append(axis_steps)
    
    # Create sliders
    sliders = [
        dict(
            active=mid0,
            currentvalue={"prefix": "Axis 0 slice: "},
            pad={"t": 50},
            steps=steps[0],
            len=0.25,
            x=0.02,
            y=0.02
        ),
        dict(
            active=mid1,
            currentvalue={"prefix": "Axis 1 slice: "},
            pad={"t": 50},
            steps=steps[1],
            len=0.25,
            x=0.37,
            y=0.02
        ),
        dict(
            active=mid2,
            currentvalue={"prefix": "Axis 2 slice: "},
            pad={"t": 50},
            steps=steps[2],
            len=0.25,
            x=0.72,
            y=0.02
        )
    ]
    
    # Update layout
    fig.update_layout(
        title="Original vs Reconstruction Comparison with Error Analysis",
        sliders=sliders,
        height=900,
        margin=dict(b=150, t=100),
        showlegend=False
    )
    
    # Remove axis labels for cleaner look
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    
    return fig

In [5]:
# import numpy as np
# import matplotlib.pyplot as plt
# import tensorly as tl
# from tensorly.decomposition import tensor_train
# from tensorly.tt_tensor import tt_to_tensor
# import torch

# # Set PyTorch backend and GPU
# tl.set_backend('pytorch')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TT_RANKS_TO_TEST = [1, 2, 5, 32, 64, 512]

# def calculate_compression_ratio(original_tensor, tt_cores):
#    original_size = original_tensor.numel()
#    compressed_size = sum(core.numel() for core in tt_cores)
#    ratio = compressed_size / original_size
#    return ratio

# # Move tensor to GPU
# image_tensor_gpu = torch.tensor(first_volume, dtype=torch.float32, device=device)

# print(f"Original image tensor shape: {image_tensor_gpu.shape}")
# print(f"Total elements in original image: {image_tensor_gpu.numel()}\n")

# reconstructed_images = []
# compression_ratios = []
# tensor_train_cores = []

# for rank in TT_RANKS_TO_TEST:
#    print(f"--- Processing for Rank={rank} ---")

#    tt_cores = tensor_train(image_tensor_gpu, rank=rank)
#    tensor_train_cores.append(tt_cores)

#    reconstructed_tensor = tt_to_tensor(tt_cores)
#    reconstructed_images.append(reconstructed_tensor)

#    ratio = calculate_compression_ratio(image_tensor_gpu, tt_cores)
#    compression_ratios.append(ratio)
   
#    print(f"Compressed size (sum of elements in cores): {sum(c.numel() for c in tt_cores)}")
#    print(f"Compression Ratio: {ratio:.4f} (Compressed is {ratio*100:.2f}% of Original size)\n")

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import tensorly as tl
from tensorly.decomposition import tensor_train
from tensorly.tt_tensor import tt_to_tensor

# Lower rank = more compression, but lower quality.
# Higher rank = less compression, but higher quality.
TT_RANKS_TO_TEST = [1, 2, 5, 32, 64, 512]

def calculate_compression_ratio(original_tensor, tt_cores):
    """Calculates the compression ratio."""
    original_size = original_tensor.size
    compressed_size = sum(core.size for core in tt_cores)
    ratio = compressed_size / original_size
    return ratio

image_tensor_float = first_volume

print(f"Original image tensor shape: {image_tensor_float.shape}")
print(f"Total elements in original image: {image_tensor_float.size}\n")

reconstructed_images = []
compression_ratios = []
tensor_train_cores = []

for rank in TT_RANKS_TO_TEST:
    print(f"--- Processing for Rank={rank} ---")

    # Perform Tensor Train decomposition
    tt_cores = tensor_train(image_tensor_float, rank=rank)
    tensor_train_cores.append(tt_cores)

    # Reconstruct the tensor from the compressed TT-cores
    reconstructed_tensor = tt_to_tensor(tt_cores)
    reconstructed_images.append(reconstructed_tensor)

    # Calculate and store compression info
    ratio = calculate_compression_ratio(image_tensor_float, tt_cores)
    compression_ratios.append(ratio)
    
    print(f"Compressed size (sum of elements in cores): {sum(c.size for c in tt_cores)}")
    print(f"Compression Ratio: {ratio:.4f} (Compressed is {ratio*100:.2f}% of Original size)\n")

Original image tensor shape: (91, 109, 91)
Total elements in original image: 902629

--- Processing for Rank=1 ---
Compressed size (sum of elements in cores): 291
Compression Ratio: 0.0003 (Compressed is 0.03% of Original size)

--- Processing for Rank=2 ---
Compressed size (sum of elements in cores): 800
Compression Ratio: 0.0009 (Compressed is 0.09% of Original size)

--- Processing for Rank=5 ---
Compressed size (sum of elements in cores): 3635
Compression Ratio: 0.0040 (Compressed is 0.40% of Original size)

--- Processing for Rank=32 ---
Compressed size (sum of elements in cores): 117440
Compression Ratio: 0.1301 (Compressed is 13.01% of Original size)

--- Processing for Rank=64 ---
Compressed size (sum of elements in cores): 458112
Compression Ratio: 0.5075 (Compressed is 50.75% of Original size)

--- Processing for Rank=512 ---
Compressed size (sum of elements in cores): 919191
Compression Ratio: 1.0183 (Compressed is 101.83% of Original size)



In [None]:
th = 3
visualize_volume_cross_sections_comparison(original=first_volume, reconstruction=reconstructed_images[th])

In [8]:
for core in tensor_train_cores[th]:
    print(core.shape)

(1, 91, 32)
(32, 109, 32)
(32, 91, 1)


In [11]:
import torch
import torch.nn as nn
from typing import List, Tuple

class TensorTrain(nn.Module):
    """
    Tensor Train decomposition module that creates learnable tensor cores
    for efficient representation of high-dimensional tensors.
    """
    
    def __init__(self, shape: List[int], rank: int):
        """
        Initialize tensor train with specified output shape and rank.
        
        Args:
            shape: List of dimensions for the target tensor (e.g., [91, 109, 91])
            rank: Rank of the tensor train (internal bond dimension)
        """
        super(TensorTrain, self).__init__()
        
        self.shape = shape
        self.rank = rank
        self.n_dims = len(shape)
        
        if self.n_dims < 2:
            raise ValueError("Tensor train requires at least 2 dimensions")
        
        # Create tensor cores
        self.cores = nn.ParameterList()
        
        for i, dim in enumerate(shape):
            if i == 0:  # First core: (1, d_i, rank)
                core_shape = (1, dim, rank)
            elif i == self.n_dims - 1:  # Last core: (rank, d_i, 1)
                core_shape = (rank, dim, 1)
            else:  # Middle cores: (rank, d_i, rank)
                core_shape = (rank, dim, rank)
            
            # Initialize with Xavier uniform for stable gradients
            core = nn.Parameter(torch.empty(core_shape))
            nn.init.xavier_uniform_(core)
            self.cores.append(core)
    
    def forward(self) -> torch.Tensor:
        """
        Reconstruct the full tensor from tensor train cores.
        
        Returns:
            Reconstructed tensor of shape specified during initialization
        """
        # Start with the first core, remove the first dimension (which is 1)
        result = self.cores[0].squeeze(0)  # Shape: (d_0, rank)
        
        # Contract with middle cores
        for i in range(1, self.n_dims - 1):
            # result shape: (..., rank)
            # core shape: (rank, d_i, rank)
            result = torch.einsum('...r,rdr->...dr', result, self.cores[i])
        
        # Contract with the last core and remove the last dimension (which is 1)
        # result shape: (..., rank)
        # last core shape: (rank, d_n, 1)
        result = torch.einsum('...r,rd->...d', result, self.cores[-1].squeeze(-1))
        
        return result
    
    def get_core_shapes(self) -> List[Tuple[int, ...]]:
        """Return the shapes of all tensor cores."""
        return [tuple(core.shape) for core in self.cores]
    
    def get_compression_ratio(self) -> float:
        """
        Calculate compression ratio compared to storing the full tensor.
        
        Returns:
            Ratio of full tensor size to tensor train parameter count
        """
        full_size = torch.prod(torch.tensor(self.shape)).item()
        tt_size = sum(torch.prod(torch.tensor(core.shape)).item() for core in self.cores)
        return full_size / tt_size

# Example usage
if __name__ == "__main__":
    # Create tensor train for your specific case
    tt = TensorTrain(shape=[91, 109, 91], rank=32)
    
    print(f"Target shape: {tt.shape}")
    print(f"Tensor core shapes: {tt.get_core_shapes()}")
    print(f"Compression ratio: {tt.get_compression_ratio():.2f}")
    
    # Reconstruct tensor
    reconstructed = tt()
    print(f"Reconstructed tensor shape: {reconstructed.shape}")
    
    # Verify it matches expected shape
    assert reconstructed.shape == tuple(tt.shape), f"Shape mismatch: {reconstructed.shape} vs {tuple(tt.shape)}"
    
    print("✓ Tensor train created successfully!")
    
    # Show parameter count
    total_params = sum(p.numel() for p in tt.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Example with different shape and rank
    tt2 = TensorTrain(shape=[50, 100, 75, 25], rank=16)
    reconstructed2 = tt2()
    print(f"\nSecond example - Target: {tt2.shape}, Reconstructed: {reconstructed2.shape}")
    print(f"Compression ratio: {tt2.get_compression_ratio():.2f}")

Target shape: [91, 109, 91]
Tensor core shapes: [(1, 91, 32), (32, 109, 32), (32, 91, 1)]
Compression ratio: 7.69
Reconstructed tensor shape: torch.Size([91, 109, 91])
✓ Tensor train created successfully!
Total parameters: 117,440

Second example - Target: [50, 100, 75, 25], Reconstructed: torch.Size([50, 100, 75, 25])
Compression ratio: 203.80


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple

class TensorTrainConv2d(nn.Module):
    """
    2D Convolution with Tensor Train decomposed kernel for parameter efficiency.
    
    Uses tensor train decomposition to represent the 4D convolution kernel
    [out_channels, in_channels, kernel_height, kernel_width] with fewer parameters.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        rank: int,
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Union[int, Tuple[int, int], str] = 0,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros'
    ):
        """
        Initialize TensorTrain Conv2D layer.
        
        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels  
            kernel_size: Size of convolution kernel
            rank: Tensor train rank (controls parameter efficiency)
            stride: Stride of convolution
            padding: Padding applied to input
            dilation: Spacing between kernel elements
            groups: Number of blocked connections from input to output channels
            bias: If True, adds a learnable bias
            padding_mode: Padding mode ('zeros', 'reflect', 'replicate', 'circular')
        """
        super(TensorTrainConv2d, self).__init__()
        
        # Handle kernel_size as int or tuple
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        
        # Validate groups
        if in_channels % groups != 0:
            raise ValueError(f"in_channels ({in_channels}) must be divisible by groups ({groups})")
        if out_channels % groups != 0:
            raise ValueError(f"out_channels ({out_channels}) must be divisible by groups ({groups})")
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        
        # Store original requested rank for reference
        self.requested_rank = rank
        self.rank = rank
        
        # Tensor train kernel shape: [out_channels, in_channels//groups, kernel_h, kernel_w]
        kernel_shape = [out_channels, in_channels // groups, kernel_size[0], kernel_size[1]]
        
        # Validate and adjust rank if necessary
        min_dim = min(kernel_shape)
        if rank > min_dim:
            original_rank = rank
            rank = min_dim
            print(f"Warning: Tensor train rank ({original_rank}) exceeds minimum kernel dimension ({min_dim}). "
                  f"Using rank={rank} instead.")
        
        self.rank = rank  # Update stored rank to actual value used
        
        # Create tensor train for the kernel
        self.kernel_tt = TensorTrain(kernel_shape, rank)
        
        # Optional bias parameter
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)
        
        # Cache for reconstructed kernel (optional optimization)
        self._cached_kernel = None
        self._cache_valid = False
    
    def get_kernel(self) -> torch.Tensor:
        """
        Reconstruct the convolution kernel from tensor train cores.
        
        Returns:
            4D kernel tensor of shape [out_channels, in_channels//groups, kernel_h, kernel_w]
        """
        # Use caching to avoid redundant reconstruction
        if not self._cache_valid:
            self._cached_kernel = self.kernel_tt()
            self._cache_valid = True
        return self._cached_kernel
    
    def _invalidate_cache(self):
        """Invalidate kernel cache when parameters change."""
        self._cache_valid = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through tensor train convolution.
        
        Args:
            x: Input tensor of shape [batch, in_channels, height, width]
            
        Returns:
            Output tensor after convolution
        """
        # Reconstruct kernel from tensor train
        kernel = self.get_kernel()
        
        # Perform convolution
        return F.conv2d(
            x, 
            kernel, 
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups
        )
    
    def get_compression_ratio(self) -> float:
        """
        Calculate compression ratio compared to standard Conv2d.
        
        Returns:
            Ratio of standard conv parameters to tensor train parameters
        """
        # Standard conv2d kernel size
        standard_kernel_params = self.out_channels * (self.in_channels // self.groups) * \
                               self.kernel_size[0] * self.kernel_size[1]
        
        # Tensor train parameters (from kernel only, excluding bias)
        tt_kernel_params = sum(torch.prod(torch.tensor(core.shape)).item() 
                              for core in self.kernel_tt.cores)
        
        return standard_kernel_params / tt_kernel_params
    
    def get_total_compression_ratio(self) -> float:
        """
        Calculate total compression ratio including bias terms.
        
        Returns:
            Ratio of total standard conv parameters to tensor train parameters
        """
        # Standard conv2d total parameters
        standard_total = self.out_channels * (self.in_channels // self.groups) * \
                        self.kernel_size[0] * self.kernel_size[1]
        if self.bias is not None:
            standard_total += self.out_channels
        
        # Tensor train total parameters
        tt_total = sum(torch.prod(torch.tensor(core.shape)).item() 
                      for core in self.kernel_tt.cores)
        if self.bias is not None:
            tt_total += self.out_channels
            
        return standard_total / tt_total
    
    def extra_repr(self) -> str:
        """String representation with key parameters."""
        return (f'in_channels={self.in_channels}, out_channels={self.out_channels}, '
                f'kernel_size={self.kernel_size}, rank={self.rank}, stride={self.stride}, '
                f'padding={self.padding}, dilation={self.dilation}, groups={self.groups}, '
                f'bias={self.bias is not None}, compression_ratio={self.get_compression_ratio():.2f}')


# Example usage and testing
if __name__ == "__main__":
    # Create tensor train conv2d layer
    tt_conv = TensorTrainConv2d(
        in_channels=64,
        out_channels=128, 
        kernel_size=3,
        rank=16,  # Will be automatically reduced to 3 for 3x3 kernel
        padding=1
    )
    
    # Test forward pass
    x = torch.randn(4, 64, 32, 32)  # [batch, channels, height, width]
    output = tt_conv(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Kernel compression ratio: {tt_conv.get_compression_ratio():.2f}")
    print(f"Total compression ratio: {tt_conv.get_total_compression_ratio():.2f}")
    print(f"TT Conv parameters: {sum(p.numel() for p in tt_conv.parameters())}")
    
    # Compare with standard conv2d
    standard_conv = nn.Conv2d(64, 128, 3, padding=1)
    print(f"Standard Conv parameters: {sum(p.numel() for p in standard_conv.parameters())}")
    
    # Test different kernel sizes and ranks
    print("\nCompression ratios for different configurations:")
    for kernel_size in [1, 3, 5, 7]:
        for rank in [4, 8, 16, 32]:
            try:
                conv = TensorTrainConv2d(64, 128, kernel_size, rank)
                print(f"Kernel={kernel_size}x{kernel_size}, Rank={rank}: {conv.get_compression_ratio():.2f}x compression")
            except ValueError as e:
                print(f"Kernel={kernel_size}x{kernel_size}, Rank={rank}: {e}")

Input shape: torch.Size([4, 64, 32, 32])
Output shape: torch.Size([4, 128, 32, 32])
Kernel compression ratio: 74.02
Total compression ratio: 65.71
TT Conv parameters: 1124
Standard Conv parameters: 73856

Compression ratios for different configurations:
Kernel=1x1, Rank=4: 42.23x compression
Kernel=1x1, Rank=8: 42.23x compression
Kernel=1x1, Rank=16: 42.23x compression
Kernel=1x1, Rank=32: 42.23x compression
Kernel=3x3, Rank=4: 74.02x compression
Kernel=3x3, Rank=8: 74.02x compression
Kernel=3x3, Rank=16: 74.02x compression
Kernel=3x3, Rank=32: 74.02x compression
Kernel=5x5, Rank=4: 125.18x compression
Kernel=5x5, Rank=8: 85.69x compression
Kernel=5x5, Rank=16: 85.69x compression
Kernel=5x5, Rank=32: 85.69x compression
Kernel=7x7, Rank=4: 239.50x compression
Kernel=7x7, Rank=8: 90.73x compression
Kernel=7x7, Rank=16: 90.73x compression
Kernel=7x7, Rank=32: 90.73x compression
