In [2]:
import torch

### Exploration to be deleted

In [10]:
# Create a 4x4 image tensor
image = torch.arange(1, 17).reshape(1, 1, 4, 4).float()  # Shape: (N, C, H, W)
print("Original Image:")
print(image)

Original Image:
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])


In [11]:
# Unfold the tensor to extract 2x2 patches with a stride of 1
patches = image.unfold(2, 2, 1).unfold(3, 2, 1)
print("Unfolded Image Patches:")
print(patches)

Unfolded Image Patches:
tensor([[[[[[ 1.,  2.],
            [ 5.,  6.]],

           [[ 2.,  3.],
            [ 6.,  7.]],

           [[ 3.,  4.],
            [ 7.,  8.]]],


          [[[ 5.,  6.],
            [ 9., 10.]],

           [[ 6.,  7.],
            [10., 11.]],

           [[ 7.,  8.],
            [11., 12.]]],


          [[[ 9., 10.],
            [13., 14.]],

           [[10., 11.],
            [14., 15.]],

           [[11., 12.],
            [15., 16.]]]]]])


In [12]:
# Reshape the unfolded tensor to get patches as columns
patches = patches.contiguous().view(1, 1, -1, 2 * 2)  # Shape: (N, C, number_of_patches, patch_size)
patches = patches.view(-1, 2 * 2)  # Shape: (N * C * number_of_patches, patch_size)
print("Reshaped Image Patches (im2col format):")
print(patches)

Reshaped Image Patches (im2col format):
tensor([[ 1.,  2.,  5.,  6.],
        [ 2.,  3.,  6.,  7.],
        [ 3.,  4.,  7.,  8.],
        [ 5.,  6.,  9., 10.],
        [ 6.,  7., 10., 11.],
        [ 7.,  8., 11., 12.],
        [ 9., 10., 13., 14.],
        [10., 11., 14., 15.],
        [11., 12., 15., 16.]])


In [15]:
# Step 1: Create the Input Tensor
image = torch.arange(1, 17).reshape(1, 1, 4, 4).float()  # Shape: (N, C, H, W)
print("Original Image:")
print(image)

# Step 2: Use torch.unfold to extract 2x2 patches with a stride of 1
patches = image.unfold(2, 2, 1).unfold(3, 2, 1)
print("Unfolded Image Patches:")
print(patches)

# Step 3: Reshape the Unfolded Tensor to im2col format
patches = patches.contiguous().view(1, 1, -1, 2 * 2)  # Shape: (N, C, number_of_patches, patch_size)
patches = patches.view(-1, 2 * 2)  # Shape: (N * C * number_of_patches, patch_size)
print("Reshaped Image Patches (im2col format):")
print(patches)

# Create a 2x2 convolution kernel
kernel = torch.tensor([[1, 0], [0, -1]]).float().view(1, -1)  # Shape: (1, patch_size)

# Perform convolution by matrix multiplication
output = patches @ kernel.t()  # Shape: (number_of_patches, 1)
output = output.view(1, 1, 3, 3)  # Shape: (N, C, H', W') where H' and W' are the output dimensions

print("Convolution Output:")
print(output)

Original Image:
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
Unfolded Image Patches:
tensor([[[[[[ 1.,  2.],
            [ 5.,  6.]],

           [[ 2.,  3.],
            [ 6.,  7.]],

           [[ 3.,  4.],
            [ 7.,  8.]]],


          [[[ 5.,  6.],
            [ 9., 10.]],

           [[ 6.,  7.],
            [10., 11.]],

           [[ 7.,  8.],
            [11., 12.]]],


          [[[ 9., 10.],
            [13., 14.]],

           [[10., 11.],
            [14., 15.]],

           [[11., 12.],
            [15., 16.]]]]]])
Reshaped Image Patches (im2col format):
tensor([[ 1.,  2.,  5.,  6.],
        [ 2.,  3.,  6.,  7.],
        [ 3.,  4.,  7.,  8.],
        [ 5.,  6.,  9., 10.],
        [ 6.,  7., 10., 11.],
        [ 7.,  8., 11., 12.],
        [ 9., 10., 13., 14.],
        [10., 11., 14., 15.],
        [11., 12., 15., 16.]])
