In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffsort import DiffSortNet
from torch.optim import Adam
from tqdm import tqdm
import functools, operator
import matplotlib.pyplot as plt
from bending_modules import BendingDiffSort, BendingDiffSort_XY
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
B = 2
data_tensor = torch.arange(1, 25).view(B, 3, 2, 2).float()

perm_matrices = torch.tensor([
    [[0, 1], [1, 0]],
    [[1, 0], [0, 1]]
], dtype=torch.float32)

In [9]:
x = torch.arange(1, 97).view(2, 3, 4, 4).float()
randperm = torch.randperm(x.shape[2])

In [18]:
x[0, 0, :, :]

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])

In [17]:
x[0, 0, randperm, :]

tensor([[13., 14., 15., 16.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [ 1.,  2.,  3.,  4.]])

In [5]:
data_tensor.shape, perm_matrices.shape

(torch.Size([2, 3, 2, 2]), torch.Size([2, 2, 2]))

In [8]:
data_tensor_reshaped = data_tensor.view(-1, 2, 2)
perm_matrices_expanded = perm_matrices.unsqueeze(1).repeat(1, 3, 1, 1).view(-1, 2, 2)

In [17]:
data_tensor_reshaped.shape

torch.Size([6, 2, 2])

In [13]:
permuted_data_tensor = torch.bmm(perm_matrices_expanded,
                                 data_tensor_reshaped)
permuted_data_tensor = permuted_data_tensor.view(B, 3, 2, 2)

In [15]:
permuted_data_tensor[0]

tensor([[[ 3.,  4.],
         [ 1.,  2.]],

        [[ 7.,  8.],
         [ 5.,  6.]],

        [[11., 12.],
         [ 9., 10.]]])

In [16]:
data_tensor[0]

tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]])

In [18]:
# Step 1: Transpose the last two dimensions of the data tensor to switch the rows and columns
data_tensor_transposed = data_tensor.permute(0, 1, 3, 2)

# Step 2 & 3: Reshape the data tensor and adjust the permutation matrices tensor for batch matrix multiplication
data_tensor_transposed_reshaped = data_tensor_transposed.view(-1, 2, 2)

# Step 4: Use torch.bmm to perform batch matrix multiplication on the transposed data tensor to permute the columns
permuted_data_tensor_transposed = torch.bmm(perm_matrices_expanded, data_tensor_transposed_reshaped)

# Step 5 & 6: Transpose back and Reshape the result back to the original shape
permuted_data_tensor_cols = permuted_data_tensor_transposed.view(B, 3, 2, 2).permute(0, 1, 3, 2)

# Display the original and permuted tensors for both items in the batch
data_tensor, permuted_data_tensor_cols


(tensor([[[[ 1.,  2.],
           [ 3.,  4.]],
 
          [[ 5.,  6.],
           [ 7.,  8.]],
 
          [[ 9., 10.],
           [11., 12.]]],
 
 
         [[[13., 14.],
           [15., 16.]],
 
          [[17., 18.],
           [19., 20.]],
 
          [[21., 22.],
           [23., 24.]]]]),
 tensor([[[[ 2.,  1.],
           [ 4.,  3.]],
 
          [[ 6.,  5.],
           [ 8.,  7.]],
 
          [[10.,  9.],
           [12., 11.]]],
 
 
         [[[13., 14.],
           [15., 16.]],
 
          [[17., 18.],
           [19., 20.]],
 
          [[21., 22.],
           [23., 24.]]]]))

In [10]:
perm_matrices_expanded.shape

torch.Size([6, 2, 2])

In [7]:
data_tensor_reshaped.shape

torch.Size([6, 2, 2])

In [86]:
B = 1
C = 2
t = torch.arange(1, 33).view(C, 4, 4).unsqueeze(0).repeat(B, 1, 1, 1)

