In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffsort import DiffSortNet
from torch.optim import Adam
from tqdm import tqdm

In [32]:
import functools, operator

class NFeatsEstimate(nn.Module):
    def __init__(self, n_channels, input_size):
        super(NFeatsEstimate, self).__init__()
        self.n_channels = n_channels
        
        self.feat_extractor = nn.Sequential(
            nn.Conv2d(self.n_channels, 32, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        # Trick from https://datascience.stackexchange.com/questions/40906/determining-size-of-fc-layer-after-conv-layer-in-pytorch
        num_feats_before_fcnn = functools.reduce(
            operator.mul,
            list(self.feat_extractor(
                torch.rand(1, n_channels, input_size, input_size)
            ).shape)
        )
        print('Estimated:', num_feats_before_fcnn)
        
        self.fc1 = nn.Linear(num_feats_before_fcnn, 64)
        self.fc2 = nn.Linear(64, self.n_channels)
    
    def forward(self, x):
        batch_size = x.shape[0]  
        
        out = self.feat_extractor(x)
        out = out.view(batch_size, -1)
        print('Actual:', out.shape[1])
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        
        return out
        

In [33]:
nfeatsmodel = NFeatsEstimate(3, 24)

Estimated: 576


In [34]:
x = torch.rand(10, 3, 24, 24)

In [35]:
y = nfeatsmodel(x)

Actual: 576


In [37]:
y[0, :]

tensor([ 0.0854,  0.0291, -0.0963], grad_fn=<SliceBackward0>)

In [38]:
output_sorter = DiffSortNet('bitonic', 3, steepness=5)

In [39]:
sorted_vectors, permutation_matrices = output_sorter(y)

In [40]:
sorted_vectors[0, :]

tensor([-0.0130,  0.0111,  0.0201], grad_fn=<SliceBackward0>)

In [12]:
vector_length = 2**4
vectors = torch.randperm(vector_length, dtype=torch.float32, device='cpu', requires_grad=True).view(1, -1)
vectors = vectors - 5.

# sort using a bitonic-sorting-network
sorter = DiffSortNet('bitonic', vector_length, steepness=5)
sorted_vectors, permutation_matrices = sorter(vectors)
print(sorted_vectors)

tensor([[-4.3655e+00, -3.6170e+00, -2.4922e+00, -1.7446e+00, -7.4628e-01,
          1.1393e-03,  1.1265e+00,  1.8729e+00,  3.0004e+00,  3.8730e+00,
          4.9990e+00,  5.6191e+00,  6.7452e+00,  7.8702e+00,  8.4936e+00,
          9.3647e+00]], grad_fn=<AddBackward0>)


In [42]:
permutation_matrices[0, 0, :]

tensor([0.0089, 0.0096, 0.0013, 0.0018, 0.0364, 0.0276, 0.0594, 0.7393, 0.0595,
        0.0294, 0.0031, 0.0138, 0.0011, 0.0008, 0.0011, 0.0068],
       grad_fn=<SliceBackward0>)

In [38]:
permutation_matrices.shape

torch.Size([1, 16, 16])

In [10]:
device = 'cuda'

In [16]:
# Toy example

order = torch.range(0, 4, dtype=torch.float32, device='cuda').view(1, -1)

  order = torch.range(0, 4, dtype=torch.float32, device='cuda').view(1, -1)


In [25]:
rp = torch.randperm(5)
xx = x[:, :, rp]

In [26]:
rp

tensor([4, 2, 0, 3, 1])

In [32]:
order = torch.arange(0, 4)

In [33]:
order

tensor([0, 1, 2, 3])

In [34]:
xxx = x[:, :, order]

In [35]:
xxx[:, :, 0]

tensor([[0.5484, 0.2939, 0.5618,  ..., 0.1041, 0.7093, 0.3379],
        [0.9113, 0.9405, 0.8850,  ..., 0.3878, 0.8800, 0.4208],
        [0.4252, 0.1380, 0.9859,  ..., 0.9901, 0.6075, 0.0245],
        ...,
        [0.8270, 0.5652, 0.9940,  ..., 0.7776, 0.4276, 0.6980],
        [0.5070, 0.3908, 0.7123,  ..., 0.8939, 0.8492, 0.2130],
        [0.1029, 0.6063, 0.6649,  ..., 0.4715, 0.1156, 0.1094]])

In [36]:
x[:, :, 0]

tensor([[0.5484, 0.2939, 0.5618,  ..., 0.1041, 0.7093, 0.3379],
        [0.9113, 0.9405, 0.8850,  ..., 0.3878, 0.8800, 0.4208],
        [0.4252, 0.1380, 0.9859,  ..., 0.9901, 0.6075, 0.0245],
        ...,
        [0.8270, 0.5652, 0.9940,  ..., 0.7776, 0.4276, 0.6980],
        [0.5070, 0.3908, 0.7123,  ..., 0.8939, 0.8492, 0.2130],
        [0.1029, 0.6063, 0.6649,  ..., 0.4715, 0.1156, 0.1094]])

In [29]:
xx[:, :, 1]

tensor([[0.1903, 0.7706, 0.3053,  ..., 0.3698, 0.1177, 0.8657],
        [0.2454, 0.4232, 0.0497,  ..., 0.6887, 0.4966, 0.5243],
        [0.1692, 0.4444, 0.0849,  ..., 0.1065, 0.0825, 0.8605],
        ...,
        [0.8798, 0.2669, 0.0934,  ..., 0.9787, 0.4542, 0.4507],
        [0.0577, 0.9786, 0.0999,  ..., 0.8592, 0.8941, 0.0569],
        [0.8539, 0.4721, 0.9536,  ..., 0.4351, 0.8505, 0.2961]])

In [30]:
x[:, :, 2]

tensor([[0.1903, 0.7706, 0.3053,  ..., 0.3698, 0.1177, 0.8657],
        [0.2454, 0.4232, 0.0497,  ..., 0.6887, 0.4966, 0.5243],
        [0.1692, 0.4444, 0.0849,  ..., 0.1065, 0.0825, 0.8605],
        ...,
        [0.8798, 0.2669, 0.0934,  ..., 0.9787, 0.4542, 0.4507],
        [0.0577, 0.9786, 0.0999,  ..., 0.8592, 0.8941, 0.0569],
        [0.8539, 0.4721, 0.9536,  ..., 0.4351, 0.8505, 0.2961]])

In [14]:
order

tensor([[0., 1., 2., 3., 4.]], device='cuda:0')

In [21]:
x = torch.rand(32, 32, 5)

In [None]:
class TensorSorter(nn.Module):
    def __init__(self, vector_length):
        super(TensorSorter, self).__init__()
        self.vector_length = vector_length
        self.sortnet = DiffSortNet('bitonic', self.vector_length,
                                   steepness=5)
        
    def forward(self, x):
        init_order = torch.range(0, vector_length)
        new_order, _ = self.sortnet(init_order)

In [None]:
tensor_sorter = DiffSortNet('bitonic', 5, steepness=5, device=device)

torch.cuda.empty_cache()

batch_size = 4

n_iter = 100

opt = Adam(tensor_sorter.parameters(), 1e-3)

for i in tqdm(range(n_iter)):
    
    