# Initial NdLinear Exploration

In [None]:
# Import the NdLinear implementation
from NdLinear.ndlinear import NdLinear
import torch
import torch.nn as nn

# Create a simple example showing how NdLinear preserves structure
input_tensor = torch.randn(32, 3, 224, 224)  # Batch x Channels x Height x Width

# Standard approach flattens the tensor
flattened = input_tensor.view(32, -1)  # Shape: [32, 150528]
linear = nn.Linear(3*224*224, 1000)
linear_output = linear(flattened)  # Shape: [32, 1000]

# NdLinear preserves the structure
ndlinear = NdLinear(input_dims=(3, 224, 224), hidden_size=(64, 16, 16))
ndlinear_output = ndlinear(input_tensor)  # Shape: [32, 64, 16, 16]

# Compare parameter counts
print(f"Linear parameters: {sum(p.numel() for p in linear.parameters())}")
print(f"NdLinear parameters: {sum(p.numel() for p in ndlinear.parameters())}")

TypeError: NdLinear.__init__() missing 1 required positional argument: 'hidden_size'

In [None]:
from ndlinear import NdLinear
import torch
import torch.nn as nn


# Create sample data with multidimensional structure (like an image)
batch_size = 2
channels = 3
height = 32
width = 32

# Create a random tensor with image-like structure
x = torch.randn(batch_size, channels, height, width)
print(f"Input shape: {x.shape}")

# Create a standard linear layer (requires flattening)
flat_features = channels * height * width
linear_layer = nn.Linear(flat_features, 10)

# Apply standard linear layer (requires flattening first)
x_flat = x.view(batch_size, -1)
linear_output = linear_layer(x_flat)
print(f"Linear layer output shape: {linear_output.shape}")
print(f"Linear layer parameters: {sum(p.numel() for p in linear_layer.parameters())}")

# Create an NdLinear layer preserving structure
ndlinear_layer = NdLinear(
    dims=[channels, height, width],  # Input dimensions 
    output_dims=[10, 1, 1]  # Output dimensions (10 classes, collapsed spatial dims)
)

# Apply NdLinear directly to structured input
ndlinear_output = ndlinear_layer(x)
print(f"NdLinear layer output shape: {ndlinear_output.shape}")
print(f"NdLinear layer parameters: {sum(p.numel() for p in ndlinear_layer.parameters())}")



ModuleNotFoundError: No module named 'ndlinear'