In [175]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffsort import DiffSortNet
from torch.optim import Adam
from tqdm import tqdm
import functools, operator
import matplotlib.pyplot as plt
from bending_modules import BendingDiffSort

In [179]:
bendiffsort = BendingDiffSort(3, 24)
x = torch.rand(16, 3, 24, 24)
y = bendiffsort(x)

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [16, 576] but got: [16, 3].

In [13]:
class ToyDiffSort(nn.Module):
    def __init__(self, n_channels, input_size):
        super(ToyDiffSort, self).__init__()
        self.n_channels = n_channels
        
        self.feat_extractor = nn.Sequential(
            nn.Conv2d(self.n_channels, 32, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        # Trick from https://datascience.stackexchange.com/questions/40906/determining-size-of-fc-layer-after-conv-layer-in-pytorch
        num_feats_before_fcnn = functools.reduce(
            operator.mul,
            list(self.feat_extractor(
                torch.rand(1, n_channels, input_size, input_size)
            ).shape)
        )
        #print('Estimated:', num_feats_before_fcnn)
        
        self.fc1 = nn.Linear(num_feats_before_fcnn, 64)
        self.fc2 = nn.Linear(64, self.n_channels)
        self.output_sorter = DiffSortNet('bitonic', self.n_channels, steepness=5)
    
    def forward(self, x):
        batch_size = x.shape[0]  
        
        out = self.feat_extractor(x)
        out = out.view(batch_size, -1)
        #print('Actual:', out.shape[1])
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        out_sorted, sort_mat = self.output_sorter(out)
        
        return out_sorted
        

In [79]:
sorter = DiffSortNet('bitonic', 3, steepness=50)

In [105]:
x = torch.randn(4, 3, 2, 2)
y = torch.randn(4, 3)

In [106]:
x_ = x.reshape(4, -1, 3)

In [107]:
x_.shape

torch.Size([4, 4, 3])

In [99]:
x_sort, sort_mat = sorter(y)

In [100]:
sort_mat.shape

torch.Size([4, 3, 3])

In [125]:
x.permute(0, 2, 3, 1).shape

torch.Size([4, 2, 2, 3])

In [126]:
x_sorted = torch.bmm(x.permute(0, 2, 3, 1).reshape(4, -1, 3), sort_mat).reshape(0, 2, 3, 1)

RuntimeError: shape '[0, 2, 3, 1]' is invalid for input of size 48

In [128]:
sort_mat = torch.Tensor([[[0., 1.], [1., 0.]], [[1., 0.], [0., 1.]]])

In [131]:
sort_mat.shape

torch.Size([2, 2, 2])

In [130]:
sort_mat[0]

tensor([[0., 1.],
        [1., 0.]])

In [165]:
x = torch.Tensor([
    [[[0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1]]], 
    [[[2, 2, 2], [2, 2, 2]], [[3, 3, 3], [3, 3, 3]]]
    ])

In [166]:
x.shape

torch.Size([2, 2, 2, 3])

In [170]:
y = torch.bmm(x.reshape(2,2,-1).permute(0, 2, 1), sort_mat).permute(0,2,1).reshape(2,2,2,3)

In [171]:
y.shape

torch.Size([2, 2, 2, 3])

In [172]:
x[0][0]

tensor([[0., 0., 0.],
        [0., 0., 0.]])

In [174]:
y[0][0]

tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [122]:
x[0][0]

tensor([[ 0.0042, -0.4893],
        [ 1.1642,  1.3981]])

In [118]:
x_sorted[0][1]

tensor([[ 1.3691,  0.2172],
        [ 0.1267, -0.7040]])

In [120]:
sort_mat[0]

tensor([[0.0053, 0.9749, 0.0198],
        [0.0041, 0.0199, 0.9760],
        [0.9906, 0.0052, 0.0042]])

In [114]:
x.reshape(4, -1, 3).reshape(4, 3, 2, 2)[0][0]

tensor([[ 0.0042, -0.4893],
        [ 1.1642,  1.3981]])

In [109]:
y.shape

torch.Size([4, 4, 3])

In [95]:
x

tensor([[ 1.0119,  1.2738, -1.0657],
        [-0.3608,  2.0084,  0.3825],
        [-0.0903, -1.5681,  1.4969],
        [ 0.8572,  0.5036,  0.0334]])

In [86]:
y.shape

torch.Size([4, 4, 3])

In [87]:
x.shape

torch.Size([4, 3])

In [78]:
for mat in sort_mat:
    print(f'{mat[0][0]:.4f}, {mat[0][1]:.4f}, {mat[0][2]:.4f}')

0.0017, 0.0019, 0.9964
0.0031, 0.9894, 0.0075
0.0013, 0.0069, 0.9918
0.0091, 0.9895, 0.0015


In [72]:
for mat in sort_mat:
    print(f'{mat[0][0]:.4f}, {mat[0][1]:.4f}, {mat[0][2]:.4f}')

0.9985, 0.0015, 0.0000
0.0016, 0.9817, 0.0167
0.9985, 0.0015, 0.0000
0.0013, 0.0035, 0.9952


In [None]:
def compute_loss(out_sorted):
    loss = out_sorted * torch.Tensor([1., 0., 0.])[None, ...]
    return 1./loss.sum()

In [25]:
toydiffsort = ToyDiffSort(3, 24)

In [29]:
toydiffsort.output_sorter

DiffSortNet()

In [15]:
x = torch.rand(10, 3, 24, 24)

In [16]:
y_sort = toydiffsort(x)

In [20]:
torch.argsort(y_sort)

tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2]])