out_rows = t.permute(0, 2, 1, 3).reshape(B, 4, -1).permute(0, 2, 1)
out_cols = t.permute(0, 3, 1, 2).reshape(B, 4, -1).permute(0, 2, 1)

In [87]:
out_rows[0, :, 0]

tensor([ 1,  2,  3,  4, 17, 18, 19, 20])

In [79]:
out_rows[0, :, 0]

tensor([ 1,  2,  3,  4, 17, 18, 19, 20])

In [81]:
t[0, :, 0, :]

tensor([[ 1,  2,  3,  4],
        [17, 18, 19, 20]])

In [82]:
out_cols[0, :, 0]

tensor([ 1,  5,  9, 13, 17, 21, 25, 29])

In [83]:
t[0, :, :, 0]

tensor([[ 1,  5,  9, 13],
        [17, 21, 25, 29]])

In [73]:
out_cols.shape

torch.Size([1, 4, 8])

In [69]:
t[0, :, 0, :] # first row, all channels and columns

tensor([[ 1,  2,  3,  4],
        [17, 18, 19, 20]])

In [67]:
out_rows[0][0]

tensor([ 1, 17,  2, 18,  3, 19,  4, 20])

In [60]:
nrows = 4
ncols = 5
rect = torch.randn(1, 2, nrows, ncols)

In [61]:
rect[0, :, 0, :]

tensor([[ 1.7924, -0.3046, -1.0712, -0.1341, -2.9213],
        [ 1.4698,  0.0924, -0.3523,  0.0364,  0.7477]])

In [62]:
rect_rows = rect.permute(0,2,1,3).reshape(1, -1, nrows)
rect_cols = rect.permute(0,3,1,2).reshape(1, -1, ncols)

In [63]:
rect_rows[0, :, 0]

tensor([ 1.7924, -2.9213,  0.0364,  0.3977,  0.6793,  0.2705,  0.2828, -0.4925,
         0.6413, -0.8219])

In [56]:
rect.reshape(1, -1, ncols).shape

torch.Size([1, 128, 5])

In [58]:
rect.reshape(1, nrows, -1).permute(0,2,1).shape

torch.Size([1, 160, 4])

In [59]:
rect[0][0]

tensor([[ 0.3375,  1.8205,  1.1124,  0.8353,  0.0926],
        [ 1.1759,  1.9132,  1.9754, -1.4028, -0.8067],
        [ 0.3202, -1.1928,  0.0972,  1.3331,  1.1699],
        [-0.4395, -1.1938, -0.1809,  1.3625, -0.3301]])

In [2]:
benddiff = BendingDiffSort_XY(64, 4)

In [3]:
x = torch.randn(1, 64, 4, 4)

In [9]:
x.permute(0,2,1,3).shape

torch.Size([1, 4, 64, 4])

In [11]:
x.reshape(1, -1, 4).shape

torch.Size([1, 256, 4])

In [8]:
x[0,1,:,:]

tensor([[ 0.9746,  0.2272,  0.0774, -1.6365],
        [ 0.3266, -0.3427,  0.3522, -0.7015],
        [ 0.1381,  0.7946,  1.0612, -0.4956],
        [ 0.0207,  0.6857,  2.0444, -0.3357]])

In [5]:
y = benddiff(x)

In [7]:
y[0,1,:,:]

tensor([[ 0.2746,  0.2235,  0.0247,  0.6085],
        [-0.4800, -0.5519, -0.3721, -0.6134],
        [ 0.0942, -0.6721, -0.2880, -0.5985],
        [ 0.7746,  0.1724,  0.3683,  1.1201]], grad_fn=<SliceBackward0>)

In [12]:
feat_extractor = nn.Sequential(
            nn.Conv2d(64, 32, 
                      1, padding='same'),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 
                      1, padding='same'),
            nn.ReLU(inplace=True)
            )

In [13]:
out = feat_extractor(x)

In [22]:
out.shape

