In [67]:
%load_ext autoreload
%autoreload 2
from typing import List

from copy import copy
from itertools import combinations, permutations
import pickle
from random import choice, random

import pandas as pd
import numpy as np

import sys
sys.path.append('../src')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim

In [32]:
# need a convex function
def sort_loss(input: List[int]) -> float:
    loss = 0
    for i in range(len(input)):
        for j in range(i, len(input)):
            if input[i] > input[j]:
                loss += 1
    return loss


def get_swaps_possible(input: List[int]) -> List[List[int]]:
    res = []
    for i in range(len(input)-1):
        input_copy = copy(input)
        input_copy[i], input_copy[i+1] = input_copy[i+1], input_copy[i]
        res.append(input_copy)
    
    return res

In [29]:
sort_loss([1,5,4,3,2])

6

In [54]:
# generate simple data
SOURCE = [x for x in range(7)]
DATASET_SIZE = 10
SEQ_LEN = 4

DATA = []

SAMPLES = list(permutations(SOURCE, SEQ_LEN))
_picked = set()
for _ in range(10):
    v = choice(SAMPLES)
    if v not in _picked:
        _picked.add(v)

    swap_chosen = None
    for swap in get_swaps_possible(list(v)):
        if swap_chosen is None:
            swap_chosen = swap
        if sort_loss(swap) < sort_loss(swap_chosen):
            swap_chosen = swap
    
    DATA.append((list(v), swap_chosen))

In [55]:
DATA

[([1, 2, 0, 3], [1, 0, 2, 3]),
 ([0, 5, 2, 3], [0, 2, 5, 3]),
 ([6, 3, 1, 4], [3, 6, 1, 4]),
 ([6, 1, 5, 3], [1, 6, 5, 3]),
 ([0, 3, 4, 5], [3, 0, 4, 5]),
 ([2, 1, 6, 3], [1, 2, 6, 3]),
 ([6, 1, 4, 2], [1, 6, 4, 2]),
 ([5, 0, 1, 6], [0, 5, 1, 6]),
 ([4, 6, 2, 3], [4, 2, 6, 3]),
 ([2, 1, 6, 0], [1, 2, 6, 0])]

In [71]:
# sanity check
for i, row in enumerate(DATA):
    if i > 9:
        break
    print('------------------\n')
    print(row)
    print('------------------\n')
    for v in get_swaps_possible(row[0]):
        print(v, sort_loss(v))
    print('-------------------\n')

------------------

([1, 2, 0, 3], [1, 0, 2, 3])
------------------

[2, 1, 0, 3] 3
[1, 0, 2, 3] 1
[1, 2, 3, 0] 3
-------------------

------------------

([0, 5, 2, 3], [0, 2, 5, 3])
------------------

[5, 0, 2, 3] 3
[0, 2, 5, 3] 1
[0, 5, 3, 2] 3
-------------------

------------------

([6, 3, 1, 4], [3, 6, 1, 4])
------------------

[3, 6, 1, 4] 3
[6, 1, 3, 4] 3
[6, 3, 4, 1] 5
-------------------

------------------

([6, 1, 5, 3], [1, 6, 5, 3])
------------------

[1, 6, 5, 3] 3
[6, 5, 1, 3] 5
[6, 1, 3, 5] 3
-------------------

------------------

([0, 3, 4, 5], [3, 0, 4, 5])
------------------

[3, 0, 4, 5] 1
[0, 4, 3, 5] 1
[0, 3, 5, 4] 1
-------------------

------------------

([2, 1, 6, 3], [1, 2, 6, 3])
------------------

[1, 2, 6, 3] 1
[2, 6, 1, 3] 3
[2, 1, 3, 6] 1
-------------------

------------------

([6, 1, 4, 2], [1, 6, 4, 2])
------------------

[1, 6, 4, 2] 3
[6, 4, 1, 2] 5
[6, 1, 2, 4] 3
-------------------

------------------

([5, 0, 1, 6], [0, 5, 1, 6])
------

## Split data

In [58]:
SPLIT_RATE = 0.8
train_data = []
test_data = []
for row in DATA:
    if random() < SPLIT_RATE:
        train_data.append(row)
    else:
        test_data.append(row)

## Dump data

In [62]:
with open('../data/train.pkl', 'wb') as f:
    pickle.dump(train_data, f)

with open('../data/test.pkl', 'wb') as f:
    pickle.dump(test_data, f)

## Load data

In [61]:
with open('../data/train.pkl', 'rb') as f:
    train_data = pickle.load(f)

with open('../data/test.pkl', 'rb') as f:
    test_data = pickle.load(f)

[([2, 1, 6, 0], [1, 2, 6, 0])]

## Define a simple sorter NN

In [26]:
class SorterNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super(SorterNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [64]:
input_train_data = torch.tensor([x[0] for x in train_data], dtype=torch.float32)
output_train_data = torch.tensor([x[1] for x in train_data], dtype=torch.float32)

input_test_data = torch.tensor([x[0] for x in test_data], dtype=torch.float32)
output_test_data = torch.tensor([x[1] for x in test_data], dtype=torch.float32)    

In [None]:
input_data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float32)
output_data = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90], [100, 110, 120]], dtype=torch.float32)

In [None]:
input_size = input_data.shape[1]
output_size = output_data.shape[1]
hidden_size = 64
model = SorterNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [12]:
num_epochs = 5000
for epoch in range(num_epochs):
    outputs = model(input_data)
    loss = criterion(outputs, output_data)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 1000 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [100/5000], Loss: 2582.2021
Epoch [200/5000], Loss: 325.3589
Epoch [300/5000], Loss: 21.1943
Epoch [400/5000], Loss: 15.2081
Epoch [500/5000], Loss: 15.0118
Epoch [600/5000], Loss: 14.8628
Epoch [700/5000], Loss: 14.7003
Epoch [800/5000], Loss: 14.5249
Epoch [900/5000], Loss: 14.3375
Epoch [1000/5000], Loss: 14.1387
Epoch [1100/5000], Loss: 13.9292
Epoch [1200/5000], Loss: 13.7093
Epoch [1300/5000], Loss: 13.4795
Epoch [1400/5000], Loss: 13.2398
Epoch [1500/5000], Loss: 12.9903
Epoch [1600/5000], Loss: 12.7311
Epoch [1700/5000], Loss: 12.4618
Epoch [1800/5000], Loss: 12.1821
Epoch [1900/5000], Loss: 11.8914
Epoch [2000/5000], Loss: 11.5892
Epoch [2100/5000], Loss: 11.2744
Epoch [2200/5000], Loss: 10.9460
Epoch [2300/5000], Loss: 10.6028
Epoch [2400/5000], Loss: 10.2435
Epoch [2500/5000], Loss: 9.8666
Epoch [2600/5000], Loss: 9.4705
Epoch [2700/5000], Loss: 9.0538
Epoch [2800/5000], Loss: 8.6149
Epoch [2900/5000], Loss: 8.1528
Epoch [3000/5000], Loss: 7.6667
Epoch [3100/5000], Los

In [None]:
test_input = torch.tensor([[10, 20, 30]], dtype=torch.float32)
predicted_output = model(test_input)
print("Predicted Output:", predicted_output.detach().numpy())