In [23]:
x_perm = permute_tensor_batchwise(x, torch.argsort(y_sort), 1)

In [24]:
x_perm.shape

torch.Size([10, 3, 24, 24])

In [None]:
y_sort.shape

In [None]:
y_sort[:3, :]

In [None]:
torch.argsort(y_sort[:3, :])

In [None]:
xxx = torch.randn(16, 3)

In [None]:
A = torch.Tensor([[1, 2, 3], [4, 5, 6]])
order_tensor = torch.LongTensor([[0, 2, 1], [1, 2, 0]])

In [None]:
# Create row indices tensor
row_indices = torch.arange(A.shape[0]).unsqueeze(-1).expand_as(order_tensor)

In [None]:
# Use advanced indexing to permute the tensor
permuted_A_advanced = A[row_indices, order_tensor]

permuted_A_advanced

In [2]:
from utils import permute_tensor_batchwise

In [9]:
A = torch.randn(2, 3)

In [10]:
order_tensor = torch.LongTensor([[0, 2, 1],
                                 [2, 1, 0]])

In [11]:
A

tensor([[ 0.8444, -0.1195,  0.3843],
        [ 0.0372,  2.2617, -0.5762]])

In [12]:
permute_tensor_batchwise(A, order_tensor, 1)

tensor([[ 0.8444,  0.3843, -0.1195],
        [-0.5762,  2.2617,  0.0372]])

In [7]:
permute_tensor_batchwise(A, order_tensor, 2)[1]

tensor([[-1.0878,  1.7513],
        [-0.2414, -0.1225],
        [-1.9815,  0.4922]])

In [None]:
A.dim()

In [None]:
order_tensor.view(*([-1] + [1] * (A.dim() - 1)))

In [None]:
dim = 1

# Check if the shape of A along the given dim matches the shape of order_tensor
assert A.shape[dim] == order_tensor.size(0), "Mismatch in shapes of A and order_tensor along the specified dimension"

# Reshape order_tensor to be compatible with A
reshaped_order_tensor = order_tensor.view(*([-1] + [1] * (A.dim() - 1)))