torch.Size([1, 32, 4, 4])

In [35]:
out_1 = out.permute(0, 2, 1, 3).reshape(1, -1, 4).reshape(1, -1, 4, 4)
out_2 = out.reshape(1, -1, 4).reshape(1, -1, 4, 4)
out_3 = out.permute(0, 3, 1, 2).reshape(1, -1, 4).reshape(1, -1, 4, 4).permute(0, 2, 3, 1)

In [32]:
out_3.shape

torch.Size([1, 4, 4, 32])

In [18]:
out_1[0,0,:,:]

tensor([[0.0065, 0.0000, 0.4086, 0.4803],
        [0.2278, 0.1691, 0.2756, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0007],
        [0.0702, 0.2661, 0.2555, 0.0000]], grad_fn=<SliceBackward0>)

In [37]:
out_2[0,0,:,:]

tensor([[0.0065, 0.0000, 0.4086, 0.4803],
        [0.3354, 0.1345, 0.2500, 0.0000],
        [0.2185, 0.0000, 0.0000, 0.0558],
        [0.0000, 0.1893, 0.7107, 0.5130]], grad_fn=<SliceBackward0>)

In [36]:
out[0,0,:,:]

tensor([[0.0065, 0.0000, 0.4086, 0.4803],
        [0.3354, 0.1345, 0.2500, 0.0000],
        [0.2185, 0.0000, 0.0000, 0.0558],
        [0.0000, 0.1893, 0.7107, 0.5130]], grad_fn=<SliceBackward0>)

In [45]:
out_rows = out.permute(0, 2, 1, 3).reshape(1, -1, 4)
out_cols = out.permute(0, 3, 1, 2).reshape(1, -1, 4)

In [46]:
out_rows.shape

torch.Size([1, 128, 4])

In [23]:
x.shape

torch.Size([1, 64, 4, 4])

In [13]:
# Rows only
x_rows = x.permute(0, 2, 1, 3).reshape(1, -1, 4)
x_cols = x.permute(0, 3, 1, 2).reshape(1, -1, 4)

In [43]:
elementwise = nn.Conv1d(128, 1, 1)

In [39]:
sorter = DiffSortNet('bitonic', 4, steepness=50)

In [47]:
y = elementwise(out_rows).flatten(1)

In [48]:
y.shape

torch.Size([1, 4])

In [49]:
y_sorted, sort_mat = sorter(y)

In [51]:
sort_mat.transpose(1,2)

tensor([[[0.0398, 0.0407, 0.1750, 0.7445],
         [0.0237, 0.2369, 0.5567, 0.1827],
         [0.1187, 0.6069, 0.2420, 0.0324],
         [0.8178, 0.1155, 0.0263, 0.0404]]], grad_fn=<TransposeBackward0>)

In [40]:
x.permute(0, 2, 1, 3).reshape(1, -1, 4).shap4

torch.Size([1, 256, 4])

In [43]:
x.shape

torch.Size([1, 64, 4, 4])

In [41]:
sort_mat 

tensor([[[0.6562, 0.2141, 0.0594, 0.0703],
         [0.2058, 0.5314, 0.1547, 0.1081],
         [0.0774, 0.0776, 0.3771, 0.4679],
         [0.0605, 0.1770, 0.4088, 0.3537]]], grad_fn=<AddBackward0>)

In [47]:
yyy = torch.bmm(x.permute(0, 2, 1, 3).reshape(1, -1, 4),
          sort_mat).reshape(1, -1, 4, 4)

In [48]:
yyy.shape

torch.Size([1, 64, 4, 4])

In [37]:
x.shape

torch.Size([1, 64, 4, 4])

In [7]:
x_cols.shape

torch.Size([1, 4, 256])

In [8]:
bendiffsort = BendingDiffSort_XY(3, 2)
x = torch.arange(1, 13).reshape(1, 3, 2, 2).float()
y = bendiffsort(x)