Convolution Output:
tensor([[[[-5., -5., -5.],
          [-5.

In [20]:
def im2col_convolution(image, kernel, stride):

    kernel_height = kernel.shape[0]
    kernel_width = kernel.shape[1]    

    # Unfold the image to extract patches with the given kernel size and stride
    # First we unfold by dimension 2 (height) then dimension 3 (width)
    patches = image.unfold(2, kernel_height, stride).unfold(3, kernel_width, stride)
    
    # Reshape the unfolded tensor to im2col format
    # Also ensure tensor is in contiguous in memory format
    patches = patches.contiguous().view(1, 1, -1, kernel_height * kernel_width)  # Shape: (N * C * number_of_patches, patch_size)

    # Reshape the kernel to match the patch size
    kernel = kernel.view(1, -1)  # Shape: (1, patch_size)

    # Perform convolution by matrix multiplication
    output = patches @ kernel.t()  # Shape: (number_of_patches, 1)

    # Calculate output dimensions
    output_height = (image.shape[2] - kernel_height) // stride + 1
    output_width = (image.shape[3] - kernel_width) // stride + 1
    output = output.view(1, 1, output_height, output_width)  # Shape: (N, C, H', W')

    return output

# Example usage
image = torch.arange(1, 17).reshape(1, 1, 4, 4).float()  # Shape: (N, C, H, W)
kernel = torch.tensor([[1, 0], [0, -1]]).float()  # Shape: (kernel_H, kernel_W)

kernel_size = 2
stride = 1

output = im2col_convolution(image, kernel, stride)
print("Convolution Output:")
print(output)

Convolution Output:
tensor([[[[-5., -5., -5.],
          [-5., -5., -5.],
          [-5., -5., -5.]]]])


### New Python Convolution -- no NaN

In [102]:

#CURRENT LIMITATIONS: EXPECTS 4D input, cannot deal with 3D input
def im2col_convolution(image, kernel, stride):
    N, C, H, W = image.shape  # Get the shape of the input image
    out_channels, in_channels, K_H, K_W = kernel.shape  # Get the shape of the kernel

    assert C == in_channels, "Number of input channels must match between image and kernel."

    # Unfold the image to extract patches with the given kernel size and stride
    patches = image.unfold(2, K_H, stride).unfold(3, K_W, stride)
    # print(patches)
    
    # Calculate the number of patches
    num_patches_h = patches.size(2)
    num_patches_w = patches.size(3)
    
    # Reshape the unfolded tensor to im2col format
    # Shape: (N, C, num_patches_h * num_patches_w, K_H * K_W)
    patches = patches.contiguous().view(N, C, num_patches_h * num_patches_w, K_H * K_W)
    # print(patches)

    # Reshape the kernel to match the patch size
    # Shape: (out_channels, in_channels, K_H * K_W)
    kernel = kernel.view(out_channels, in_channels * K_H * K_W)

    # Perform convolution by matrix multiplication
    # Shape of patches: (N, C, num_patches_h * num_patches_w, K_H * K_W)
    # Shape of kernel: (out_channels, in_channels * K_H * K_W)
    # Need to multiply patches with kernel properly
    # Reshape patches to (N * num_patches_h * num_patches_w, C * K_H * K_W)
    patches_reshaped = patches.permute(0, 2, 1, 3).reshape(-1, in_channels * K_H * K_W)
    kernel_reshaped = kernel.t()  # Shape: (in_channels * K_H * K_W, out_channels)

    # Matrix multiplication to get the output
    # Shape: (N * num_patches_h * num_patches_w, out_channels)
    output_reshaped = torch.mm(patches_reshaped, kernel_reshaped)

    # Reshape output to match the expected dimensions
    # output_height = (H - K_H) // stride + 1
    # output_width = (W - K_W) // stride + 1
    output = output_reshaped.view(N, num_patches_h, num_patches_w, out_channels)
    output = output.permute(0, 3, 1, 2)  # Shape: (N, out_channels, num_patches_h, num_patches_w)

    return output

# Example usage
image = torch.arange(1, 49).reshape(1, 3, 4, 4).float()  # Shape: (N, C, H, W)
kernel = torch.tensor([[[[1, 0], [0, -1]], [[1, 0], [0, -1]], [[1, 0], [0, -1]]],
                       [[[0, 1], [-1, 0]], [[0, 1], [-1, 0]], [[0, 1], [-1, 0]]]]).float()  # Shape: (out_channels, in_channels, K_H, K_W)

stride = 1

output = im2col_convolution(image, kernel, stride)
print("Convolution Output:")
print(output)

Convolution Output:
tensor([[[[-15., -15., -15.],
          [-15., -15., -15.],
          [-15., -15., -15.]],

         [[ -9.,  -9.,  -9.],
          [ -9.,  -9.,  -9.],
          [ -9.,  -9.,  -9.]]]])


In [186]:
print(image)
nan_counts = torch.sum(image.isnan(), dim=3)
    
# Find the rows that exceed the threshold
rows_to_extract = nan_counts/image.shape[3] > 0.5

# Extract the rows and their positions
extracted_rows = image[~rows_to_extract]
extracted_rows = extracted_rows.view(*image.shape[:2], (~rows_to_extract).sum(), image.shape[-1])
extracted_rows
positions = torch.nonzero(rows_to_extract) #.squeeze()
positions

extracted_rows, extracted_rows.shape
positions, rows_to_extract
# extract_rows_with_nans(image, 0.5)

tensor([[[[nan, nan, nan, nan],
          [nan, nan,  7.,  8.],
          [ 9., nan, 11., 12.],
          [nan, nan, 15., 16.]]]])


(tensor([[0, 0, 0]]), tensor([[[ True, False, False, False]]]))

In [120]:
from torch.masked import masked_tensor, as_masked_tensor


image[:, :, 0] = float('nan')
image[:, :, 1, :2] = float('nan')
image[:, :, 2, 2:] = float('nan')
image[:, :, 3, :] = float('nan')
print(image)


mt = masked_tensor(image, ~torch.isnan(image))
mt

tensor([[[[nan, nan, nan, nan],
          [nan, nan,  7.,  8.],
          [ 9., 10., nan, nan],
          [nan, nan, nan, nan]]]])


MaskedTensor(
  [
    [
      [
        [      --,       --,       --,       --],
        [      --,       --,   7.0000,   8.0000],
        [  9.0000,  10.0000,       --,       --],
        [      --,       --,       --,       --]
      ]
    ]
  ]
)

In [124]:
print(mt.shape)
sparse_coo_mt = mt[0,0].to_sparse_csr()
sparse_coo_mt.get_data()

torch.Size([1, 1, 4, 4])




tensor(crow_indices=tensor([0, 0, 2, 4, 4]),
       col_indices=tensor([2, 3, 0, 1]),
       values=tensor([ 7.,  8.,  9., 10.]), size=(4, 4), nnz=4,
       layout=torch.sparse_csr)

In [7]:
def extract_rows_with_nans(matrix, threshold):
    # Count the number of NaNs in each row
    nan_counts = torch.sum(matrix.isnan(), dim=3)
    
    # Find the rows that exceed the threshold
    rows_to_extract = nan_counts/matrix.shape[3] > threshold
    
    # Extract the rows and their positions
    # REPLACE NANS?
    extracted_rows = matrix[~rows_to_extract]
    extracted_rows = extracted_rows.view(*image.shape[:2], (~rows_to_extract).sum(), image.shape[-1])

    nan_positions = torch.nonzero(rows_to_extract).squeeze()
    valid_positions = torch.nonzero(~rows_to_extract).squeeze()

    return extracted_rows, nan_positions, valid_positions


image = torch.arange(1, 17).reshape(1, 1, 4, 4).float()  # Shape: (N, C, H, W)
image[:, :, 0] = float('nan')
image[:, :, 1, :2] = float('nan')
image[:, :, 2, 1] = float('nan')
image[:, :, 3, :2] = float('nan')

kernel = torch.tensor([[[[1, 0], [0, -1]], [[1, 0], [0, -1]], [[1, 0], [0, -1]]],
                       [[[0, 1], [-1, 0]], [[0, 1], [-1, 0]], [[0, 1], [-1, 0]]]]).float()  # Shape: (out_channels, in_channels, K_H, K_W)

# kernel = torch.tensor([[1, 0], [0, -1]]).float().view(1, 1, 1, -1)  # Shape: (1, patch_size)
kernel = torch.randn(1, 1, 2, 2)  # Shape: (out_channels, in_channels, kernel_height, kernel_width)
print(image)
stride = 1

# output = im2col_convolution(image, kernel, stride)
# print("Convolution Output:")
# print(output)


N, C, H, W = image.shape  # Get the shape of the input image
out_channels, in_channels, K_H, K_W = kernel.shape  # Get the shape of the kernel

assert C == in_channels, "Number of input channels must match between image and kernel."

# Unfold the image to extract patches with the given kernel size and stride
patches = image.unfold(2, K_H, stride).unfold(3, K_W, stride)
# print(patches)

# Calculate the number of patches
num_patches_h = patches.size(2)
num_patches_w = patches.size(3)

# Reshape the unfolded tensor to im2col format
# Shape: (N, C, num_patches_h * num_patches_w, K_H * K_W)
patches = patches.contiguous().view(N, C, num_patches_h * num_patches_w, K_H * K_W)
print(patches)

patches, nan_positions, valid_positions = extract_rows_with_nans(patches, 0.5)

print(patches)

# Reshape the kernel to match the patch size
# Shape: (out_channels, in_channels, K_H * K_W)
kernel = kernel.view(out_channels, in_channels * K_H * K_W)

# Perform convolution by matrix multiplication
# Shape of patches: (N, C, num_patches_h * num_patches_w, K_H * K_W)
# Shape of kernel: (out_channels, in_channels * K_H * K_W)
# Need to multiply patches with kernel properly
# Reshape patches to (N * num_patches_h * num_patches_w, C * K_H * K_W)
patches_reshaped = patches.permute(0, 2, 1, 3) #.reshape(-1, in_channels * K_H * K_W)
print(patches_reshaped.shape)
kernel_reshaped = kernel.t()  # Shape: (in_channels * K_H * K_W, out_channels)
print(kernel_reshaped.shape)

# Matrix multiplication to get the output
# Shape: (N * num_patches_h * num_patches_w, out_channels)
output_reshaped = torch.matmul(patches_reshaped, kernel_reshaped)

print(output_reshaped.shape)
# print(positions)

# Reshape output to match the expected dimensions
# output_height = (H - K_H) // stride + 1
# output_width = (W - K_W) // stride + 1
output = torch.zeros(N, num_patches_h, num_patches_w, out_channels)
print(output.shape)
print(valid_positions)
output[valid_positions] = output_reshaped
output[nan_positions, :] = float('nan')

# output = output_reshaped.view(N, num_patches_h, num_patches_w, out_channels)
output = output.permute(0, 3, 1, 2)  # Shape: (N, out_channels, num_patches_h, num_patches_w)




tensor([[[[nan, nan, nan, nan],
          [nan, nan,  7.,  8.],
          [ 9., nan, 11., 12.],
          [nan, nan, 15., 16.]]]])
tensor([[[[nan, nan, nan, nan],
          [nan, nan, nan,  7.],
          [nan, nan,  7.,  8.],
          [nan, nan,  9., nan],
          [nan,  7., nan, 11.],
          [ 7.,  8., 11., 12.],
          [ 9., nan, nan, nan],
          [nan, 11., nan, 15.],
          [11., 12., 15., 16.]]]])
tensor([[[[nan, nan,  7.,  8.],
          [nan,  7., nan, 11.],
          [ 7.,  8., 11., 12.],
          [nan, 11., nan, 15.],
          [11., 12., 15., 16.]]]])
torch.Size([1, 5, 1, 4])
torch.Size([4, 1])
torch.Size([1, 5, 1, 1])
torch.Size([1, 3, 3, 1])
tensor([[0, 0, 2],
        [0, 0, 4],
        [0, 0, 5],
        [0, 0, 7],
        [0, 0, 8]])


RuntimeError: shape mismatch: value tensor of shape [5, 1, 1] cannot be broadcast to indexing result of shape [5, 3, 3, 3, 1]

In [189]:
# Define a 4D tensor
matrix = torch.tensor([
    [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
    [[[13, 14], [15, 16]], [[17, 18], [19, 20]], [[21, 22], [23, 24]]],
    [[[25, 26], [27, 28]], [[29, 30], [31, 32]], [[33, 34], [35, 36]]],
    [[[37, 38], [39, 40]], [[41, 42], [43, 44]], [[45, 46], [47, 48]]]
], dtype=torch.float32)
print(matrix.shape)

# Define a condition to select certain rows
threshold = 40
condition = torch.sum(matrix[:, 0, :, :], dim=(1, 2)) > threshold

# Get the indices of the rows that meet the condition
indices = torch.nonzero(condition).squeeze()

# Modify the selected rows (for demonstration purposes, we'll just multiply them by 10)
modified_rows = matrix[indices] * 10
print(modified_rows.shape)

# Create a new tensor to hold the modified data
new_matrix = matrix.clone()

# Reinsert the modified rows back into their original positions
new_matrix[indices] = modified_rows

print("Modified Matrix:")
print(new_matrix)
new_matrix.shape

torch.Size([4, 3, 2, 2])
torch.Size([3, 3, 2, 2])
Modified Matrix:
tensor([[[[  1.,   2.],
          [  3.,   4.]],

         [[  5.,   6.],
          [  7.,   8.]],

         [[  9.,  10.],
          [ 11.,  12.]]],


        [[[130., 140.],
          [150., 160.]],

         [[170., 180.],
          [190., 200.]],

         [[210., 220.],
          [230., 240.]]],


        [[[250., 260.],
          [270., 280.]],

         [[290., 300.],
          [310., 320.]],

         [[330., 340.],
          [350., 360.]]],


        [[[370., 380.],
          [390., 400.]],

         [[410., 420.],
          [430., 440.]],

         [[450., 460.],
          [470., 480.]]]])


torch.Size([4, 3, 2, 2])

### Expansion to automatically adjust to 3D or 4D input

In [69]:

def im2col_convolution(image, kernel, stride):
    if image.dim() == 3:  # Handle 3D image case (single-channel, no batch)
        image = image.unsqueeze(0) #.unsqueeze(0)  # Add batch and channel dimensions
        kernel = kernel.unsqueeze(0)  # Add batch dimension to kernel

    if kernel.dim() == 3:  # Handle 3D kernel case (single channel, no batch)
        kernel = kernel.unsqueeze(0)  # Add batch dimension to kernel

    print(image.shape)

    N, C, H, W = image.shape  # Get the shape of the input image
    out_channels, in_channels, K_H, K_W = kernel.shape  # Get the shape of the kernel

    assert C == in_channels, "Number of input channels must match between image and kernel."

    # Unfold the image to extract patches with the given kernel size and stride
    patches = image.unfold(2, K_H, stride).unfold(3, K_W, stride)
    
    # Calculate the number of patches
    num_patches_h = patches.size(2)
    num_patches_w = patches.size(3)
    
    # Reshape the unfolded tensor to im2col format
    patches = patches.contiguous().view(N, C, num_patches_h * num_patches_w, K_H * K_W)

    # Reshape the kernel to match the patch size
    kernel = kernel.view(out_channels, in_channels * K_H * K_W)

    # Perform convolution by matrix multiplication
    patches_reshaped = patches.permute(0, 2, 1, 3).reshape(-1, in_channels * K_H * K_W)
    kernel_reshaped = kernel.t()  # Shape: (in_channels * K_H * K_W, out_channels)

    # Matrix multiplication to get the output
    output_reshaped = torch.mm(patches_reshaped, kernel_reshaped)

    # Reshape output to match the expected dimensions
    output_height = (H - K_H) // stride + 1
    output_width = (W - K_W) // stride + 1
    output = output_reshaped.view(N, num_patches_h, num_patches_w, out_channels)
    output = output.permute(0, 3, 1, 2)  # Shape: (N, out_channels, num_patches_h, num_patches_w)

    return output

# Example usage
image_4d = torch.arange(1, 49).reshape(1, 3, 4, 4).float()  # Shape: (N, C, H, W)
kernel_4d = torch.tensor([[[[1, 0], [0, -1]], [[1, 0], [0, -1]], [[1, 0], [0, -1]]],
                          [[[0, 1], [-1, 0]], [[0, 1], [-1, 0]], [[0, 1], [-1, 0]]]]).float()  # Shape: (out_channels, in_channels, K_H, K_W)

stride = 1

output = im2col_convolution(image_4d, kernel_4d, stride)
print("Convolution Output:")
print(output)

# Example for 3D kernel and image
image_3d = torch.arange(1, 17).reshape(1, 4, 4).float()  # Shape: (N, C, H, W)
kernel_3d = torch.tensor([[[[1, 0], [0, -1]]]]).float()  # Shape: (out_channels, in_channels, K_H, K_W)

output_3d = im2col_convolution(image_3d, kernel_3d, stride)
print("Convolution Output for 3D kernel and image:")
print(output_3d)

torch.Size([1, 3, 4, 4])
Convolution Output:
tensor([[[[-15., -15., -15.],
          [-15., -15., -15.],
          [-15., -15., -15.]],

         [[ -9.,  -9.,  -9.],
          [ -9.,  -9.,  -9.],
          [ -9.,  -9.,  -9.]]]])
torch.Size([1, 1, 4, 4])


ValueError: too many values to unpack (expected 4)

In [64]:
torch.nn.functional.conv2d(input=image_3d, weight=kernel_3d)

tensor([[[[-5., -5., -5.],
          [-5., -5., -5.],
          [-5., -5., -5.]]]])

In [65]:
image_3d.shape

torch.Size([1, 1, 4, 4])