# Move the permuted dimension to the correct position
reshaped_order_tensor = reshaped_order_tensor.permute([i for i in range(1, dim+1)] + [0] + [i for i in range(dim+1, A.dim())])

# Generate meshgrid of indices
indices = list(torch.meshgrid([torch.arange(s) for s in A.shape]))

# Adjust the shape of order_tensor to match the dimensionality of A
expanded_shape = [s if i == dim else A.shape[i] for i, s in enumerate(reshaped_order_tensor.shape)]
order_tensor_expanded = reshaped_order_tensor.expand(*expanded_shape)

indices[dim] = order_tensor_expanded

In [None]:
yyy.shape

In [None]:
torch.argsort(y_sort)

In [None]:
xxx[:3,:]

In [None]:
yyy[:3, :]

In [None]:
yyy.shape

In [None]:
perm_mat[0]

In [None]:
loss = compute_loss(y_sort)

In [None]:
loss

In [None]:
torch.cuda.empty_cache()

batch_size = 16

n_iter = 1000

opt = Adam(toydiffsort.parameters(), 1e-3)

loss_log = []

for i in tqdm(range(n_iter)):
    
    x = torch.randn(batch_size, 3, 24, 24)
    
    out_sorted = toydiffsort(x)
    
    loss = compute_loss(out_sorted)
    
    loss_log.append(loss.detach().cpu().numpy())
    
    with torch.no_grad():
        loss.backward()
        opt.step()
        opt.zero_grad()
        
plt.plot(range(n_iter), loss_log)

In [None]:
x = torch.randn(batch_size, 3, 24, 24)

In [None]:
toydiffsort(x)

#
#
#
#
#

In [None]:
zz = y_sort * torch.Tensor([10., 1., 0.1])[None, ...]

In [None]:
y_sort

In [None]:
zz.sum()

In [None]:
y[0, :]

In [None]:
output_sorter = DiffSortNet('bitonic', 3, steepness=5)

In [None]:
sorted_vectors, permutation_matrices = output_sorter(y)

In [None]:
permutation_matrices[0]

In [None]:
torch.argsort(sorted_vectors)

In [None]:
sorted_vectors[0, :]

In [None]:
vector_length = 2**4
vectors = torch.randperm(vector_length, dtype=torch.float32, device='cpu', requires_grad=True).view(1, -1)
vectors = vectors - 5.

# sort using a bitonic-sorting-network
sorter = DiffSortNet('bitonic', vector_length, steepness=5)
sorted_vectors, permutation_matrices = sorter(vectors)
print(sorted_vectors)

In [None]:
permutation_matrices[0, 0, :]

In [None]:
permutation_matrices.shape

In [None]:
device = 'cuda'

In [None]:
# Toy example

order = torch.range(0, 4, dtype=torch.float32, device='cuda').view(1, -1)

In [None]:
rp = torch.randperm(5)
xx = x[:, :, rp]

In [None]:
rp

In [None]:
order = torch.arange(0, 4)

In [None]:
order

In [None]:
xxx = x[:, :, order]

In [None]:
xxx[:, :, 0]

In [None]:
x[:, :, 0]

In [None]:
xx[:, :, 1]

In [None]:
x[:, :, 2]

In [None]:
order

In [None]:
x = torch.rand(32, 32, 5)

In [None]:
class TensorSorter(nn.Module):
    def __init__(self, vector_length):
        super(TensorSorter, self).__init__()
        self.vector_length = vector_length
        self.sortnet = DiffSortNet('bitonic', self.vector_length,
                                   steepness=5)
        
    def forward(self, x):
        init_order = torch.range(0, vector_length)
        new_order, _ = self.sortnet(init_order)

In [None]:
tensor_sorter = DiffSortNet('bitonic', 5, steepness=5, device=device)

torch.cuda.empty_cache()

batch_size = 4

n_iter = 100

opt = Adam(tensor_sorter.parameters(), 1e-3)

for i in tqdm(range(n_iter)):
    
    