In [9]:
y.shape

torch.Size([1, 3, 2, 2])

In [10]:
x[0][0]

tensor([[1., 2.],
        [3., 4.]])

In [11]:
y[0][0]

tensor([[1.8245, 1.1755],
        [3.8245, 3.1755]], grad_fn=<SelectBackward0>)

In [13]:
class ToyDiffSort(nn.Module):
    def __init__(self, n_channels, input_size):
        super(ToyDiffSort, self).__init__()
        self.n_channels = n_channels
        
        self.feat_extractor = nn.Sequential(
            nn.Conv2d(self.n_channels, 32, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        # Trick from https://datascience.stackexchange.com/questions/40906/determining-size-of-fc-layer-after-conv-layer-in-pytorch
        num_feats_before_fcnn = functools.reduce(
            operator.mul,
            list(self.feat_extractor(
                torch.rand(1, n_channels, input_size, input_size)
            ).shape)
        )
        #print('Estimated:', num_feats_before_fcnn)
        
        self.fc1 = nn.Linear(num_feats_before_fcnn, 64)
        self.fc2 = nn.Linear(64, self.n_channels)
        self.output_sorter = DiffSortNet('bitonic', self.n_channels, steepness=5)
    
    def forward(self, x):
        batch_size = x.shape[0]  
        
        out = self.feat_extractor(x)
        out = out.view(batch_size, -1)
        #print('Actual:', out.shape[1])
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        out_sorted, sort_mat = self.output_sorter(out)
        
        return out_sorted
        

In [79]:
sorter = DiffSortNet('bitonic', 3, steepness=50)

In [105]:
x = torch.randn(4, 3, 2, 2)
y = torch.randn(4, 3)

In [106]:
x_ = x.reshape(4, -1, 3)

In [107]:
x_.shape

torch.Size([4, 4, 3])

In [99]:
x_sort, sort_mat = sorter(y)

In [100]:
sort_mat.shape

torch.Size([4, 3, 3])

In [125]:
x.permute(0, 2, 3, 1).shape

torch.Size([4, 2, 2, 3])

In [126]:
x_sorted = torch.bmm(x.permute(0, 2, 3, 1).reshape(4, -1, 3), sort_mat).reshape(0, 2, 3, 1)

RuntimeError: shape '[0, 2, 3, 1]' is invalid for input of size 48

In [128]:
sort_mat = torch.Tensor([[[0., 1.], [1., 0.]], [[1., 0.], [0., 1.]]])

In [131]:
sort_mat.shape

torch.Size([2, 2, 2])

In [130]:
sort_mat[0]

tensor([[0., 1.],
        [1., 0.]])

In [165]:
x = torch.Tensor([
    [[[0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1]]], 
    [[[2, 2, 2], [2, 2, 2]], [[3, 3, 3], [3, 3, 3]]]
    ])

In [166]:
x.shape

torch.Size([2, 2, 2, 3])

In [170]:
y = torch.bmm(x.reshape(2,2,-1).permute(0, 2, 1), sort_mat).permute(0,2,1).reshape(2,2,2,3)

In [171]:
y.shape

torch.Size([2, 2, 2, 3])

In [172]:
x[0][0]

tensor([[0., 0., 0.],
        [0., 0., 0.]])

In [174]:
y[0][0]

tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [122]:
x[0][0]

tensor([[ 0.0042, -0.4893],
        [ 1.1642,  1.3981]])

In [118]:
x_sorted[0][1]

tensor([[ 1.3691,  0.2172],
        [ 0.1267, -0.7040]])

In [120]:
sort_mat[0]

tensor([[0.0053, 0.9749, 0.0198],
        [0.0041, 0.0199, 0.9760],
        [0.9906, 0.0052, 0.0042]])

In [114]:
x.reshape(4, -1, 3).reshape(4, 3, 2, 2)[0][0]

tensor([[ 0.0042, -0.4893],
        [ 1.1642,  1.3981]])

In [109]:
y.shape

torch.Size([4, 4, 3])

In [95]:
x

tensor([[ 1.0119,  1.2738, -1.0657],
        [-0.3608,  2.0084,  0.3825],
        [-0.0903, -1.5681,  1.4969],
        [ 0.8572,  0.5036,  0.0334]])

In [86]:
y.shape

torch.Size([4, 4, 3])

In [87]:
x.shape

torch.Size([4, 3])

In [78]:
for mat in sort_mat:
    print(f'{mat[0][0]:.4f}, {mat[0][1]:.4f}, {mat[0][2]:.4f}')

0.0017, 0.0019, 0.9964
0.0031, 0.9894, 0.0075
0.0013, 0.0069, 0.9918
0.0091, 0.9895, 0.0015


In [72]:
for mat in sort_mat:
    print(f'{mat[0][0]:.4f}, {mat[0][1]:.4f}, {mat[0][2]:.4f}')

0.9985, 0.0015, 0.0000
0.0016, 0.9817, 0.0167
0.9985, 0.0015, 0.0000
0.0013, 0.0035, 0.9952


In [None]:
def compute_loss(out_sorted):
    loss = out_sorted * torch.Tensor([1., 0., 0.])[None, ...]
    return 1./loss.sum()

In [25]:
toydiffsort = ToyDiffSort(3, 24)

In [29]:
toydiffsort.output_sorter

DiffSortNet()

In [15]:
x = torch.rand(10, 3, 24, 24)

In [16]:
y_sort = toydiffsort(x)

In [20]:
torch.argsort(y_sort)

tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2]])

In [23]:
x_perm = permute_tensor_batchwise(x, torch.argsort(y_sort), 1)

In [24]:
x_perm.shape

torch.Size([10, 3, 24, 24])

In [None]:
y_sort.shape

In [None]:
y_sort[:3, :]

In [None]:
torch.argsort(y_sort[:3, :])

In [None]:
xxx = torch.randn(16, 3)

In [None]:
A = torch.Tensor([[1, 2, 3], [4, 5, 6]])
order_tensor = torch.LongTensor([[0, 2, 1], [1, 2, 0]])

In [None]:
# Create row indices tensor
row_indices = torch.arange(A.shape[0]).unsqueeze(-1).expand_as(order_tensor)

In [None]:
# Use advanced indexing to permute the tensor
permuted_A_advanced = A[row_indices, order_tensor]

permuted_A_advanced

In [2]:
from utils import permute_tensor_batchwise

In [9]:
A = torch.randn(2, 3)

In [10]:
order_tensor = torch.LongTensor([[0, 2, 1],
                                 [2, 1, 0]])

In [11]:
A

tensor([[ 0.8444, -0.1195,  0.3843],
        [ 0.0372,  2.2617, -0.5762]])

In [12]:
permute_tensor_batchwise(A, order_tensor, 1)

tensor([[ 0.8444,  0.3843, -0.1195],
        [-0.5762,  2.2617,  0.0372]])

In [7]:
permute_tensor_batchwise(A, order_tensor, 2)[1]

tensor([[-1.0878,  1.7513],
        [-0.2414, -0.1225],
        [-1.9815,  0.4922]])

In [None]:
A.dim()

In [None]:
order_tensor.view(*([-1] + [1] * (A.dim() - 1)))

In [None]:
dim = 1

# Check if the shape of A along the given dim matches the shape of order_tensor
assert A.shape[dim] == order_tensor.size(0), "Mismatch in shapes of A and order_tensor along the specified dimension"

# Reshape order_tensor to be compatible with A
reshaped_order_tensor = order_tensor.view(*([-1] + [1] * (A.dim() - 1)))

# Move the permuted dimension to the correct position
reshaped_order_tensor = reshaped_order_tensor.permute([i for i in range(1, dim+1)] + [0] + [i for i in range(dim+1, A.dim())])

# Generate meshgrid of indices
indices = list(torch.meshgrid([torch.arange(s) for s in A.shape]))

# Adjust the shape of order_tensor to match the dimensionality of A
expanded_shape = [s if i == dim else A.shape[i] for i, s in enumerate(reshaped_order_tensor.shape)]
order_tensor_expanded = reshaped_order_tensor.expand(*expanded_shape)

indices[dim] = order_tensor_expanded

In [None]:
yyy.shape

In [None]:
torch.argsort(y_sort)

In [None]:
xxx[:3,:]

In [None]:
yyy[:3, :]

In [None]:
yyy.shape

In [None]:
perm_mat[0]

In [None]:
loss = compute_loss(y_sort)

In [None]:
loss

In [None]:
torch.cuda.empty_cache()

batch_size = 16

n_iter = 1000

opt = Adam(toydiffsort.parameters(), 1e-3)

loss_log = []

for i in tqdm(range(n_iter)):
    
    x = torch.randn(batch_size, 3, 24, 24)
    
    out_sorted = toydiffsort(x)
    
    loss = compute_loss(out_sorted)
    
    loss_log.append(loss.detach().cpu().numpy())
    
    with torch.no_grad():
        loss.backward()
        opt.step()
        opt.zero_grad()
        
plt.plot(range(n_iter), loss_log)

In [None]:
x = torch.randn(batch_size, 3, 24, 24)

In [None]:
toydiffsort(x)

#
#
#
#
#

In [None]:
zz = y_sort * torch.Tensor([10., 1., 0.1])[None, ...]

In [None]:
y_sort

In [None]:
zz.sum()

In [None]:
y[0, :]

In [None]:
output_sorter = DiffSortNet('bitonic', 3, steepness=5)

In [None]:
sorted_vectors, permutation_matrices = output_sorter(y)

In [None]:
permutation_matrices[0]

In [None]:
torch.argsort(sorted_vectors)

In [None]:
sorted_vectors[0, :]

In [None]:
vector_length = 2**4
vectors = torch.randperm(vector_length, dtype=torch.float32, device='cpu', requires_grad=True).view(1, -1)
vectors = vectors - 5.

# sort using a bitonic-sorting-network
sorter = DiffSortNet('bitonic', vector_length, steepness=5)
sorted_vectors, permutation_matrices = sorter(vectors)
print(sorted_vectors)

In [None]:
permutation_matrices[0, 0, :]

In [None]:
permutation_matrices.shape

In [None]:
device = 'cuda'

In [None]:
# Toy example

order = torch.range(0, 4, dtype=torch.float32, device='cuda').view(1, -1)

In [None]:
rp = torch.randperm(5)
xx = x[:, :, rp]

In [None]:
rp

In [None]:
order = torch.arange(0, 4)

In [None]:
order

In [None]:
xxx = x[:, :, order]

In [None]:
xxx[:, :, 0]

In [None]:
x[:, :, 0]

In [None]:
xx[:, :, 1]

In [None]:
x[:, :, 2]

In [None]:
order

In [None]:
x = torch.rand(32, 32, 5)

In [None]:
class TensorSorter(nn.Module):
    def __init__(self, vector_length):
        super(TensorSorter, self).__init__()
        self.vector_length = vector_length
        self.sortnet = DiffSortNet('bitonic', self.vector_length,
                                   steepness=5)
        
    def forward(self, x):
        init_order = torch.range(0, vector_length)
        new_order, _ = self.sortnet(init_order)

In [None]:
tensor_sorter = DiffSortNet('bitonic', 5, steepness=5, device=device)

torch.cuda.empty_cache()

batch_size = 4

n_iter = 100

opt = Adam(tensor_sorter.parameters(), 1e-3)

for i in tqdm(range(n_iter)):
    
    